1use std::collections::HashMap;
2
3use strided_kernel::{copy_into, reduce_axis};
4use strided_view::{ElementOpApply, Identity, StridedArray, StridedView};
5
6pub fn single_tensor_einsum<T>(
17 src: &StridedView<T>,
18 input_ids: &[char],
19 output_ids: &[char],
20 size_dict: Option<&HashMap<char, usize>>,
21) -> crate::Result<StridedArray<T>>
22where
23 T: Copy + ElementOpApply + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero + Default,
24{
25 let mut unique_output_ids: Vec<char> = Vec::new();
28 let mut duplicate_map: Vec<Vec<usize>> = Vec::new(); for (pos, &ch) in output_ids.iter().enumerate() {
30 if let Some(idx) = unique_output_ids.iter().position(|&c| c == ch) {
31 duplicate_map[idx].push(pos);
32 } else {
33 unique_output_ids.push(ch);
34 duplicate_map.push(vec![pos]);
35 }
36 }
37 let has_duplicate = duplicate_map.iter().any(|positions| positions.len() > 1);
38
39 let generative_labels: Vec<char> = unique_output_ids
41 .iter()
42 .filter(|ch| !input_ids.contains(ch))
43 .copied()
44 .collect();
45 let has_repeat = !generative_labels.is_empty();
46
47 if !has_repeat && !has_duplicate {
50 return single_tensor_einsum_classic(src, input_ids, output_ids);
51 }
52
53 let empty_dict = HashMap::new();
55 let sd = size_dict.unwrap_or(&empty_dict);
56 let mut gen_sizes: HashMap<char, usize> = HashMap::new();
57 for &ch in &generative_labels {
58 match sd.get(&ch) {
59 Some(&sz) => {
60 gen_sizes.insert(ch, sz);
61 }
62 None => return Err(crate::EinsumError::OrphanOutputAxis(ch.to_string())),
63 }
64 }
65
66 let core_output_ids: Vec<char> = unique_output_ids
69 .iter()
70 .filter(|ch| !generative_labels.contains(ch))
71 .copied()
72 .collect();
73
74 let core_result = if core_output_ids.is_empty() && input_ids.is_empty() {
75 let mut out = StridedArray::<T>::col_major(&[]);
77 out.data_mut()[0] = unsafe { *src.data().as_ptr().offset(src.offset() as isize) };
78 out
79 } else if core_output_ids.is_empty() && !input_ids.is_empty() {
80 single_tensor_einsum_classic(src, input_ids, &[])?
82 } else {
83 single_tensor_einsum_classic(src, input_ids, &core_output_ids)?
84 };
85
86 let intermediate = if has_repeat {
89 let mut inter_dims: Vec<usize> = Vec::new();
90 let mut inter_strides: Vec<isize> = Vec::new();
91 let core_view = core_result.view();
92 for &ch in &unique_output_ids {
93 if generative_labels.contains(&ch) {
94 inter_dims.push(gen_sizes[&ch]);
95 inter_strides.push(0); } else {
97 let core_pos = core_output_ids.iter().position(|&c| c == ch).unwrap();
98 inter_dims.push(core_view.dims()[core_pos]);
99 inter_strides.push(core_view.strides()[core_pos]);
100 }
101 }
102 let broadcast_view: StridedView<'_, T, Identity> = StridedView::new(
104 core_view.data(),
105 &inter_dims,
106 &inter_strides,
107 core_view.offset() as isize,
108 )?;
109 let mut materialized = StridedArray::<T>::col_major(&inter_dims);
110 copy_into(&mut materialized.view_mut(), &broadcast_view)?;
111 materialized
112 } else {
113 if core_output_ids == unique_output_ids {
115 core_result
116 } else {
117 let perm: Vec<usize> = unique_output_ids
119 .iter()
120 .map(|ch| core_output_ids.iter().position(|c| c == ch).unwrap())
121 .collect();
122 let permuted_view = core_result.view().permute(&perm)?;
123 let mut out = StridedArray::<T>::col_major(permuted_view.dims());
124 copy_into(&mut out.view_mut(), &permuted_view)?;
125 out
126 }
127 };
128
129 if !has_duplicate {
131 return Ok(intermediate);
132 }
133
134 let out_dims: Vec<usize> = output_ids
136 .iter()
137 .map(|ch| {
138 let idx = unique_output_ids.iter().position(|c| c == ch).unwrap();
139 intermediate.dims()[idx]
140 })
141 .collect();
142 let total: usize = out_dims.iter().product::<usize>().max(1);
143 let mut out_data = vec![T::zero(); total];
144 let out_strides = strided_view::col_major_strides(&out_dims);
145
146 let inter_dims = intermediate.dims().to_vec();
149 let inter_total: usize = inter_dims.iter().product::<usize>().max(1);
150 let inter_strides_cm = strided_view::col_major_strides(&inter_dims);
151
152 let mut inter_idx = vec![0usize; inter_dims.len()];
153 for flat in 0..inter_total {
154 let mut inter_flat = 0usize;
156 for d in 0..inter_dims.len() {
157 inter_flat += inter_idx[d] * inter_strides_cm[d] as usize;
158 }
159 let val = intermediate.data()[inter_flat];
160
161 let mut out_flat = 0usize;
164 for (out_pos, &ch) in output_ids.iter().enumerate() {
165 let inter_pos = unique_output_ids.iter().position(|&c| c == ch).unwrap();
166 out_flat += inter_idx[inter_pos] * out_strides[out_pos] as usize;
167 }
168 out_data[out_flat] = val;
169
170 if flat + 1 < inter_total {
172 for d in 0..inter_dims.len() {
173 inter_idx[d] += 1;
174 if inter_idx[d] < inter_dims[d] {
175 break;
176 }
177 inter_idx[d] = 0;
178 }
179 }
180 }
181
182 StridedArray::from_parts(out_data, &out_dims, &out_strides, 0)
183 .map_err(|e| crate::EinsumError::Strided(e))
184}
185
186fn single_tensor_einsum_classic<T>(
190 src: &StridedView<T>,
191 input_ids: &[char],
192 output_ids: &[char],
193) -> crate::Result<StridedArray<T>>
194where
195 T: Copy + ElementOpApply + Send + Sync + std::ops::Add<Output = T> + num_traits::Zero + Default,
196{
197 if output_ids.is_empty()
199 && !input_ids.is_empty()
200 && input_ids.iter().all(|&c| c == input_ids[0])
201 {
202 let n = src.dims()[0];
203 let diag_stride: isize = src.strides().iter().sum();
204 let ptr = src.data().as_ptr();
205 let mut offset = src.offset() as isize;
206 let mut acc = T::zero();
207 for _ in 0..n {
208 acc = acc + unsafe { *ptr.offset(offset) };
209 offset += diag_stride;
210 }
211 let mut out = StridedArray::<T>::col_major(&[]);
212 out.data_mut()[0] = acc;
213 return Ok(out);
214 }
215
216 {
219 let mut pair: Option<(char, usize, usize)> = None;
220 let mut seen_chars: Vec<(char, usize)> = Vec::new();
221 let mut has_triple = false;
222 for (i, &ch) in input_ids.iter().enumerate() {
223 if let Some(&(_, first)) = seen_chars.iter().find(|(c, _)| *c == ch) {
224 if pair.is_some() {
225 has_triple = true;
227 break;
228 }
229 pair = Some((ch, first, i));
230 } else {
231 seen_chars.push((ch, i));
232 }
233 }
234
235 if let Some((repeated_ch, pos0, pos1)) = pair {
236 if !has_triple && !output_ids.contains(&repeated_ch) {
237 let free_ids: Vec<(char, usize)> = input_ids
239 .iter()
240 .enumerate()
241 .filter(|(i, _)| *i != pos0 && *i != pos1)
242 .map(|(i, &ch)| (ch, i))
243 .collect();
244 let all_free_in_output = free_ids.iter().all(|(ch, _)| output_ids.contains(ch));
245
246 if all_free_in_output && free_ids.len() == output_ids.len() {
247 let n = src.dims()[pos0]; let diag_stride = src.strides()[pos0] + src.strides()[pos1];
249 let ptr = src.data().as_ptr();
250 let base_offset = src.offset() as isize;
251
252 let out_dims: Vec<usize> = output_ids
254 .iter()
255 .map(|oc| {
256 let (_, src_axis) = free_ids.iter().find(|(ch, _)| ch == oc).unwrap();
257 src.dims()[*src_axis]
258 })
259 .collect();
260 let out_strides_src: Vec<isize> = output_ids
261 .iter()
262 .map(|oc| {
263 let (_, src_axis) = free_ids.iter().find(|(ch, _)| ch == oc).unwrap();
264 src.strides()[*src_axis]
265 })
266 .collect();
267
268 let total_out: usize = out_dims.iter().product::<usize>().max(1);
269 let out_col_strides = strided_view::col_major_strides(&out_dims);
270 let mut out_data = vec![T::zero(); total_out];
271
272 let out_rank = out_dims.len();
274 let mut idx = vec![0usize; out_rank];
275 for flat in 0..total_out {
276 let mut src_off = base_offset;
278 for d in 0..out_rank {
279 src_off += idx[d] as isize * out_strides_src[d];
280 }
281 let mut acc = T::zero();
283 let mut diag_off = src_off;
284 for _ in 0..n {
285 acc = acc + unsafe { *ptr.offset(diag_off) };
286 diag_off += diag_stride;
287 }
288 let mut out_flat = 0usize;
290 for d in 0..out_rank {
291 out_flat += idx[d] * out_col_strides[d] as usize;
292 }
293 out_data[out_flat] = acc;
294
295 if flat + 1 < total_out {
297 for d in 0..out_rank {
298 idx[d] += 1;
299 if idx[d] < out_dims[d] {
300 break;
301 }
302 idx[d] = 0;
303 }
304 }
305 }
306
307 return StridedArray::from_parts(out_data, &out_dims, &out_col_strides, 0)
308 .map_err(|e| crate::EinsumError::Strided(e));
309 }
310 }
311 }
312 }
313
314 let mut pairs: Vec<(usize, usize)> = Vec::new();
317 let mut seen: Vec<(char, usize)> = Vec::new();
318 for (i, &ch) in input_ids.iter().enumerate() {
319 if let Some(&(_, first)) = seen.iter().find(|(c, _)| *c == ch) {
320 pairs.push((first, i));
321 } else {
322 seen.push((ch, i));
323 }
324 }
325
326 let (diag_arr, unique_ids): (Option<StridedArray<T>>, Vec<char>);
330 let mut diag_reduce_done = false;
331 if pairs.is_empty() {
332 diag_arr = None;
333 unique_ids = input_ids.to_vec();
334 } else {
335 let dv = src.diagonal_view(&pairs)?;
336 let axes_to_remove: Vec<usize> = pairs.iter().map(|&(_, b)| b).collect();
338 unique_ids = input_ids
339 .iter()
340 .enumerate()
341 .filter(|(i, _)| !axes_to_remove.contains(i))
342 .map(|(_, &ch)| ch)
343 .collect();
344 let dims = dv.dims().to_vec();
345 if dims.iter().product::<usize>() == 0 {
346 let out_dims: Vec<usize> = output_ids
348 .iter()
349 .map(|oc| {
350 let pos = unique_ids.iter().position(|c| c == oc).unwrap();
351 dims[pos]
352 })
353 .collect();
354 return Ok(StridedArray::<T>::col_major(&out_dims));
355 }
356
357 let has_reduce = unique_ids.iter().any(|ch| !output_ids.contains(ch));
360 if has_reduce {
361 let mut axes_to_reduce_diag: Vec<usize> = Vec::new();
363 for (i, ch) in unique_ids.iter().enumerate() {
364 if !output_ids.contains(ch) {
365 axes_to_reduce_diag.push(i);
366 }
367 }
368 axes_to_reduce_diag.sort_unstable();
369 axes_to_reduce_diag.reverse();
370
371 let mut reduced =
373 reduce_axis(&dv, axes_to_reduce_diag[0], |x| x, |a, b| a + b, T::zero())?;
374 for &ax in &axes_to_reduce_diag[1..] {
376 reduced = reduce_axis(&reduced.view(), ax, |x| x, |a, b| a + b, T::zero())?;
377 }
378 diag_arr = Some(reduced);
379 diag_reduce_done = true;
380 } else {
381 let mut owned = StridedArray::<T>::col_major(&dims);
383 copy_into(&mut owned.view_mut(), &dv)?;
384 diag_arr = Some(owned);
385 }
386 }
387
388 let mut axes_to_reduce: Vec<usize> = Vec::new();
390 for (i, ch) in unique_ids.iter().enumerate() {
391 if !output_ids.contains(ch) {
392 axes_to_reduce.push(i);
393 }
394 }
395
396 axes_to_reduce.sort_unstable();
399 axes_to_reduce.reverse();
400
401 let mut current_arr: Option<StridedArray<T>> = None;
402
403 if !diag_reduce_done {
404 for &ax in axes_to_reduce.iter() {
405 let reduced = if let Some(ref arr) = current_arr {
406 reduce_axis(&arr.view(), ax, |x| x, |a, b| a + b, T::zero())?
407 } else if let Some(ref arr) = diag_arr {
408 reduce_axis(&arr.view(), ax, |x| x, |a, b| a + b, T::zero())?
409 } else {
410 reduce_axis(src, ax, |x| x, |a, b| a + b, T::zero())?
411 };
412 current_arr = Some(reduced);
413 }
414 }
415
416 let mut current_ids = unique_ids.clone();
418 if diag_reduce_done {
419 let mut reduced_axes: Vec<usize> = Vec::new();
422 for (i, ch) in unique_ids.iter().enumerate() {
423 if !output_ids.contains(ch) {
424 reduced_axes.push(i);
425 }
426 }
427 reduced_axes.sort_unstable();
428 reduced_axes.reverse();
429 for ax in reduced_axes {
430 current_ids.remove(ax);
431 }
432 } else {
433 for &ax in axes_to_reduce.iter() {
435 current_ids.remove(ax);
436 }
437 }
438
439 if output_ids.is_empty() {
444 let result_arr = if let Some(arr) = current_arr {
445 arr
446 } else if let Some(arr) = diag_arr {
447 arr
448 } else {
449 let mut owned = StridedArray::<T>::col_major(&[]);
451 owned.data_mut()[0] = unsafe { *src.data().as_ptr().offset(src.offset() as isize) };
452 owned
453 };
454 return Ok(result_arr);
455 }
456
457 if current_ids == output_ids {
459 let result_arr = if let Some(arr) = current_arr {
460 arr
461 } else if let Some(arr) = diag_arr {
462 arr
463 } else {
464 let dims = src.dims().to_vec();
466 let mut owned = StridedArray::<T>::col_major(&dims);
467 copy_into(&mut owned.view_mut(), src)?;
468 owned
469 };
470 return Ok(result_arr);
471 }
472
473 let mut perm: Vec<usize> = Vec::with_capacity(output_ids.len());
475 for oc in output_ids {
476 match current_ids.iter().position(|c| c == oc) {
477 Some(pos) => perm.push(pos),
478 None => return Err(crate::EinsumError::OrphanOutputAxis(oc.to_string())),
479 }
480 }
481
482 let source_view = if let Some(ref arr) = current_arr {
484 arr.view()
485 } else if let Some(ref arr) = diag_arr {
486 arr.view()
487 } else {
488 src.clone()
490 };
491
492 let permuted_view = source_view.permute(&perm)?;
493 let out_dims = permuted_view.dims().to_vec();
494 let mut out = StridedArray::<T>::col_major(&out_dims);
495 copy_into(&mut out.view_mut(), &permuted_view)?;
496
497 Ok(out)
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use approx::assert_abs_diff_eq;
504 use strided_view::StridedArray;
505
506 #[test]
507 fn test_permutation_only() {
508 let arr = StridedArray::<f64>::from_fn_row_major(&[2, 3, 4], |idx| {
510 (idx[0] * 12 + idx[1] * 4 + idx[2]) as f64
511 });
512 let result =
513 single_tensor_einsum(&arr.view(), &['i', 'j', 'k'], &['k', 'j', 'i'], None).unwrap();
514 assert_eq!(result.dims(), &[4, 3, 2]);
515 assert_abs_diff_eq!(result.get(&[0, 0, 0]), 0.0);
516 assert_abs_diff_eq!(result.get(&[3, 2, 1]), 23.0);
517 }
518
519 #[test]
520 fn test_full_trace() {
521 let mut arr = StridedArray::<f64>::col_major(&[3, 3]);
523 for i in 0..3 {
524 for j in 0..3 {
525 arr.set(&[i, j], (i * 10 + j) as f64);
526 }
527 }
528 let result = single_tensor_einsum(&arr.view(), &['i', 'i'], &[], None).unwrap();
530 assert_abs_diff_eq!(result.data()[0], 33.0);
531 }
532
533 #[test]
534 fn test_partial_trace() {
535 let arr = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
537 (idx[0] * 6 + idx[1] * 3 + idx[2]) as f64
538 });
539 let result = single_tensor_einsum(&arr.view(), &['i', 'i', 'j'], &['j'], None).unwrap();
542 assert_eq!(result.len(), 3);
543 let values: Vec<f64> = (0..3).map(|j| result.data()[j]).collect();
544 let mut sorted = values.clone();
545 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
546 assert_abs_diff_eq!(sorted[0], 9.0);
547 assert_abs_diff_eq!(sorted[1], 11.0);
548 assert_abs_diff_eq!(sorted[2], 13.0);
549 }
550
551 #[test]
552 fn test_diagonal_extraction() {
553 let arr = StridedArray::<f64>::from_fn_row_major(&[2, 3, 3], |idx| {
555 (idx[0] * 9 + idx[1] * 3 + idx[2]) as f64
556 });
557 let result =
559 single_tensor_einsum(&arr.view(), &['i', 'j', 'j'], &['i', 'j'], None).unwrap();
560 assert_eq!(result.dims(), &[2, 3]);
561 assert_abs_diff_eq!(result.get(&[0, 0]), 0.0); assert_abs_diff_eq!(result.get(&[0, 1]), 4.0); assert_abs_diff_eq!(result.get(&[0, 2]), 8.0); assert_abs_diff_eq!(result.get(&[1, 0]), 9.0); assert_abs_diff_eq!(result.get(&[1, 1]), 13.0); assert_abs_diff_eq!(result.get(&[1, 2]), 17.0); }
568
569 #[test]
570 fn test_sum_axis() {
571 let arr =
573 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
574 let result = single_tensor_einsum(&arr.view(), &['i', 'j'], &['i'], None).unwrap();
576 assert_eq!(result.len(), 2);
577 assert_abs_diff_eq!(result.data()[0], 3.0);
578 assert_abs_diff_eq!(result.data()[1], 12.0);
579 }
580}