tensor4all_core/defaults/
direct_sum.rs1use crate::defaults::DynIndex;
10use crate::index_like::IndexLike;
11use crate::tensor::TensorDynLen;
12use anyhow::Result;
13use num_traits::Zero;
14use tensor4all_tensorbackend::TensorElement;
15
16pub fn direct_sum(
45 a: &TensorDynLen,
46 b: &TensorDynLen,
47 pairs: &[(DynIndex, DynIndex)],
48) -> Result<(TensorDynLen, Vec<DynIndex>)> {
49 if a.is_f64() && b.is_f64() {
50 direct_sum_typed::<f64>(a, b, pairs)
51 } else if a.is_complex() && b.is_complex() {
52 direct_sum_typed::<num_complex::Complex64>(a, b, pairs)
53 } else {
54 Err(anyhow::anyhow!(
55 "direct_sum requires both tensors to have the same dense scalar type (f64 or Complex64)"
56 ))
57 }
58}
59
60#[allow(dead_code)]
62struct DirectSumSetup {
63 common_a_positions: Vec<usize>,
64 common_b_positions: Vec<usize>,
65 paired_a_positions: Vec<usize>,
66 paired_b_positions: Vec<usize>,
67 paired_dims_a: Vec<usize>,
68 paired_dims_b: Vec<usize>,
69 result_indices: Vec<DynIndex>,
70 result_dims: Vec<usize>,
71 result_strides: Vec<usize>,
72 result_total: usize,
73 a_strides: Vec<usize>,
74 b_strides: Vec<usize>,
75 new_indices: Vec<DynIndex>,
76 n_common: usize,
77}
78
79fn setup_direct_sum(
80 a: &TensorDynLen,
81 b: &TensorDynLen,
82 pairs: &[(DynIndex, DynIndex)],
83) -> Result<DirectSumSetup> {
84 use crate::defaults::DynId;
85 use std::collections::HashMap;
86
87 if pairs.is_empty() {
88 return Err(anyhow::anyhow!(
89 "direct_sum requires at least one index pair"
90 ));
91 }
92
93 let a_idx_map: HashMap<&DynId, usize> = a
95 .indices
96 .iter()
97 .enumerate()
98 .map(|(i, idx)| (idx.id(), i))
99 .collect();
100 let b_idx_map: HashMap<&DynId, usize> = b
101 .indices
102 .iter()
103 .enumerate()
104 .map(|(i, idx)| (idx.id(), i))
105 .collect();
106
107 let mut a_paired_ids: std::collections::HashSet<&DynId> = std::collections::HashSet::new();
109 let mut b_paired_ids: std::collections::HashSet<&DynId> = std::collections::HashSet::new();
110
111 for (a_idx, b_idx) in pairs {
112 if !a_idx_map.contains_key(a_idx.id()) {
113 return Err(anyhow::anyhow!(
114 "Index not found in first tensor (direct_sum)"
115 ));
116 }
117 if !b_idx_map.contains_key(b_idx.id()) {
118 return Err(anyhow::anyhow!(
119 "Index not found in second tensor (direct_sum)"
120 ));
121 }
122 a_paired_ids.insert(a_idx.id());
123 b_paired_ids.insert(b_idx.id());
124 }
125
126 let mut common_a_positions: Vec<usize> = Vec::new();
130 let mut common_b_positions: Vec<usize> = Vec::new();
131 for (a_pos, a_idx) in a.indices.iter().enumerate() {
132 let a_id = a_idx.id();
133 if a_paired_ids.contains(a_id) {
134 continue;
135 }
136 if let Some(&b_pos) = b_idx_map.get(a_id) {
137 if b_paired_ids.contains(a_id) {
138 continue;
139 }
140 let a_dims = a.dims();
141 let b_dims = b.dims();
142 if a_dims[a_pos] != b_dims[b_pos] {
143 return Err(anyhow::anyhow!(
144 "Dimension mismatch for common index: {} vs {}",
145 a_dims[a_pos],
146 b_dims[b_pos]
147 ));
148 }
149 common_a_positions.push(a_pos);
150 common_b_positions.push(b_pos);
151 }
152 }
153
154 let paired_a_positions: Vec<usize> = pairs
156 .iter()
157 .map(|(a_idx, _)| a_idx_map[a_idx.id()])
158 .collect();
159 let paired_b_positions: Vec<usize> = pairs
160 .iter()
161 .map(|(_, b_idx)| b_idx_map[b_idx.id()])
162 .collect();
163 let a_dims = a.dims();
164 let b_dims = b.dims();
165 let paired_dims_a: Vec<usize> = paired_a_positions.iter().map(|&p| a_dims[p]).collect();
166 let paired_dims_b: Vec<usize> = paired_b_positions.iter().map(|&p| b_dims[p]).collect();
167
168 let mut new_indices: Vec<DynIndex> = Vec::new();
170 for (&dim_a, &dim_b) in paired_dims_a.iter().zip(&paired_dims_b) {
171 let new_dim = dim_a + dim_b;
172 let new_index = DynIndex::new_link(new_dim)
173 .map_err(|e| anyhow::anyhow!("Failed to create index: {:?}", e))?;
174 new_indices.push(new_index);
175 }
176
177 let mut result_indices: Vec<DynIndex> = Vec::new();
179 let mut result_dims: Vec<usize> = Vec::new();
180
181 let a_dims = a.dims();
183 for &a_pos in &common_a_positions {
184 result_indices.push(a.indices[a_pos].clone());
185 result_dims.push(a_dims[a_pos]);
186 }
187
188 for new_idx in &new_indices {
190 result_indices.push(new_idx.clone());
191 result_dims.push(new_idx.dim());
192 }
193
194 let result_total: usize = result_dims.iter().product();
196 let mut result_strides: Vec<usize> = vec![1; result_dims.len()];
197 for i in 1..result_dims.len() {
198 result_strides[i] = result_strides[i - 1] * result_dims[i - 1];
199 }
200
201 let a_dims = a.dims();
202 let b_dims = b.dims();
203 let mut a_strides: Vec<usize> = vec![1; a_dims.len()];
204 for i in 1..a_dims.len() {
205 a_strides[i] = a_strides[i - 1] * a_dims[i - 1];
206 }
207
208 let mut b_strides: Vec<usize> = vec![1; b_dims.len()];
209 for i in 1..b_dims.len() {
210 b_strides[i] = b_strides[i - 1] * b_dims[i - 1];
211 }
212
213 let n_common = common_a_positions.len();
214
215 Ok(DirectSumSetup {
216 common_a_positions,
217 common_b_positions,
218 paired_a_positions,
219 paired_b_positions,
220 paired_dims_a,
221 paired_dims_b,
222 result_indices,
223 result_dims,
224 result_strides,
225 result_total,
226 a_strides,
227 b_strides,
228 new_indices,
229 n_common,
230 })
231}
232
233fn linear_to_multi(linear: usize, dims: &[usize]) -> Vec<usize> {
234 let mut multi = vec![0; dims.len()];
235 let mut remaining = linear;
236 for i in 0..dims.len() {
237 multi[i] = remaining % dims[i];
238 remaining /= dims[i];
239 }
240 multi
241}
242
243fn multi_to_linear(multi: &[usize], strides: &[usize]) -> usize {
244 multi.iter().zip(strides).map(|(&m, &s)| m * s).sum()
245}
246
247fn direct_sum_typed<T: TensorElement + Zero>(
248 a: &TensorDynLen,
249 b: &TensorDynLen,
250 pairs: &[(DynIndex, DynIndex)],
251) -> Result<(TensorDynLen, Vec<DynIndex>)> {
252 let setup = setup_direct_sum(a, b, pairs)?;
253 let a_data = a.to_vec::<T>()?;
254 let b_data = b.to_vec::<T>()?;
255
256 let mut result_data: Vec<T> = vec![T::zero(); setup.result_total];
257
258 #[allow(clippy::needless_range_loop)]
259 for result_linear in 0..setup.result_total {
260 let result_multi = linear_to_multi(result_linear, &setup.result_dims);
261 let common_multi: Vec<usize> = result_multi[..setup.n_common].to_vec();
262 let paired_multi: Vec<usize> = result_multi[setup.n_common..].to_vec();
263
264 let all_from_a = paired_multi
265 .iter()
266 .enumerate()
267 .all(|(i, &pm)| pm < setup.paired_dims_a[i]);
268 let all_from_b = paired_multi
269 .iter()
270 .enumerate()
271 .all(|(i, &pm)| pm >= setup.paired_dims_a[i]);
272
273 if all_from_a {
274 let a_dims = a.dims();
275 let mut a_multi = vec![0usize; a_dims.len()];
276 for (i, &cp) in setup.common_a_positions.iter().enumerate() {
277 a_multi[cp] = common_multi[i];
278 }
279 for (i, &pp) in setup.paired_a_positions.iter().enumerate() {
280 a_multi[pp] = paired_multi[i];
281 }
282 let a_linear = multi_to_linear(&a_multi, &setup.a_strides);
283 result_data[result_linear] = a_data[a_linear];
284 } else if all_from_b {
285 let b_dims = b.dims();
286 let mut b_multi = vec![0usize; b_dims.len()];
287 for (i, &cp) in setup.common_b_positions.iter().enumerate() {
288 b_multi[cp] = common_multi[i];
289 }
290 for (i, &pp) in setup.paired_b_positions.iter().enumerate() {
291 b_multi[pp] = paired_multi[i] - setup.paired_dims_a[i];
292 }
293 let b_linear = multi_to_linear(&b_multi, &setup.b_strides);
294 result_data[result_linear] = b_data[b_linear];
295 }
296 }
298
299 let result = TensorDynLen::from_dense(setup.result_indices, result_data)?;
300 Ok((result, setup.new_indices))
301}
302
303#[cfg(test)]
304mod tests;