tensor4all_core/defaults/
direct_sum.rs

1//! Direct sum operations for tensors.
2//!
3//! This module provides functionality to compute the direct sum of tensors
4//! along specified index pairs. The direct sum concatenates tensor data along
5//! the paired indices, creating new indices with combined dimensions.
6//!
7//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
8
9use crate::defaults::DynIndex;
10use crate::index_like::IndexLike;
11use crate::tensor::TensorDynLen;
12use anyhow::Result;
13use num_traits::Zero;
14use tensor4all_tensorbackend::TensorElement;
15
16/// Compute the direct sum of two tensors along specified index pairs.
17///
18/// For tensors A and B with indices to be summed specified as pairs,
19/// creates a new tensor C where each paired index has dimension = dim_A + dim_B.
20/// Non-paired indices must match exactly between A and B (same ID).
21///
22/// # Arguments
23///
24/// * `a` - First tensor
25/// * `b` - Second tensor
26/// * `pairs` - Pairs of (a_index, b_index) to be summed. Each pair creates
27///   a new index in the result with dimension = dim(a_index) + dim(b_index).
28///
29/// # Returns
30///
31/// A tuple of:
32/// - The direct sum tensor
33/// - The new indices created for the summed dimensions (one per pair)
34///
35/// # Example
36///
37/// ```ignore
38/// // A has indices [i, j] with dims [2, 3]
39/// // B has indices [i, k] with dims [2, 4]
40/// // If we pair (j, k), result has indices [i, m] with dims [2, 7]
41/// // where m is a new index with dim = 3 + 4 = 7
42/// let (result, new_indices) = direct_sum(&a, &b, &[(j, k)])?;
43/// ```
44pub fn direct_sum(
45    a: &TensorDynLen,
46    b: &TensorDynLen,
47    pairs: &[(DynIndex, DynIndex)],
48) -> Result<(TensorDynLen, Vec<DynIndex>)> {
49    if a.is_f64() && b.is_f64() {
50        direct_sum_typed::<f64>(a, b, pairs)
51    } else if a.is_complex() && b.is_complex() {
52        direct_sum_typed::<num_complex::Complex64>(a, b, pairs)
53    } else {
54        Err(anyhow::anyhow!(
55            "direct_sum requires both tensors to have the same dense scalar type (f64 or Complex64)"
56        ))
57    }
58}
59
60/// Setup data for direct sum computation
61#[allow(dead_code)]
62struct DirectSumSetup {
63    common_a_positions: Vec<usize>,
64    common_b_positions: Vec<usize>,
65    paired_a_positions: Vec<usize>,
66    paired_b_positions: Vec<usize>,
67    paired_dims_a: Vec<usize>,
68    paired_dims_b: Vec<usize>,
69    result_indices: Vec<DynIndex>,
70    result_dims: Vec<usize>,
71    result_strides: Vec<usize>,
72    result_total: usize,
73    a_strides: Vec<usize>,
74    b_strides: Vec<usize>,
75    new_indices: Vec<DynIndex>,
76    n_common: usize,
77}
78
79fn setup_direct_sum(
80    a: &TensorDynLen,
81    b: &TensorDynLen,
82    pairs: &[(DynIndex, DynIndex)],
83) -> Result<DirectSumSetup> {
84    use crate::defaults::DynId;
85    use std::collections::HashMap;
86
87    if pairs.is_empty() {
88        return Err(anyhow::anyhow!(
89            "direct_sum requires at least one index pair"
90        ));
91    }
92
93    // Build maps from index IDs to positions
94    let a_idx_map: HashMap<&DynId, usize> = a
95        .indices
96        .iter()
97        .enumerate()
98        .map(|(i, idx)| (idx.id(), i))
99        .collect();
100    let b_idx_map: HashMap<&DynId, usize> = b
101        .indices
102        .iter()
103        .enumerate()
104        .map(|(i, idx)| (idx.id(), i))
105        .collect();
106
107    // Build set of paired index IDs
108    let mut a_paired_ids: std::collections::HashSet<&DynId> = std::collections::HashSet::new();
109    let mut b_paired_ids: std::collections::HashSet<&DynId> = std::collections::HashSet::new();
110
111    for (a_idx, b_idx) in pairs {
112        if !a_idx_map.contains_key(a_idx.id()) {
113            return Err(anyhow::anyhow!(
114                "Index not found in first tensor (direct_sum)"
115            ));
116        }
117        if !b_idx_map.contains_key(b_idx.id()) {
118            return Err(anyhow::anyhow!(
119                "Index not found in second tensor (direct_sum)"
120            ));
121        }
122        a_paired_ids.insert(a_idx.id());
123        b_paired_ids.insert(b_idx.id());
124    }
125
126    // Identify common (non-paired) indices
127    // Iterate over a.indices (Vec) to preserve deterministic order,
128    // using b_idx_map only for lookups.
129    let mut common_a_positions: Vec<usize> = Vec::new();
130    let mut common_b_positions: Vec<usize> = Vec::new();
131    for (a_pos, a_idx) in a.indices.iter().enumerate() {
132        let a_id = a_idx.id();
133        if a_paired_ids.contains(a_id) {
134            continue;
135        }
136        if let Some(&b_pos) = b_idx_map.get(a_id) {
137            if b_paired_ids.contains(a_id) {
138                continue;
139            }
140            let a_dims = a.dims();
141            let b_dims = b.dims();
142            if a_dims[a_pos] != b_dims[b_pos] {
143                return Err(anyhow::anyhow!(
144                    "Dimension mismatch for common index: {} vs {}",
145                    a_dims[a_pos],
146                    b_dims[b_pos]
147                ));
148            }
149            common_a_positions.push(a_pos);
150            common_b_positions.push(b_pos);
151        }
152    }
153
154    // Build paired positions and dimensions
155    let paired_a_positions: Vec<usize> = pairs
156        .iter()
157        .map(|(a_idx, _)| a_idx_map[a_idx.id()])
158        .collect();
159    let paired_b_positions: Vec<usize> = pairs
160        .iter()
161        .map(|(_, b_idx)| b_idx_map[b_idx.id()])
162        .collect();
163    let a_dims = a.dims();
164    let b_dims = b.dims();
165    let paired_dims_a: Vec<usize> = paired_a_positions.iter().map(|&p| a_dims[p]).collect();
166    let paired_dims_b: Vec<usize> = paired_b_positions.iter().map(|&p| b_dims[p]).collect();
167
168    // Create new indices for paired dimensions
169    let mut new_indices: Vec<DynIndex> = Vec::new();
170    for (&dim_a, &dim_b) in paired_dims_a.iter().zip(&paired_dims_b) {
171        let new_dim = dim_a + dim_b;
172        let new_index = DynIndex::new_link(new_dim)
173            .map_err(|e| anyhow::anyhow!("Failed to create index: {:?}", e))?;
174        new_indices.push(new_index);
175    }
176
177    // Build result indices and dimensions
178    let mut result_indices: Vec<DynIndex> = Vec::new();
179    let mut result_dims: Vec<usize> = Vec::new();
180
181    // Common indices first
182    let a_dims = a.dims();
183    for &a_pos in &common_a_positions {
184        result_indices.push(a.indices[a_pos].clone());
185        result_dims.push(a_dims[a_pos]);
186    }
187
188    // New indices from pairs preserve the full index identity semantics (id + plev + tags).
189    for new_idx in &new_indices {
190        result_indices.push(new_idx.clone());
191        result_dims.push(new_idx.dim());
192    }
193
194    // Compute column-major strides (leftmost index fastest).
195    let result_total: usize = result_dims.iter().product();
196    let mut result_strides: Vec<usize> = vec![1; result_dims.len()];
197    for i in 1..result_dims.len() {
198        result_strides[i] = result_strides[i - 1] * result_dims[i - 1];
199    }
200
201    let a_dims = a.dims();
202    let b_dims = b.dims();
203    let mut a_strides: Vec<usize> = vec![1; a_dims.len()];
204    for i in 1..a_dims.len() {
205        a_strides[i] = a_strides[i - 1] * a_dims[i - 1];
206    }
207
208    let mut b_strides: Vec<usize> = vec![1; b_dims.len()];
209    for i in 1..b_dims.len() {
210        b_strides[i] = b_strides[i - 1] * b_dims[i - 1];
211    }
212
213    let n_common = common_a_positions.len();
214
215    Ok(DirectSumSetup {
216        common_a_positions,
217        common_b_positions,
218        paired_a_positions,
219        paired_b_positions,
220        paired_dims_a,
221        paired_dims_b,
222        result_indices,
223        result_dims,
224        result_strides,
225        result_total,
226        a_strides,
227        b_strides,
228        new_indices,
229        n_common,
230    })
231}
232
233fn linear_to_multi(linear: usize, dims: &[usize]) -> Vec<usize> {
234    let mut multi = vec![0; dims.len()];
235    let mut remaining = linear;
236    for i in 0..dims.len() {
237        multi[i] = remaining % dims[i];
238        remaining /= dims[i];
239    }
240    multi
241}
242
243fn multi_to_linear(multi: &[usize], strides: &[usize]) -> usize {
244    multi.iter().zip(strides).map(|(&m, &s)| m * s).sum()
245}
246
247fn direct_sum_typed<T: TensorElement + Zero>(
248    a: &TensorDynLen,
249    b: &TensorDynLen,
250    pairs: &[(DynIndex, DynIndex)],
251) -> Result<(TensorDynLen, Vec<DynIndex>)> {
252    let setup = setup_direct_sum(a, b, pairs)?;
253    let a_data = a.to_vec::<T>()?;
254    let b_data = b.to_vec::<T>()?;
255
256    let mut result_data: Vec<T> = vec![T::zero(); setup.result_total];
257
258    #[allow(clippy::needless_range_loop)]
259    for result_linear in 0..setup.result_total {
260        let result_multi = linear_to_multi(result_linear, &setup.result_dims);
261        let common_multi: Vec<usize> = result_multi[..setup.n_common].to_vec();
262        let paired_multi: Vec<usize> = result_multi[setup.n_common..].to_vec();
263
264        let all_from_a = paired_multi
265            .iter()
266            .enumerate()
267            .all(|(i, &pm)| pm < setup.paired_dims_a[i]);
268        let all_from_b = paired_multi
269            .iter()
270            .enumerate()
271            .all(|(i, &pm)| pm >= setup.paired_dims_a[i]);
272
273        if all_from_a {
274            let a_dims = a.dims();
275            let mut a_multi = vec![0usize; a_dims.len()];
276            for (i, &cp) in setup.common_a_positions.iter().enumerate() {
277                a_multi[cp] = common_multi[i];
278            }
279            for (i, &pp) in setup.paired_a_positions.iter().enumerate() {
280                a_multi[pp] = paired_multi[i];
281            }
282            let a_linear = multi_to_linear(&a_multi, &setup.a_strides);
283            result_data[result_linear] = a_data[a_linear];
284        } else if all_from_b {
285            let b_dims = b.dims();
286            let mut b_multi = vec![0usize; b_dims.len()];
287            for (i, &cp) in setup.common_b_positions.iter().enumerate() {
288                b_multi[cp] = common_multi[i];
289            }
290            for (i, &pp) in setup.paired_b_positions.iter().enumerate() {
291                b_multi[pp] = paired_multi[i] - setup.paired_dims_a[i];
292            }
293            let b_linear = multi_to_linear(&b_multi, &setup.b_strides);
294            result_data[result_linear] = b_data[b_linear];
295        }
296        // else: mixed case stays T::zero()
297    }
298
299    let result = TensorDynLen::from_dense(setup.result_indices, result_data)?;
300    Ok((result, setup.new_indices))
301}
302
303#[cfg(test)]
304mod tests;