Skip to main content

tensor4all_core/defaults/
direct_sum.rs

1//! Direct sum operations for tensors.
2//!
3//! This module provides functionality to compute the direct sum of tensors
4//! along specified index pairs. The direct sum concatenates tensor data along
5//! the paired indices, creating new indices with combined dimensions.
6//!
7//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
8
9use crate::defaults::DynIndex;
10use crate::index_like::IndexLike;
11use crate::tensor::TensorDynLen;
12use anyhow::Result;
13use num_traits::Zero;
14use tensor4all_tensorbackend::TensorElement;
15
16/// Compute the direct sum of two tensors along specified index pairs.
17///
18/// For tensors A and B with indices to be summed specified as pairs,
19/// creates a new tensor C where each paired index has dimension = dim_A + dim_B.
20/// Non-paired indices must match exactly between A and B.
21///
22/// # Arguments
23///
24/// * `a` - First tensor
25/// * `b` - Second tensor
26/// * `pairs` - Pairs of (a_index, b_index) to be summed. Each pair creates
27///   a new index in the result with dimension = dim(a_index) + dim(b_index).
28///
29/// # Returns
30///
31/// A tuple of:
32/// - The direct sum tensor
33/// - The new indices created for the summed dimensions (one per pair)
34///
35/// # Example
36///
37/// ```
38/// use tensor4all_core::{direct_sum, DynIndex, TensorDynLen};
39///
40/// # fn main() -> anyhow::Result<()> {
41/// let j = DynIndex::new_dyn(2);
42/// let k = DynIndex::new_dyn(3);
43///
44/// let a = TensorDynLen::from_dense(vec![j.clone()], vec![1.0, 2.0])?;
45/// let b = TensorDynLen::from_dense(vec![k.clone()], vec![3.0, 4.0, 5.0])?;
46/// let (result, new_indices) = direct_sum(&a, &b, &[(j.clone(), k.clone())])?;
47///
48/// assert_eq!(new_indices.len(), 1);
49/// assert_eq!(result.to_vec::<f64>()?, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
50/// # Ok(())
51/// # }
52/// ```
53pub fn direct_sum(
54    a: &TensorDynLen,
55    b: &TensorDynLen,
56    pairs: &[(DynIndex, DynIndex)],
57) -> Result<(TensorDynLen, Vec<DynIndex>)> {
58    if a.is_f64() && b.is_f64() {
59        direct_sum_typed::<f64>(a, b, pairs)
60    } else if a.is_complex() && b.is_complex() {
61        direct_sum_typed::<num_complex::Complex64>(a, b, pairs)
62    } else {
63        Err(anyhow::anyhow!(
64            "direct_sum requires both tensors to have the same dense scalar type (f64 or Complex64)"
65        ))
66    }
67}
68
69/// Setup data for direct sum computation
70#[allow(dead_code)]
71struct DirectSumSetup {
72    common_a_positions: Vec<usize>,
73    common_b_positions: Vec<usize>,
74    paired_a_positions: Vec<usize>,
75    paired_b_positions: Vec<usize>,
76    paired_dims_a: Vec<usize>,
77    paired_dims_b: Vec<usize>,
78    result_indices: Vec<DynIndex>,
79    result_dims: Vec<usize>,
80    result_strides: Vec<usize>,
81    result_total: usize,
82    a_strides: Vec<usize>,
83    b_strides: Vec<usize>,
84    new_indices: Vec<DynIndex>,
85    n_common: usize,
86}
87
88fn setup_direct_sum(
89    a: &TensorDynLen,
90    b: &TensorDynLen,
91    pairs: &[(DynIndex, DynIndex)],
92) -> Result<DirectSumSetup> {
93    use std::collections::HashMap;
94
95    if pairs.is_empty() {
96        return Err(anyhow::anyhow!(
97            "direct_sum requires at least one index pair"
98        ));
99    }
100
101    // Build maps from full index metadata to positions. Same-id indices that
102    // differ by prime level or tags are distinct axes.
103    let a_idx_map: HashMap<&DynIndex, usize> = a
104        .indices
105        .iter()
106        .enumerate()
107        .map(|(i, idx)| (idx, i))
108        .collect();
109    let b_idx_map: HashMap<&DynIndex, usize> = b
110        .indices
111        .iter()
112        .enumerate()
113        .map(|(i, idx)| (idx, i))
114        .collect();
115
116    // Build set of paired indices.
117    let mut a_paired_indices: std::collections::HashSet<&DynIndex> =
118        std::collections::HashSet::new();
119    let mut b_paired_indices: std::collections::HashSet<&DynIndex> =
120        std::collections::HashSet::new();
121
122    for (a_idx, b_idx) in pairs {
123        if !a_idx_map.contains_key(a_idx) {
124            return Err(anyhow::anyhow!(
125                "Index not found in first tensor (direct_sum)"
126            ));
127        }
128        if !b_idx_map.contains_key(b_idx) {
129            return Err(anyhow::anyhow!(
130                "Index not found in second tensor (direct_sum)"
131            ));
132        }
133        a_paired_indices.insert(a_idx);
134        b_paired_indices.insert(b_idx);
135    }
136
137    // Identify common (non-paired) indices
138    // Iterate over a.indices (Vec) to preserve deterministic order,
139    // using b_idx_map only for lookups.
140    let mut common_a_positions: Vec<usize> = Vec::new();
141    let mut common_b_positions: Vec<usize> = Vec::new();
142    for (a_pos, a_idx) in a.indices.iter().enumerate() {
143        if a_paired_indices.contains(a_idx) {
144            continue;
145        }
146        if let Some(&b_pos) = b_idx_map.get(a_idx) {
147            if b_paired_indices.contains(a_idx) {
148                continue;
149            }
150            let a_dims = a.dims();
151            let b_dims = b.dims();
152            if a_dims[a_pos] != b_dims[b_pos] {
153                return Err(anyhow::anyhow!(
154                    "Dimension mismatch for common index: {} vs {}",
155                    a_dims[a_pos],
156                    b_dims[b_pos]
157                ));
158            }
159            common_a_positions.push(a_pos);
160            common_b_positions.push(b_pos);
161        }
162    }
163
164    // Build paired positions and dimensions
165    let paired_a_positions: Vec<usize> = pairs.iter().map(|(a_idx, _)| a_idx_map[a_idx]).collect();
166    let paired_b_positions: Vec<usize> = pairs.iter().map(|(_, b_idx)| b_idx_map[b_idx]).collect();
167    let a_dims = a.dims();
168    let b_dims = b.dims();
169    let paired_dims_a: Vec<usize> = paired_a_positions.iter().map(|&p| a_dims[p]).collect();
170    let paired_dims_b: Vec<usize> = paired_b_positions.iter().map(|&p| b_dims[p]).collect();
171
172    // Create new indices for paired dimensions
173    let mut new_indices: Vec<DynIndex> = Vec::new();
174    for (&dim_a, &dim_b) in paired_dims_a.iter().zip(&paired_dims_b) {
175        let new_dim = dim_a + dim_b;
176        let new_index = DynIndex::new_link(new_dim)
177            .map_err(|e| anyhow::anyhow!("Failed to create index: {:?}", e))?;
178        new_indices.push(new_index);
179    }
180
181    // Build result indices and dimensions
182    let mut result_indices: Vec<DynIndex> = Vec::new();
183    let mut result_dims: Vec<usize> = Vec::new();
184
185    // Common indices first.
186    let a_dims = a.dims();
187    for &a_pos in &common_a_positions {
188        result_indices.push(a.indices[a_pos].clone());
189        result_dims.push(a_dims[a_pos]);
190    }
191
192    // New indices from pairs preserve the full index identity semantics (id + plev + tags).
193    for new_idx in &new_indices {
194        result_indices.push(new_idx.clone());
195        result_dims.push(new_idx.dim());
196    }
197
198    // Compute column-major strides (leftmost index fastest).
199    let result_total: usize = result_dims.iter().product();
200    let mut result_strides: Vec<usize> = vec![1; result_dims.len()];
201    for i in 1..result_dims.len() {
202        result_strides[i] = result_strides[i - 1] * result_dims[i - 1];
203    }
204
205    let a_dims = a.dims();
206    let b_dims = b.dims();
207    let mut a_strides: Vec<usize> = vec![1; a_dims.len()];
208    for i in 1..a_dims.len() {
209        a_strides[i] = a_strides[i - 1] * a_dims[i - 1];
210    }
211
212    let mut b_strides: Vec<usize> = vec![1; b_dims.len()];
213    for i in 1..b_dims.len() {
214        b_strides[i] = b_strides[i - 1] * b_dims[i - 1];
215    }
216
217    let n_common = common_a_positions.len();
218
219    Ok(DirectSumSetup {
220        common_a_positions,
221        common_b_positions,
222        paired_a_positions,
223        paired_b_positions,
224        paired_dims_a,
225        paired_dims_b,
226        result_indices,
227        result_dims,
228        result_strides,
229        result_total,
230        a_strides,
231        b_strides,
232        new_indices,
233        n_common,
234    })
235}
236
237fn linear_to_multi(linear: usize, dims: &[usize]) -> Vec<usize> {
238    let mut multi = vec![0; dims.len()];
239    let mut remaining = linear;
240    for i in 0..dims.len() {
241        multi[i] = remaining % dims[i];
242        remaining /= dims[i];
243    }
244    multi
245}
246
247fn multi_to_linear(multi: &[usize], strides: &[usize]) -> usize {
248    multi.iter().zip(strides).map(|(&m, &s)| m * s).sum()
249}
250
251fn direct_sum_typed<T: TensorElement + Zero>(
252    a: &TensorDynLen,
253    b: &TensorDynLen,
254    pairs: &[(DynIndex, DynIndex)],
255) -> Result<(TensorDynLen, Vec<DynIndex>)> {
256    let setup = setup_direct_sum(a, b, pairs)?;
257    let a_data = a.to_vec::<T>()?;
258    let b_data = b.to_vec::<T>()?;
259
260    let mut result_data: Vec<T> = vec![T::zero(); setup.result_total];
261
262    #[allow(clippy::needless_range_loop)]
263    for result_linear in 0..setup.result_total {
264        let result_multi = linear_to_multi(result_linear, &setup.result_dims);
265        let common_multi: Vec<usize> = result_multi[..setup.n_common].to_vec();
266        let paired_multi: Vec<usize> = result_multi[setup.n_common..].to_vec();
267
268        let all_from_a = paired_multi
269            .iter()
270            .enumerate()
271            .all(|(i, &pm)| pm < setup.paired_dims_a[i]);
272        let all_from_b = paired_multi
273            .iter()
274            .enumerate()
275            .all(|(i, &pm)| pm >= setup.paired_dims_a[i]);
276
277        if all_from_a {
278            let a_dims = a.dims();
279            let mut a_multi = vec![0usize; a_dims.len()];
280            for (i, &cp) in setup.common_a_positions.iter().enumerate() {
281                a_multi[cp] = common_multi[i];
282            }
283            for (i, &pp) in setup.paired_a_positions.iter().enumerate() {
284                a_multi[pp] = paired_multi[i];
285            }
286            let a_linear = multi_to_linear(&a_multi, &setup.a_strides);
287            result_data[result_linear] = a_data[a_linear];
288        } else if all_from_b {
289            let b_dims = b.dims();
290            let mut b_multi = vec![0usize; b_dims.len()];
291            for (i, &cp) in setup.common_b_positions.iter().enumerate() {
292                b_multi[cp] = common_multi[i];
293            }
294            for (i, &pp) in setup.paired_b_positions.iter().enumerate() {
295                b_multi[pp] = paired_multi[i] - setup.paired_dims_a[i];
296            }
297            let b_linear = multi_to_linear(&b_multi, &setup.b_strides);
298            result_data[result_linear] = b_data[b_linear];
299        }
300        // else: mixed case stays T::zero()
301    }
302
303    let result = TensorDynLen::from_dense(setup.result_indices, result_data)?;
304    Ok((result, setup.new_indices))
305}
306
307#[cfg(test)]
308mod tests;