strided_opteinsum/
single_tensor.rs

1use std::collections::HashMap;
2
3use strided_kernel::{copy_into, reduce_axis};
4use strided_view::{ElementOpApply, Identity, StridedArray, StridedView};
5
6/// Execute a single-tensor einsum operation (5-step pipeline).
7///
8/// Given input with axis IDs `input_ids` and desired `output_ids`:
9/// 1. Identify repeated input indices -> diagonal_view (stride trick, zero-copy)
10/// 2. Identify indices to sum out -> reduce_axis
11/// 3. Permute to output order -> copy_into
12/// 4. Repeat: broadcast to NEW dimensions (output labels not in input, via size_dict)
13/// 5. Duplicate: repeated output indices (e.g. "i->ii", diagonal write)
14///
15/// Pass `size_dict` to specify sizes for output indices not present in the input.
16pub fn single_tensor_einsum<T>(
17    src: &StridedView<T>,
18    input_ids: &[char],
19    output_ids: &[char],
20    size_dict: Option<&HashMap<char, usize>>,
21) -> crate::Result<StridedArray<T>>
22where
23    T: Copy + ElementOpApply + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero + Default,
24{
25    // Decompose output_ids into unique labels and detect repeated output labels (Duplicate).
26    // Also identify labels not present in input (Repeat/generative).
27    let mut unique_output_ids: Vec<char> = Vec::new();
28    let mut duplicate_map: Vec<Vec<usize>> = Vec::new(); // for each unique output label, positions in output_ids
29    for (pos, &ch) in output_ids.iter().enumerate() {
30        if let Some(idx) = unique_output_ids.iter().position(|&c| c == ch) {
31            duplicate_map[idx].push(pos);
32        } else {
33            unique_output_ids.push(ch);
34            duplicate_map.push(vec![pos]);
35        }
36    }
37    let has_duplicate = duplicate_map.iter().any(|positions| positions.len() > 1);
38
39    // Identify generative (Repeat) labels: in unique_output_ids but not in input_ids
40    let generative_labels: Vec<char> = unique_output_ids
41        .iter()
42        .filter(|ch| !input_ids.contains(ch))
43        .copied()
44        .collect();
45    let has_repeat = !generative_labels.is_empty();
46
47    // If we have generative or duplicate labels, we need a different pipeline.
48    // Otherwise, fall through to the existing optimized pipeline.
49    if !has_repeat && !has_duplicate {
50        return single_tensor_einsum_classic(src, input_ids, output_ids);
51    }
52
53    // For generative labels, look up sizes from size_dict
54    let empty_dict = HashMap::new();
55    let sd = size_dict.unwrap_or(&empty_dict);
56    let mut gen_sizes: HashMap<char, usize> = HashMap::new();
57    for &ch in &generative_labels {
58        match sd.get(&ch) {
59            Some(&sz) => {
60                gen_sizes.insert(ch, sz);
61            }
62            None => return Err(crate::EinsumError::OrphanOutputAxis(ch.to_string())),
63        }
64    }
65
66    // Step 1-3: Run classic pipeline on unique_output_ids minus generative labels.
67    // This gives us the "core" result before Repeat and Duplicate.
68    let core_output_ids: Vec<char> = unique_output_ids
69        .iter()
70        .filter(|ch| !generative_labels.contains(ch))
71        .copied()
72        .collect();
73
74    let core_result = if core_output_ids.is_empty() && input_ids.is_empty() {
75        // Scalar input, scalar core output — just copy the scalar
76        let mut out = StridedArray::<T>::col_major(&[]);
77        out.data_mut()[0] = unsafe { *src.data().as_ptr().offset(src.offset() as isize) };
78        out
79    } else if core_output_ids.is_empty() && !input_ids.is_empty() {
80        // Need to reduce everything (full trace/sum)
81        single_tensor_einsum_classic(src, input_ids, &[])?
82    } else {
83        single_tensor_einsum_classic(src, input_ids, &core_output_ids)?
84    };
85
86    // Step 4 (Repeat): broadcast core_result to include generative dimensions.
87    // Build the intermediate shape: unique_output_ids ordering, with generative dims added.
88    let intermediate = if has_repeat {
89        let mut inter_dims: Vec<usize> = Vec::new();
90        let mut inter_strides: Vec<isize> = Vec::new();
91        let core_view = core_result.view();
92        for &ch in &unique_output_ids {
93            if generative_labels.contains(&ch) {
94                inter_dims.push(gen_sizes[&ch]);
95                inter_strides.push(0); // stride-0 = broadcast
96            } else {
97                let core_pos = core_output_ids.iter().position(|&c| c == ch).unwrap();
98                inter_dims.push(core_view.dims()[core_pos]);
99                inter_strides.push(core_view.strides()[core_pos]);
100            }
101        }
102        // Create a broadcasted view and materialize it
103        let broadcast_view: StridedView<'_, T, Identity> = StridedView::new(
104            core_view.data(),
105            &inter_dims,
106            &inter_strides,
107            core_view.offset() as isize,
108        )?;
109        let mut materialized = StridedArray::<T>::col_major(&inter_dims);
110        copy_into(&mut materialized.view_mut(), &broadcast_view)?;
111        materialized
112    } else {
113        // No repeat needed; but we may need to permute core_result to unique_output_ids order
114        if core_output_ids == unique_output_ids {
115            core_result
116        } else {
117            // Permute
118            let perm: Vec<usize> = unique_output_ids
119                .iter()
120                .map(|ch| core_output_ids.iter().position(|c| c == ch).unwrap())
121                .collect();
122            let permuted_view = core_result.view().permute(&perm)?;
123            let mut out = StridedArray::<T>::col_major(permuted_view.dims());
124            copy_into(&mut out.view_mut(), &permuted_view)?;
125            out
126        }
127    };
128
129    // Step 5 (Duplicate): write to diagonal positions for repeated output labels.
130    if !has_duplicate {
131        return Ok(intermediate);
132    }
133
134    // Build the full output shape
135    let out_dims: Vec<usize> = output_ids
136        .iter()
137        .map(|ch| {
138            let idx = unique_output_ids.iter().position(|c| c == ch).unwrap();
139            intermediate.dims()[idx]
140        })
141        .collect();
142    let total: usize = out_dims.iter().product::<usize>().max(1);
143    let mut out_data = vec![T::zero(); total];
144    let out_strides = strided_view::col_major_strides(&out_dims);
145
146    // Iterate over the intermediate (unique_output_ids shape) and write to
147    // all matching positions in the output.
148    let inter_dims = intermediate.dims().to_vec();
149    let inter_total: usize = inter_dims.iter().product::<usize>().max(1);
150    let inter_strides_cm = strided_view::col_major_strides(&inter_dims);
151
152    let mut inter_idx = vec![0usize; inter_dims.len()];
153    for flat in 0..inter_total {
154        // Read value from intermediate
155        let mut inter_flat = 0usize;
156        for d in 0..inter_dims.len() {
157            inter_flat += inter_idx[d] * inter_strides_cm[d] as usize;
158        }
159        let val = intermediate.data()[inter_flat];
160
161        // Compute output flat index: for each output position, look up the
162        // corresponding unique_output_id index from inter_idx
163        let mut out_flat = 0usize;
164        for (out_pos, &ch) in output_ids.iter().enumerate() {
165            let inter_pos = unique_output_ids.iter().position(|&c| c == ch).unwrap();
166            out_flat += inter_idx[inter_pos] * out_strides[out_pos] as usize;
167        }
168        out_data[out_flat] = val;
169
170        // Increment inter_idx (col-major order)
171        if flat + 1 < inter_total {
172            for d in 0..inter_dims.len() {
173                inter_idx[d] += 1;
174                if inter_idx[d] < inter_dims[d] {
175                    break;
176                }
177                inter_idx[d] = 0;
178            }
179        }
180    }
181
182    StridedArray::from_parts(out_data, &out_dims, &out_strides, 0)
183        .map_err(|e| crate::EinsumError::Strided(e))
184}
185
186/// Classic single-tensor einsum (steps 1-3 only: Diag, Sum, Permute).
187///
188/// No generative or duplicate output labels.
189fn single_tensor_einsum_classic<T>(
190    src: &StridedView<T>,
191    input_ids: &[char],
192    output_ids: &[char],
193) -> crate::Result<StridedArray<T>>
194where
195    T: Copy + ElementOpApply + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero + Default,
196{
197    // Fast path: full trace "cc...c -> " (all indices identical, scalar output)
198    if output_ids.is_empty()
199        && !input_ids.is_empty()
200        && input_ids.iter().all(|&c| c == input_ids[0])
201    {
202        let n = src.dims()[0];
203        let diag_stride: isize = src.strides().iter().sum();
204        let ptr = src.data().as_ptr();
205        let mut offset = src.offset() as isize;
206        let mut acc = T::zero();
207        for _ in 0..n {
208            acc = acc + unsafe { *ptr.offset(offset) };
209            offset += diag_stride;
210        }
211        let mut out = StridedArray::<T>::col_major(&[]);
212        out.data_mut()[0] = acc;
213        return Ok(out);
214    }
215
216    // Fast path: partial trace with one repeated pair
217    // e.g. "iij->j", "iji->j", "jii->j" — one pair of repeated indices, rest go to output
218    {
219        let mut pair: Option<(char, usize, usize)> = None;
220        let mut seen_chars: Vec<(char, usize)> = Vec::new();
221        let mut has_triple = false;
222        for (i, &ch) in input_ids.iter().enumerate() {
223            if let Some(&(_, first)) = seen_chars.iter().find(|(c, _)| *c == ch) {
224                if pair.is_some() {
225                    // More than one pair — not a simple partial trace
226                    has_triple = true;
227                    break;
228                }
229                pair = Some((ch, first, i));
230            } else {
231                seen_chars.push((ch, i));
232            }
233        }
234
235        if let Some((repeated_ch, pos0, pos1)) = pair {
236            if !has_triple && !output_ids.contains(&repeated_ch) {
237                // All non-repeated indices must be in output_ids
238                let free_ids: Vec<(char, usize)> = input_ids
239                    .iter()
240                    .enumerate()
241                    .filter(|(i, _)| *i != pos0 && *i != pos1)
242                    .map(|(i, &ch)| (ch, i))
243                    .collect();
244                let all_free_in_output = free_ids.iter().all(|(ch, _)| output_ids.contains(ch));
245
246                if all_free_in_output && free_ids.len() == output_ids.len() {
247                    let n = src.dims()[pos0]; // diagonal length
248                    let diag_stride = src.strides()[pos0] + src.strides()[pos1];
249                    let ptr = src.data().as_ptr();
250                    let base_offset = src.offset() as isize;
251
252                    // Compute output permutation: output_ids order vs free_ids order
253                    let out_dims: Vec<usize> = output_ids
254                        .iter()
255                        .map(|oc| {
256                            let (_, src_axis) = free_ids.iter().find(|(ch, _)| ch == oc).unwrap();
257                            src.dims()[*src_axis]
258                        })
259                        .collect();
260                    let out_strides_src: Vec<isize> = output_ids
261                        .iter()
262                        .map(|oc| {
263                            let (_, src_axis) = free_ids.iter().find(|(ch, _)| ch == oc).unwrap();
264                            src.strides()[*src_axis]
265                        })
266                        .collect();
267
268                    let total_out: usize = out_dims.iter().product::<usize>().max(1);
269                    let out_col_strides = strided_view::col_major_strides(&out_dims);
270                    let mut out_data = vec![T::zero(); total_out];
271
272                    // Iterate over output elements using col-major order
273                    let out_rank = out_dims.len();
274                    let mut idx = vec![0usize; out_rank];
275                    for flat in 0..total_out {
276                        // Compute source offset for this output position
277                        let mut src_off = base_offset;
278                        for d in 0..out_rank {
279                            src_off += idx[d] as isize * out_strides_src[d];
280                        }
281                        // Sum along diagonal
282                        let mut acc = T::zero();
283                        let mut diag_off = src_off;
284                        for _ in 0..n {
285                            acc = acc + unsafe { *ptr.offset(diag_off) };
286                            diag_off += diag_stride;
287                        }
288                        // Write to output using col-major flat index
289                        let mut out_flat = 0usize;
290                        for d in 0..out_rank {
291                            out_flat += idx[d] * out_col_strides[d] as usize;
292                        }
293                        out_data[out_flat] = acc;
294
295                        // Increment index (col-major order)
296                        if flat + 1 < total_out {
297                            for d in 0..out_rank {
298                                idx[d] += 1;
299                                if idx[d] < out_dims[d] {
300                                    break;
301                                }
302                                idx[d] = 0;
303                            }
304                        }
305                    }
306
307                    return StridedArray::from_parts(out_data, &out_dims, &out_col_strides, 0)
308                        .map_err(|e| crate::EinsumError::Strided(e));
309                }
310            }
311        }
312    }
313
314    // Step 1: Find repeated index pairs for diagonal_view.
315    // Scan input_ids left-to-right. If a char appears twice, record (first_pos, second_pos).
316    let mut pairs: Vec<(usize, usize)> = Vec::new();
317    let mut seen: Vec<(char, usize)> = Vec::new();
318    for (i, &ch) in input_ids.iter().enumerate() {
319        if let Some(&(_, first)) = seen.iter().find(|(c, _)| *c == ch) {
320            pairs.push((first, i));
321        } else {
322            seen.push((ch, i));
323        }
324    }
325
326    // Step 2: Apply diagonal_view if any pairs exist, and compute unique_ids.
327    // If diagonal is needed and reduction follows, reduce directly from the diagonal
328    // view (skip materialization). Otherwise materialize into an owned array.
329    let (diag_arr, unique_ids): (Option<StridedArray<T>>, Vec<char>);
330    let mut diag_reduce_done = false;
331    if pairs.is_empty() {
332        diag_arr = None;
333        unique_ids = input_ids.to_vec();
334    } else {
335        let dv = src.diagonal_view(&pairs)?;
336        // Compute unique_ids: remove the higher-indexed axis of each pair from input_ids.
337        let axes_to_remove: Vec<usize> = pairs.iter().map(|&(_, b)| b).collect();
338        unique_ids = input_ids
339            .iter()
340            .enumerate()
341            .filter(|(i, _)| !axes_to_remove.contains(i))
342            .map(|(_, &ch)| ch)
343            .collect();
344        let dims = dv.dims().to_vec();
345        if dims.iter().product::<usize>() == 0 {
346            // Empty tensor: return immediately
347            let out_dims: Vec<usize> = output_ids
348                .iter()
349                .map(|oc| {
350                    let pos = unique_ids.iter().position(|c| c == oc).unwrap();
351                    dims[pos]
352                })
353                .collect();
354            return Ok(StridedArray::<T>::col_major(&out_dims));
355        }
356
357        // Check if there are axes to reduce — if so, reduce from the diagonal view
358        // directly without materializing first, saving one full copy.
359        let has_reduce = unique_ids.iter().any(|ch| !output_ids.contains(ch));
360        if has_reduce {
361            // Compute axes to reduce (within diagonal view's axes)
362            let mut axes_to_reduce_diag: Vec<usize> = Vec::new();
363            for (i, ch) in unique_ids.iter().enumerate() {
364                if !output_ids.contains(ch) {
365                    axes_to_reduce_diag.push(i);
366                }
367            }
368            axes_to_reduce_diag.sort_unstable();
369            axes_to_reduce_diag.reverse();
370
371            // First reduction reads directly from diagonal view (no copy!)
372            let mut reduced =
373                reduce_axis(&dv, axes_to_reduce_diag[0], |x| x, |a, b| a + b, T::zero())?;
374            // Subsequent reductions read from the owned result of the previous
375            for &ax in &axes_to_reduce_diag[1..] {
376                reduced = reduce_axis(&reduced.view(), ax, |x| x, |a, b| a + b, T::zero())?;
377            }
378            diag_arr = Some(reduced);
379            diag_reduce_done = true;
380        } else {
381            // No reduction follows — materialize for output
382            let mut owned = StridedArray::<T>::col_major(&dims);
383            copy_into(&mut owned.view_mut(), &dv)?;
384            diag_arr = Some(owned);
385        }
386    }
387
388    // Step 3: Find axes to sum out -- indices in unique_ids that are NOT in output_ids.
389    let mut axes_to_reduce: Vec<usize> = Vec::new();
390    for (i, ch) in unique_ids.iter().enumerate() {
391        if !output_ids.contains(ch) {
392            axes_to_reduce.push(i);
393        }
394    }
395
396    // Step 4: Reduce axes from back to front (to preserve axis indices).
397    // Sort axes in descending order so removing higher axes first doesn't shift lower indices.
398    axes_to_reduce.sort_unstable();
399    axes_to_reduce.reverse();
400
401    let mut current_arr: Option<StridedArray<T>> = None;
402
403    if !diag_reduce_done {
404        for &ax in axes_to_reduce.iter() {
405            let reduced = if let Some(ref arr) = current_arr {
406                reduce_axis(&arr.view(), ax, |x| x, |a, b| a + b, T::zero())?
407            } else if let Some(ref arr) = diag_arr {
408                reduce_axis(&arr.view(), ax, |x| x, |a, b| a + b, T::zero())?
409            } else {
410                reduce_axis(src, ax, |x| x, |a, b| a + b, T::zero())?
411            };
412            current_arr = Some(reduced);
413        }
414    }
415
416    // Compute current_ids after reductions.
417    let mut current_ids = unique_ids.clone();
418    if diag_reduce_done {
419        // Reductions already happened in the diagonal branch.
420        // Remove all reduced axes from current_ids.
421        let mut reduced_axes: Vec<usize> = Vec::new();
422        for (i, ch) in unique_ids.iter().enumerate() {
423            if !output_ids.contains(ch) {
424                reduced_axes.push(i);
425            }
426        }
427        reduced_axes.sort_unstable();
428        reduced_axes.reverse();
429        for ax in reduced_axes {
430            current_ids.remove(ax);
431        }
432    } else {
433        // Remove reduced axes (already sorted descending, so indices stay valid).
434        for &ax in axes_to_reduce.iter() {
435            current_ids.remove(ax);
436        }
437    }
438
439    // Step 5: Get the current result view for permutation.
440    // For permute-only (no diagonal, no reduction), use src directly to avoid double copy.
441
442    // Step 6: Handle scalar output (output_ids is empty).
443    if output_ids.is_empty() {
444        let result_arr = if let Some(arr) = current_arr {
445            arr
446        } else if let Some(arr) = diag_arr {
447            arr
448        } else {
449            // Scalar from source (shouldn't normally happen — scalar output implies reduction)
450            let mut owned = StridedArray::<T>::col_major(&[]);
451            owned.data_mut()[0] = unsafe { *src.data().as_ptr().offset(src.offset() as isize) };
452            owned
453        };
454        return Ok(result_arr);
455    }
456
457    // Step 7: Permute to output order if needed.
458    if current_ids == output_ids {
459        let result_arr = if let Some(arr) = current_arr {
460            arr
461        } else if let Some(arr) = diag_arr {
462            arr
463        } else {
464            // Identity: copy src to col-major owned array
465            let dims = src.dims().to_vec();
466            let mut owned = StridedArray::<T>::col_major(&dims);
467            copy_into(&mut owned.view_mut(), src)?;
468            owned
469        };
470        return Ok(result_arr);
471    }
472
473    // Compute permutation: for each output axis, find its position in current_ids.
474    let mut perm: Vec<usize> = Vec::with_capacity(output_ids.len());
475    for oc in output_ids {
476        match current_ids.iter().position(|c| c == oc) {
477            Some(pos) => perm.push(pos),
478            None => return Err(crate::EinsumError::OrphanOutputAxis(oc.to_string())),
479        }
480    }
481
482    // Permute from the best available source (avoid intermediate copy)
483    let source_view = if let Some(ref arr) = current_arr {
484        arr.view()
485    } else if let Some(ref arr) = diag_arr {
486        arr.view()
487    } else {
488        // No intermediate — permute src directly (single copy instead of double)
489        src.clone()
490    };
491
492    let permuted_view = source_view.permute(&perm)?;
493    let out_dims = permuted_view.dims().to_vec();
494    let mut out = StridedArray::<T>::col_major(&out_dims);
495    copy_into(&mut out.view_mut(), &permuted_view)?;
496
497    Ok(out)
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use approx::assert_abs_diff_eq;
504    use strided_view::StridedArray;
505
506    #[test]
507    fn test_permutation_only() {
508        // ijk -> kji
509        let arr = StridedArray::<f64>::from_fn_row_major(&[2, 3, 4], |idx| {
510            (idx[0] * 12 + idx[1] * 4 + idx[2]) as f64
511        });
512        let result =
513            single_tensor_einsum(&arr.view(), &['i', 'j', 'k'], &['k', 'j', 'i'], None).unwrap();
514        assert_eq!(result.dims(), &[4, 3, 2]);
515        assert_abs_diff_eq!(result.get(&[0, 0, 0]), 0.0);
516        assert_abs_diff_eq!(result.get(&[3, 2, 1]), 23.0);
517    }
518
519    #[test]
520    fn test_full_trace() {
521        // ii -> (scalar)
522        let mut arr = StridedArray::<f64>::col_major(&[3, 3]);
523        for i in 0..3 {
524            for j in 0..3 {
525                arr.set(&[i, j], (i * 10 + j) as f64);
526            }
527        }
528        // trace = A[0,0] + A[1,1] + A[2,2] = 0 + 11 + 22 = 33
529        let result = single_tensor_einsum(&arr.view(), &['i', 'i'], &[], None).unwrap();
530        assert_abs_diff_eq!(result.data()[0], 33.0);
531    }
532
533    #[test]
534    fn test_partial_trace() {
535        // iij -> j  (sum over diagonal i)
536        let arr = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
537            (idx[0] * 6 + idx[1] * 3 + idx[2]) as f64
538        });
539        // A[0,0,:] = [0,1,2], A[1,1,:] = [9,10,11]
540        // result[j] = A[0,0,j] + A[1,1,j] = [9, 11, 13]
541        let result = single_tensor_einsum(&arr.view(), &['i', 'i', 'j'], &['j'], None).unwrap();
542        assert_eq!(result.len(), 3);
543        let values: Vec<f64> = (0..3).map(|j| result.data()[j]).collect();
544        let mut sorted = values.clone();
545        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
546        assert_abs_diff_eq!(sorted[0], 9.0);
547        assert_abs_diff_eq!(sorted[1], 11.0);
548        assert_abs_diff_eq!(sorted[2], 13.0);
549    }
550
551    #[test]
552    fn test_diagonal_extraction() {
553        // ijj -> ij
554        let arr = StridedArray::<f64>::from_fn_row_major(&[2, 3, 3], |idx| {
555            (idx[0] * 9 + idx[1] * 3 + idx[2]) as f64
556        });
557        // result[i,j] = A[i,j,j]
558        let result =
559            single_tensor_einsum(&arr.view(), &['i', 'j', 'j'], &['i', 'j'], None).unwrap();
560        assert_eq!(result.dims(), &[2, 3]);
561        assert_abs_diff_eq!(result.get(&[0, 0]), 0.0); // A[0,0,0]
562        assert_abs_diff_eq!(result.get(&[0, 1]), 4.0); // A[0,1,1]
563        assert_abs_diff_eq!(result.get(&[0, 2]), 8.0); // A[0,2,2]
564        assert_abs_diff_eq!(result.get(&[1, 0]), 9.0); // A[1,0,0]
565        assert_abs_diff_eq!(result.get(&[1, 1]), 13.0); // A[1,1,1]
566        assert_abs_diff_eq!(result.get(&[1, 2]), 17.0); // A[1,2,2]
567    }
568
569    #[test]
570    fn test_sum_axis() {
571        // ij -> i (sum over j)
572        let arr =
573            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
574        // result[0] = 0+1+2 = 3, result[1] = 3+4+5 = 12
575        let result = single_tensor_einsum(&arr.view(), &['i', 'j'], &['i'], None).unwrap();
576        assert_eq!(result.len(), 2);
577        assert_abs_diff_eq!(result.data()[0], 3.0);
578        assert_abs_diff_eq!(result.data()[1], 12.0);
579    }
580}