tensor4all_core/defaults/
direct_sum.rs1use crate::defaults::DynIndex;
10use crate::index_like::IndexLike;
11use crate::tensor::TensorDynLen;
12use anyhow::Result;
13use num_traits::Zero;
14use tensor4all_tensorbackend::TensorElement;
15
16pub fn direct_sum(
54 a: &TensorDynLen,
55 b: &TensorDynLen,
56 pairs: &[(DynIndex, DynIndex)],
57) -> Result<(TensorDynLen, Vec<DynIndex>)> {
58 if a.is_f64() && b.is_f64() {
59 direct_sum_typed::<f64>(a, b, pairs)
60 } else if a.is_complex() && b.is_complex() {
61 direct_sum_typed::<num_complex::Complex64>(a, b, pairs)
62 } else {
63 Err(anyhow::anyhow!(
64 "direct_sum requires both tensors to have the same dense scalar type (f64 or Complex64)"
65 ))
66 }
67}
68
69#[allow(dead_code)]
71struct DirectSumSetup {
72 common_a_positions: Vec<usize>,
73 common_b_positions: Vec<usize>,
74 paired_a_positions: Vec<usize>,
75 paired_b_positions: Vec<usize>,
76 paired_dims_a: Vec<usize>,
77 paired_dims_b: Vec<usize>,
78 result_indices: Vec<DynIndex>,
79 result_dims: Vec<usize>,
80 result_strides: Vec<usize>,
81 result_total: usize,
82 a_strides: Vec<usize>,
83 b_strides: Vec<usize>,
84 new_indices: Vec<DynIndex>,
85 n_common: usize,
86}
87
88fn setup_direct_sum(
89 a: &TensorDynLen,
90 b: &TensorDynLen,
91 pairs: &[(DynIndex, DynIndex)],
92) -> Result<DirectSumSetup> {
93 use std::collections::HashMap;
94
95 if pairs.is_empty() {
96 return Err(anyhow::anyhow!(
97 "direct_sum requires at least one index pair"
98 ));
99 }
100
101 let a_idx_map: HashMap<&DynIndex, usize> = a
104 .indices
105 .iter()
106 .enumerate()
107 .map(|(i, idx)| (idx, i))
108 .collect();
109 let b_idx_map: HashMap<&DynIndex, usize> = b
110 .indices
111 .iter()
112 .enumerate()
113 .map(|(i, idx)| (idx, i))
114 .collect();
115
116 let mut a_paired_indices: std::collections::HashSet<&DynIndex> =
118 std::collections::HashSet::new();
119 let mut b_paired_indices: std::collections::HashSet<&DynIndex> =
120 std::collections::HashSet::new();
121
122 for (a_idx, b_idx) in pairs {
123 if !a_idx_map.contains_key(a_idx) {
124 return Err(anyhow::anyhow!(
125 "Index not found in first tensor (direct_sum)"
126 ));
127 }
128 if !b_idx_map.contains_key(b_idx) {
129 return Err(anyhow::anyhow!(
130 "Index not found in second tensor (direct_sum)"
131 ));
132 }
133 a_paired_indices.insert(a_idx);
134 b_paired_indices.insert(b_idx);
135 }
136
137 let mut common_a_positions: Vec<usize> = Vec::new();
141 let mut common_b_positions: Vec<usize> = Vec::new();
142 for (a_pos, a_idx) in a.indices.iter().enumerate() {
143 if a_paired_indices.contains(a_idx) {
144 continue;
145 }
146 if let Some(&b_pos) = b_idx_map.get(a_idx) {
147 if b_paired_indices.contains(a_idx) {
148 continue;
149 }
150 let a_dims = a.dims();
151 let b_dims = b.dims();
152 if a_dims[a_pos] != b_dims[b_pos] {
153 return Err(anyhow::anyhow!(
154 "Dimension mismatch for common index: {} vs {}",
155 a_dims[a_pos],
156 b_dims[b_pos]
157 ));
158 }
159 common_a_positions.push(a_pos);
160 common_b_positions.push(b_pos);
161 }
162 }
163
164 let paired_a_positions: Vec<usize> = pairs.iter().map(|(a_idx, _)| a_idx_map[a_idx]).collect();
166 let paired_b_positions: Vec<usize> = pairs.iter().map(|(_, b_idx)| b_idx_map[b_idx]).collect();
167 let a_dims = a.dims();
168 let b_dims = b.dims();
169 let paired_dims_a: Vec<usize> = paired_a_positions.iter().map(|&p| a_dims[p]).collect();
170 let paired_dims_b: Vec<usize> = paired_b_positions.iter().map(|&p| b_dims[p]).collect();
171
172 let mut new_indices: Vec<DynIndex> = Vec::new();
174 for (&dim_a, &dim_b) in paired_dims_a.iter().zip(&paired_dims_b) {
175 let new_dim = dim_a + dim_b;
176 let new_index = DynIndex::new_link(new_dim)
177 .map_err(|e| anyhow::anyhow!("Failed to create index: {:?}", e))?;
178 new_indices.push(new_index);
179 }
180
181 let mut result_indices: Vec<DynIndex> = Vec::new();
183 let mut result_dims: Vec<usize> = Vec::new();
184
185 let a_dims = a.dims();
187 for &a_pos in &common_a_positions {
188 result_indices.push(a.indices[a_pos].clone());
189 result_dims.push(a_dims[a_pos]);
190 }
191
192 for new_idx in &new_indices {
194 result_indices.push(new_idx.clone());
195 result_dims.push(new_idx.dim());
196 }
197
198 let result_total: usize = result_dims.iter().product();
200 let mut result_strides: Vec<usize> = vec![1; result_dims.len()];
201 for i in 1..result_dims.len() {
202 result_strides[i] = result_strides[i - 1] * result_dims[i - 1];
203 }
204
205 let a_dims = a.dims();
206 let b_dims = b.dims();
207 let mut a_strides: Vec<usize> = vec![1; a_dims.len()];
208 for i in 1..a_dims.len() {
209 a_strides[i] = a_strides[i - 1] * a_dims[i - 1];
210 }
211
212 let mut b_strides: Vec<usize> = vec![1; b_dims.len()];
213 for i in 1..b_dims.len() {
214 b_strides[i] = b_strides[i - 1] * b_dims[i - 1];
215 }
216
217 let n_common = common_a_positions.len();
218
219 Ok(DirectSumSetup {
220 common_a_positions,
221 common_b_positions,
222 paired_a_positions,
223 paired_b_positions,
224 paired_dims_a,
225 paired_dims_b,
226 result_indices,
227 result_dims,
228 result_strides,
229 result_total,
230 a_strides,
231 b_strides,
232 new_indices,
233 n_common,
234 })
235}
236
237fn linear_to_multi(linear: usize, dims: &[usize]) -> Vec<usize> {
238 let mut multi = vec![0; dims.len()];
239 let mut remaining = linear;
240 for i in 0..dims.len() {
241 multi[i] = remaining % dims[i];
242 remaining /= dims[i];
243 }
244 multi
245}
246
247fn multi_to_linear(multi: &[usize], strides: &[usize]) -> usize {
248 multi.iter().zip(strides).map(|(&m, &s)| m * s).sum()
249}
250
251fn direct_sum_typed<T: TensorElement + Zero>(
252 a: &TensorDynLen,
253 b: &TensorDynLen,
254 pairs: &[(DynIndex, DynIndex)],
255) -> Result<(TensorDynLen, Vec<DynIndex>)> {
256 let setup = setup_direct_sum(a, b, pairs)?;
257 let a_data = a.to_vec::<T>()?;
258 let b_data = b.to_vec::<T>()?;
259
260 let mut result_data: Vec<T> = vec![T::zero(); setup.result_total];
261
262 #[allow(clippy::needless_range_loop)]
263 for result_linear in 0..setup.result_total {
264 let result_multi = linear_to_multi(result_linear, &setup.result_dims);
265 let common_multi: Vec<usize> = result_multi[..setup.n_common].to_vec();
266 let paired_multi: Vec<usize> = result_multi[setup.n_common..].to_vec();
267
268 let all_from_a = paired_multi
269 .iter()
270 .enumerate()
271 .all(|(i, &pm)| pm < setup.paired_dims_a[i]);
272 let all_from_b = paired_multi
273 .iter()
274 .enumerate()
275 .all(|(i, &pm)| pm >= setup.paired_dims_a[i]);
276
277 if all_from_a {
278 let a_dims = a.dims();
279 let mut a_multi = vec![0usize; a_dims.len()];
280 for (i, &cp) in setup.common_a_positions.iter().enumerate() {
281 a_multi[cp] = common_multi[i];
282 }
283 for (i, &pp) in setup.paired_a_positions.iter().enumerate() {
284 a_multi[pp] = paired_multi[i];
285 }
286 let a_linear = multi_to_linear(&a_multi, &setup.a_strides);
287 result_data[result_linear] = a_data[a_linear];
288 } else if all_from_b {
289 let b_dims = b.dims();
290 let mut b_multi = vec![0usize; b_dims.len()];
291 for (i, &cp) in setup.common_b_positions.iter().enumerate() {
292 b_multi[cp] = common_multi[i];
293 }
294 for (i, &pp) in setup.paired_b_positions.iter().enumerate() {
295 b_multi[pp] = paired_multi[i] - setup.paired_dims_a[i];
296 }
297 let b_linear = multi_to_linear(&b_multi, &setup.b_strides);
298 result_data[result_linear] = b_data[b_linear];
299 }
300 }
302
303 let result = TensorDynLen::from_dense(setup.result_indices, result_data)?;
304 Ok((result, setup.new_indices))
305}
306
307#[cfg(test)]
308mod tests;