1#[cfg(all(feature = "faer", feature = "blas"))]
25compile_error!("Features `faer` and `blas` are mutually exclusive. Use one or the other.");
26
27#[cfg(all(feature = "faer", feature = "blas-inject"))]
28compile_error!("Features `faer` and `blas-inject` are mutually exclusive.");
29
30#[cfg(all(feature = "blas", feature = "blas-inject"))]
31compile_error!("Features `blas` and `blas-inject` are mutually exclusive.");
32
33#[cfg(all(feature = "blas-inject", not(feature = "blas")))]
34extern crate cblas_inject as cblas_sys;
35#[cfg(all(feature = "blas", not(feature = "blas-inject")))]
36extern crate cblas_sys;
37
38#[cfg(all(
39 not(feature = "faer"),
40 any(
41 all(feature = "blas", not(feature = "blas-inject")),
42 all(feature = "blas-inject", not(feature = "blas"))
43 )
44))]
45pub mod bgemm_blas;
46
47#[cfg(feature = "faer")]
48pub mod bgemm_faer;
50pub mod bgemm_naive;
52pub mod contiguous;
54pub mod plan;
56pub mod trace;
58pub mod util;
60
61pub mod backend;
63
64use std::any::TypeId;
65use std::fmt::Debug;
66use std::hash::Hash;
67
68use strided_kernel::zip_map2_into;
69#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
70use strided_view::StridedArray;
71use strided_view::{Adjoint, Conj, ElementOp, ElementOpApply, StridedView, StridedViewMut};
72
73pub use strided_traits::ScalarBase;
74
75pub use backend::Backend;
76pub use plan::Einsum2Plan;
77
78pub trait AxisId: Clone + Eq + Hash + Debug {}
80impl<T: Clone + Eq + Hash + Debug> AxisId for T {}
81
82#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
91pub trait Scalar: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
92
93#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
94impl<T> Scalar for T where T: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
95
96#[cfg(all(
100 not(feature = "faer"),
101 any(
102 all(feature = "blas", not(feature = "blas-inject")),
103 all(feature = "blas-inject", not(feature = "blas"))
104 )
105))]
106pub trait Scalar: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
107
108#[cfg(all(
109 not(feature = "faer"),
110 any(
111 all(feature = "blas", not(feature = "blas-inject")),
112 all(feature = "blas-inject", not(feature = "blas"))
113 )
114))]
115impl<T> Scalar for T where T: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
116
117#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
119pub trait Scalar: ScalarBase + ElementOpApply {}
120
121#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
122impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
123
124#[cfg(any(
129 all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
130 all(feature = "blas", feature = "blas-inject")
131))]
132pub trait Scalar: ScalarBase + ElementOpApply {}
133
134#[cfg(any(
135 all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
136 all(feature = "blas", feature = "blas-inject")
137))]
138impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
139
140#[derive(Debug, thiserror::Error)]
142pub enum EinsumError {
143 #[error("duplicate axis label: {0}")]
144 DuplicateAxis(String),
145 #[error("output axis {0} not found in any input")]
146 OrphanOutputAxis(String),
147 #[error("dimension mismatch for axis {axis:?}: {dim_a} vs {dim_b}")]
148 DimensionMismatch {
149 axis: String,
150 dim_a: usize,
151 dim_b: usize,
152 },
153 #[error(transparent)]
154 Strided(#[from] strided_view::StridedError),
155}
156
157pub type Result<T> = std::result::Result<T, EinsumError>;
159
160fn op_is_conj<Op: 'static>() -> bool {
169 TypeId::of::<Op>() == TypeId::of::<Conj>() || TypeId::of::<Op>() == TypeId::of::<Adjoint>()
170}
171
172pub fn einsum2_into<T: Scalar, OpA, OpB, ID: AxisId>(
183 c: StridedViewMut<T>,
184 a: &StridedView<T, OpA>,
185 b: &StridedView<T, OpB>,
186 ic: &[ID],
187 ia: &[ID],
188 ib: &[ID],
189 alpha: T,
190 beta: T,
191) -> Result<()>
192where
193 OpA: ElementOp<T> + 'static,
194 OpB: ElementOp<T> + 'static,
195{
196 let plan = Einsum2Plan::new(ia, ib, ic)?;
198
199 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
201
202 let left_trace = trace::find_trace_indices(ia, ib, ic);
207 let (a_buf, conj_a) = if !left_trace.is_empty() {
208 (Some(trace::reduce_trace_axes(a, &left_trace)?), false)
209 } else {
210 (None, op_is_conj::<OpA>())
211 };
212
213 let a_view: StridedView<T> = match a_buf.as_ref() {
214 Some(buf) => buf.view(),
215 None => StridedView::new(a.data(), a.dims(), a.strides(), a.offset())
216 .expect("strip_op_view: metadata already validated"),
217 };
218
219 let right_trace = trace::find_trace_indices(ib, ia, ic);
220 let (b_buf, conj_b) = if !right_trace.is_empty() {
221 (Some(trace::reduce_trace_axes(b, &right_trace)?), false)
222 } else {
223 (None, op_is_conj::<OpB>())
224 };
225
226 let b_view: StridedView<T> = match b_buf.as_ref() {
227 Some(buf) => buf.view(),
228 None => StridedView::new(b.data(), b.dims(), b.strides(), b.offset())
229 .expect("strip_op_view: metadata already validated"),
230 };
231
232 #[cfg(any(
234 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
235 all(
236 not(feature = "faer"),
237 any(
238 all(feature = "blas", not(feature = "blas-inject")),
239 all(feature = "blas-inject", not(feature = "blas"))
240 )
241 )
242 ))]
243 {
244 let conj_fn = make_conj_fn::<T>();
245 einsum2_dispatch::<T, backend::ActiveBackend, _>(
246 c, &a_view, &b_view, &plan, alpha, beta, conj_a, conj_b, conj_fn,
247 )?;
248 }
249
250 #[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
251 {
252 let a_perm = a_view.permute(&plan.left_perm)?;
253 let b_perm = b_view.permute(&plan.right_perm)?;
254 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
255
256 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
257 let mul_fn = move |a_val: T, b_val: T| -> T {
258 let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
259 let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
260 alpha * a_c * b_c
261 };
262 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
263 return Ok(());
264 }
265
266 bgemm_naive::bgemm_strided_into(
267 &mut c_perm,
268 &a_perm,
269 &b_perm,
270 plan.batch.len(),
271 plan.lo.len(),
272 plan.ro.len(),
273 plan.sum.len(),
274 alpha,
275 beta,
276 conj_a,
277 conj_b,
278 )?;
279 }
280
281 Ok(())
282}
283
284pub fn einsum2_naive_into<T, ID, MapA, MapB>(
292 c: StridedViewMut<T>,
293 a: &StridedView<T>,
294 b: &StridedView<T>,
295 ic: &[ID],
296 ia: &[ID],
297 ib: &[ID],
298 alpha: T,
299 beta: T,
300 map_a: MapA,
301 map_b: MapB,
302) -> Result<()>
303where
304 T: ScalarBase,
305 ID: AxisId,
306 MapA: Fn(T) -> T + strided_kernel::MaybeSync,
307 MapB: Fn(T) -> T + strided_kernel::MaybeSync,
308{
309 let plan = Einsum2Plan::new(ia, ib, ic)?;
310 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
311
312 let left_trace = trace::find_trace_indices(ia, ib, ic);
316 let (a_buf, use_map_a) = if !left_trace.is_empty() {
317 let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(a.dims()) };
318 strided_kernel::map_into(&mut mapped.view_mut(), a, &map_a)?;
319 let reduced = trace::reduce_trace_axes(&mapped.view(), &left_trace)?;
320 (Some(reduced), false)
321 } else {
322 (None, true)
323 };
324 let a_view: StridedView<T> = match a_buf.as_ref() {
325 Some(buf) => buf.view(),
326 None => a.clone(),
327 };
328
329 let right_trace = trace::find_trace_indices(ib, ia, ic);
330 let (b_buf, use_map_b) = if !right_trace.is_empty() {
331 let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(b.dims()) };
332 strided_kernel::map_into(&mut mapped.view_mut(), b, &map_b)?;
333 let reduced = trace::reduce_trace_axes(&mapped.view(), &right_trace)?;
334 (Some(reduced), false)
335 } else {
336 (None, true)
337 };
338 let b_view: StridedView<T> = match b_buf.as_ref() {
339 Some(buf) => buf.view(),
340 None => b.clone(),
341 };
342
343 let a_perm = a_view.permute(&plan.left_perm)?;
344 let b_perm = b_view.permute(&plan.right_perm)?;
345 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
346
347 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
349 let mul_fn = move |a_val: T, b_val: T| -> T {
350 let a_c = if use_map_a { map_a(a_val) } else { a_val };
351 let b_c = if use_map_b { map_b(b_val) } else { b_val };
352 alpha * a_c * b_c
353 };
354 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
355 return Ok(());
356 }
357
358 let final_map_a: Box<dyn Fn(T) -> T> = if use_map_a {
359 Box::new(map_a)
360 } else {
361 Box::new(|x| x)
362 };
363 let final_map_b: Box<dyn Fn(T) -> T> = if use_map_b {
364 Box::new(map_b)
365 } else {
366 Box::new(|x| x)
367 };
368
369 bgemm_naive::bgemm_strided_into_with_map(
370 &mut c_perm,
371 &a_perm,
372 &b_perm,
373 plan.batch.len(),
374 plan.lo.len(),
375 plan.ro.len(),
376 plan.sum.len(),
377 alpha,
378 beta,
379 final_map_a,
380 final_map_b,
381 )?;
382
383 Ok(())
384}
385
386pub fn einsum2_with_backend_into<T, B, ID>(
395 c: StridedViewMut<T>,
396 a: &StridedView<T>,
397 b: &StridedView<T>,
398 ic: &[ID],
399 ia: &[ID],
400 ib: &[ID],
401 alpha: T,
402 beta: T,
403) -> Result<()>
404where
405 T: ScalarBase,
406 B: Backend<T>,
407 ID: AxisId,
408{
409 let plan = Einsum2Plan::new(ia, ib, ic)?;
410 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
411
412 let left_trace = trace::find_trace_indices(ia, ib, ic);
414 let a_buf = if !left_trace.is_empty() {
415 Some(trace::reduce_trace_axes(a, &left_trace)?)
416 } else {
417 None
418 };
419 let a_view: StridedView<T> = match a_buf.as_ref() {
420 Some(buf) => buf.view(),
421 None => a.clone(),
422 };
423
424 let right_trace = trace::find_trace_indices(ib, ia, ic);
425 let b_buf = if !right_trace.is_empty() {
426 Some(trace::reduce_trace_axes(b, &right_trace)?)
427 } else {
428 None
429 };
430 let b_view: StridedView<T> = match b_buf.as_ref() {
431 Some(buf) => buf.view(),
432 None => b.clone(),
433 };
434
435 einsum2_dispatch::<T, B, _>(c, &a_view, &b_view, &plan, alpha, beta, false, false, None)
437}
438
439#[cfg(any(
445 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
446 all(
447 not(feature = "faer"),
448 any(
449 all(feature = "blas", not(feature = "blas-inject")),
450 all(feature = "blas-inject", not(feature = "blas"))
451 )
452 )
453))]
454fn make_conj_fn<T: Scalar>() -> Option<fn(T) -> T> {
455 if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
456 Some(|x| Conj::apply(x))
457 } else {
458 None
459 }
460}
461
462fn einsum2_dispatch<T, B, ID>(
475 c: StridedViewMut<T>,
476 a: &StridedView<T>,
477 b: &StridedView<T>,
478 plan: &Einsum2Plan<ID>,
479 alpha: T,
480 beta: T,
481 conj_a: bool,
482 conj_b: bool,
483 conj_fn: Option<fn(T) -> T>,
484) -> Result<()>
485where
486 T: ScalarBase,
487 B: Backend<T>,
488 ID: AxisId,
489{
490 let a_perm = a.permute(&plan.left_perm)?;
492 let b_perm = b.permute(&plan.right_perm)?;
493 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
494
495 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
497 if !conj_a && !conj_b && alpha == T::one() {
498 zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
499 } else if !conj_a && !conj_b {
500 let mul_fn = move |a_val: T, b_val: T| -> T { alpha * a_val * b_val };
501 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
502 } else {
503 let conj_fn = conj_fn.unwrap_or(|x| x);
504 let mul_fn = move |a_val: T, b_val: T| -> T {
505 let a_c = if conj_a { conj_fn(a_val) } else { a_val };
506 let b_c = if conj_b { conj_fn(b_val) } else { b_val };
507 alpha * a_c * b_c
508 };
509 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
510 }
511 return Ok(());
512 }
513
514 let n_lo = plan.lo.len();
516 let n_ro = plan.ro.len();
517 let n_sum = plan.sum.len();
518 let use_pool = true;
519 let materialize = if B::MATERIALIZES_CONJ { conj_fn } else { None };
520
521 let a_op = contiguous::prepare_input_view(
522 &a_perm,
523 n_lo,
524 n_sum,
525 conj_a,
526 B::REQUIRES_UNIT_STRIDE,
527 use_pool,
528 materialize,
529 )?;
530 let b_op = contiguous::prepare_input_view(
531 &b_perm,
532 n_sum,
533 n_ro,
534 conj_b,
535 B::REQUIRES_UNIT_STRIDE,
536 use_pool,
537 materialize,
538 )?;
539 let mut c_op = contiguous::prepare_output_view(
540 &mut c_perm,
541 n_lo,
542 n_ro,
543 beta,
544 B::REQUIRES_UNIT_STRIDE,
545 use_pool,
546 )?;
547
548 let lo_dims = &a_perm.dims()[..n_lo];
550 let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
551 let batch_dims = &a_perm.dims()[n_lo + n_sum..];
552 let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
553 let m: usize = lo_dims.iter().product::<usize>().max(1);
554 let k: usize = sum_dims.iter().product::<usize>().max(1);
555 let n: usize = ro_dims.iter().product::<usize>().max(1);
556
557 B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
559
560 c_op.finalize_into(&mut c_perm)?;
562
563 Ok(())
564}
565
566#[cfg(any(
575 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
576 all(
577 not(feature = "faer"),
578 any(
579 all(feature = "blas", not(feature = "blas-inject")),
580 all(feature = "blas-inject", not(feature = "blas"))
581 )
582 )
583))]
584pub fn einsum2_into_owned<T: Scalar, ID: AxisId>(
585 c: StridedViewMut<T>,
586 a: StridedArray<T>,
587 b: StridedArray<T>,
588 ic: &[ID],
589 ia: &[ID],
590 ib: &[ID],
591 alpha: T,
592 beta: T,
593 conj_a: bool,
594 conj_b: bool,
595) -> Result<()>
596where
597 backend::ActiveBackend: Backend<T>,
598{
599 let plan = Einsum2Plan::new(ia, ib, ic)?;
601
602 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
604
605 let left_trace = trace::find_trace_indices(ia, ib, ic);
609 let (a_for_gemm, conj_a_final) = if !left_trace.is_empty() {
610 (trace::reduce_trace_axes(&a.view(), &left_trace)?, false)
611 } else {
612 (a, conj_a)
613 };
614
615 let right_trace = trace::find_trace_indices(ib, ia, ic);
616 let (b_for_gemm, conj_b_final) = if !right_trace.is_empty() {
617 (trace::reduce_trace_axes(&b.view(), &right_trace)?, false)
618 } else {
619 (b, conj_b)
620 };
621
622 let a_perm = a_for_gemm.permuted(&plan.left_perm)?;
624 let b_perm = b_for_gemm.permuted(&plan.right_perm)?;
625 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
626
627 let n_lo = plan.lo.len();
628 let n_ro = plan.ro.len();
629 let n_sum = plan.sum.len();
630
631 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
633 let mul_fn = move |a_val: T, b_val: T| -> T {
634 let a_c = if conj_a_final {
635 Conj::apply(a_val)
636 } else {
637 a_val
638 };
639 let b_c = if conj_b_final {
640 Conj::apply(b_val)
641 } else {
642 b_val
643 };
644 alpha * a_c * b_c
645 };
646 zip_map2_into(&mut c_perm, &a_perm.view(), &b_perm.view(), mul_fn)?;
647 return Ok(());
648 }
649
650 let a_dims_perm = a_perm.dims().to_vec();
652 let b_dims_perm = b_perm.dims().to_vec();
653
654 let lo_dims = &a_dims_perm[..n_lo];
655 let sum_dims = &a_dims_perm[n_lo..n_lo + n_sum];
656 let batch_dims = a_dims_perm[n_lo + n_sum..].to_vec();
657 let ro_dims = &b_dims_perm[n_sum..n_sum + n_ro];
658 let m: usize = lo_dims.iter().product::<usize>().max(1);
659 let k: usize = sum_dims.iter().product::<usize>().max(1);
660 let n: usize = ro_dims.iter().product::<usize>().max(1);
661
662 let conj_fn = make_conj_fn::<T>();
664 let materialize = if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
665 conj_fn
666 } else {
667 None
668 };
669 let use_pool = true;
670 let unit_stride = <backend::ActiveBackend as Backend<T>>::REQUIRES_UNIT_STRIDE;
671 let a_op = contiguous::prepare_input_owned(
672 a_perm,
673 n_lo,
674 n_sum,
675 conj_a_final,
676 unit_stride,
677 use_pool,
678 materialize,
679 )?;
680 let b_op = contiguous::prepare_input_owned(
681 b_perm,
682 n_sum,
683 n_ro,
684 conj_b_final,
685 unit_stride,
686 use_pool,
687 materialize,
688 )?;
689 let mut c_op =
690 contiguous::prepare_output_view(&mut c_perm, n_lo, n_ro, beta, unit_stride, use_pool)?;
691
692 backend::ActiveBackend::bgemm_contiguous_into(
694 &mut c_op,
695 &a_op,
696 &b_op,
697 &batch_dims,
698 m,
699 n,
700 k,
701 alpha,
702 beta,
703 )?;
704
705 c_op.finalize_into(&mut c_perm)?;
707
708 Ok(())
709}
710
711fn validate_dimensions<ID: AxisId>(
713 plan: &Einsum2Plan<ID>,
714 a_dims: &[usize],
715 b_dims: &[usize],
716 c_dims: &[usize],
717 ia: &[ID],
718 ib: &[ID],
719 ic: &[ID],
720) -> Result<()> {
721 let find_dim = |labels: &[ID], dims: &[usize], id: &ID| -> usize {
722 labels
723 .iter()
724 .position(|x| x == id)
725 .map(|i| dims[i])
726 .unwrap()
727 };
728
729 for id in &plan.batch {
731 let da = find_dim(ia, a_dims, id);
732 let db = find_dim(ib, b_dims, id);
733 let dc = find_dim(ic, c_dims, id);
734 if da != db || da != dc {
735 return Err(EinsumError::DimensionMismatch {
736 axis: format!("{:?}", id),
737 dim_a: da,
738 dim_b: db,
739 });
740 }
741 }
742
743 for id in &plan.sum {
745 let da = find_dim(ia, a_dims, id);
746 let db = find_dim(ib, b_dims, id);
747 if da != db {
748 return Err(EinsumError::DimensionMismatch {
749 axis: format!("{:?}", id),
750 dim_a: da,
751 dim_b: db,
752 });
753 }
754 }
755
756 for id in &plan.lo {
758 let da = find_dim(ia, a_dims, id);
759 let dc = find_dim(ic, c_dims, id);
760 if da != dc {
761 return Err(EinsumError::DimensionMismatch {
762 axis: format!("{:?}", id),
763 dim_a: da,
764 dim_b: dc,
765 });
766 }
767 }
768
769 for id in &plan.ro {
771 let db = find_dim(ib, b_dims, id);
772 let dc = find_dim(ic, c_dims, id);
773 if db != dc {
774 return Err(EinsumError::DimensionMismatch {
775 axis: format!("{:?}", id),
776 dim_a: db,
777 dim_b: dc,
778 });
779 }
780 }
781
782 Ok(())
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788 use strided_view::StridedArray;
789
790 #[test]
791 fn test_matmul_ij_jk_ik() {
792 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
794 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
795 });
796 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
797 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
798 });
799 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
800
801 einsum2_into(
802 c.view_mut(),
803 &a.view(),
804 &b.view(),
805 &['i', 'k'],
806 &['i', 'j'],
807 &['j', 'k'],
808 1.0,
809 0.0,
810 )
811 .unwrap();
812
813 assert_eq!(c.get(&[0, 0]), 19.0);
814 assert_eq!(c.get(&[0, 1]), 22.0);
815 assert_eq!(c.get(&[1, 0]), 43.0);
816 assert_eq!(c.get(&[1, 1]), 50.0);
817 }
818
819 #[test]
820 fn test_matmul_rect() {
821 let a =
823 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
824 let b =
825 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
826 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
827
828 einsum2_into(
829 c.view_mut(),
830 &a.view(),
831 &b.view(),
832 &['i', 'k'],
833 &['i', 'j'],
834 &['j', 'k'],
835 1.0,
836 0.0,
837 )
838 .unwrap();
839
840 assert_eq!(c.get(&[0, 0]), 38.0);
842 assert_eq!(c.get(&[1, 3]), 128.0);
843 }
844
845 #[test]
846 fn test_batched_matmul() {
847 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
849 (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
850 });
851 let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
852 (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
853 });
854 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
855
856 einsum2_into(
857 c.view_mut(),
858 &a.view(),
859 &b.view(),
860 &['b', 'i', 'k'],
861 &['b', 'i', 'j'],
862 &['b', 'j', 'k'],
863 1.0,
864 0.0,
865 )
866 .unwrap();
867
868 assert_eq!(c.get(&[0, 0, 0]), 22.0);
871 }
872
873 #[test]
874 fn test_batched_matmul_col_major_output() {
875 let a_data = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
877 let b_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
878 let a = StridedArray::<f64>::from_parts(a_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
879 let b = StridedArray::<f64>::from_parts(b_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
880 let mut c = StridedArray::<f64>::col_major(&[2, 2, 2]);
881
882 einsum2_into(
883 c.view_mut(),
884 &a.view(),
885 &b.view(),
886 &['b', 'i', 'k'],
887 &['b', 'i', 'j'],
888 &['b', 'j', 'k'],
889 1.0,
890 0.0,
891 )
892 .unwrap();
893
894 assert_eq!(c.get(&[0, 0, 0]), 1.0);
896 assert_eq!(c.get(&[0, 0, 1]), 2.0);
897 assert_eq!(c.get(&[0, 1, 0]), 3.0);
898 assert_eq!(c.get(&[0, 1, 1]), 4.0);
899 assert_eq!(c.get(&[1, 0, 0]), 10.0);
901 assert_eq!(c.get(&[1, 1, 1]), 16.0);
902 }
903
904 #[test]
905 fn test_outer_product() {
906 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
908 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
909 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
910
911 einsum2_into(
912 c.view_mut(),
913 &a.view(),
914 &b.view(),
915 &['i', 'j'],
916 &['i'],
917 &['j'],
918 1.0,
919 0.0,
920 )
921 .unwrap();
922
923 assert_eq!(c.get(&[0, 0]), 1.0);
924 assert_eq!(c.get(&[2, 3]), 12.0);
925 }
926
927 #[test]
928 fn test_dot_product() {
929 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
931 let b = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
932 let mut c = StridedArray::<f64>::row_major(&[]);
933
934 einsum2_into(
935 c.view_mut(),
936 &a.view(),
937 &b.view(),
938 &[] as &[char],
939 &['i'],
940 &['i'],
941 1.0,
942 0.0,
943 )
944 .unwrap();
945
946 assert_eq!(c.get(&[]), 14.0);
948 }
949
950 #[test]
951 fn test_alpha_beta() {
952 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
954 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
956 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
957 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
958 });
959 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
960 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
961 });
962
963 einsum2_into(
964 c.view_mut(),
965 &a.view(),
966 &b.view(),
967 &['i', 'k'],
968 &['i', 'j'],
969 &['j', 'k'],
970 2.0,
971 3.0,
972 )
973 .unwrap();
974
975 assert_eq!(c.get(&[0, 0]), 32.0); assert_eq!(c.get(&[1, 1]), 128.0); }
979
980 #[test]
981 fn test_transposed_output() {
982 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
984 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
985 });
986 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
987 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
988 });
989 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
990
991 einsum2_into(
992 c.view_mut(),
993 &a.view(),
994 &b.view(),
995 &['k', 'i'], &['i', 'j'],
997 &['j', 'k'],
998 1.0,
999 0.0,
1000 )
1001 .unwrap();
1002
1003 assert_eq!(c.get(&[0, 0]), 19.0); assert_eq!(c.get(&[0, 1]), 43.0); assert_eq!(c.get(&[1, 0]), 22.0); assert_eq!(c.get(&[1, 1]), 50.0); }
1010
1011 #[test]
1012 fn test_left_trace() {
1013 let a =
1016 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1017 let b =
1020 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1021 let mut c = StridedArray::<f64>::row_major(&[2]);
1023
1024 einsum2_into(
1025 c.view_mut(),
1026 &a.view(),
1027 &b.view(),
1028 &['k'],
1029 &['i', 'j'],
1030 &['j', 'k'],
1031 1.0,
1032 0.0,
1033 )
1034 .unwrap();
1035
1036 assert_eq!(c.get(&[0]), 71.0);
1040 assert_eq!(c.get(&[1]), 92.0);
1041 }
1042
1043 #[test]
1044 fn test_u32_labels() {
1045 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1047 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1048 });
1049 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1050 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1051 });
1052 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1053
1054 einsum2_into(
1055 c.view_mut(),
1056 &a.view(),
1057 &b.view(),
1058 &[0u32, 2],
1059 &[0u32, 1],
1060 &[1u32, 2],
1061 1.0,
1062 0.0,
1063 )
1064 .unwrap();
1065
1066 assert_eq!(c.get(&[0, 0]), 19.0);
1067 assert_eq!(c.get(&[1, 1]), 50.0);
1068 }
1069
1070 #[test]
1071 fn test_complex_matmul() {
1072 use num_complex::Complex64;
1073 let i = Complex64::i();
1074
1075 let a_vals = [
1077 [1.0 + i, Complex64::new(2.0, 0.0)],
1078 [Complex64::new(3.0, 0.0), 4.0 - i],
1079 ];
1080 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1081
1082 let b_vals = [
1084 [Complex64::new(1.0, 0.0), i],
1085 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
1086 ];
1087 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1088
1089 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1090
1091 einsum2_into(
1092 c.view_mut(),
1093 &a.view(),
1094 &b.view(),
1095 &['i', 'k'],
1096 &['i', 'j'],
1097 &['j', 'k'],
1098 Complex64::new(1.0, 0.0),
1099 Complex64::new(0.0, 0.0),
1100 )
1101 .unwrap();
1102
1103 assert_eq!(c.get(&[0, 0]), 1.0 + i);
1109 assert_eq!(c.get(&[0, 1]), 1.0 + i);
1110 assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1111 assert_eq!(c.get(&[1, 1]), 4.0 + 2.0 * i);
1112 }
1113
1114 #[test]
1115 fn test_complex_matmul_with_conj() {
1116 use num_complex::Complex64;
1117 let i = Complex64::i();
1118
1119 let a_vals = [[1.0 + i, 2.0 * i], [Complex64::new(3.0, 0.0), 4.0 - i]];
1121 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1122
1123 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| {
1125 if idx[0] == idx[1] {
1126 Complex64::new(1.0, 0.0)
1127 } else {
1128 Complex64::new(0.0, 0.0)
1129 }
1130 });
1131
1132 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1133
1134 let a_conj = a.view().conj();
1136 einsum2_into(
1137 c.view_mut(),
1138 &a_conj,
1139 &b.view(),
1140 &['i', 'k'],
1141 &['i', 'j'],
1142 &['j', 'k'],
1143 Complex64::new(1.0, 0.0),
1144 Complex64::new(0.0, 0.0),
1145 )
1146 .unwrap();
1147
1148 assert_eq!(c.get(&[0, 0]), 1.0 - i);
1150 assert_eq!(c.get(&[0, 1]), -2.0 * i);
1151 assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1152 assert_eq!(c.get(&[1, 1]), 4.0 + i);
1153 }
1154
1155 #[test]
1156 fn test_complex_matmul_with_conj_both() {
1157 use num_complex::Complex64;
1158 let i = Complex64::i();
1159
1160 let a_vals = [
1162 [1.0 + i, Complex64::new(0.0, 0.0)],
1163 [Complex64::new(0.0, 0.0), 2.0 - i],
1164 ];
1165 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1166
1167 let b_vals = [
1169 [Complex64::new(1.0, 0.0), i],
1170 [Complex64::new(0.0, 0.0), 1.0 + i],
1171 ];
1172 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1173
1174 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1175
1176 let a_conj = a.view().conj();
1178 let b_conj = b.view().conj();
1179 einsum2_into(
1180 c.view_mut(),
1181 &a_conj,
1182 &b_conj,
1183 &['i', 'k'],
1184 &['i', 'j'],
1185 &['j', 'k'],
1186 Complex64::new(1.0, 0.0),
1187 Complex64::new(0.0, 0.0),
1188 )
1189 .unwrap();
1190
1191 assert_eq!(c.get(&[0, 0]), 1.0 - i);
1199 assert_eq!(c.get(&[0, 1]), -(1.0 + i));
1200 assert_eq!(c.get(&[1, 0]), Complex64::new(0.0, 0.0));
1201 assert_eq!(c.get(&[1, 1]), 3.0 - i);
1202 }
1203
1204 #[test]
1205 fn test_elementwise_hadamard() {
1206 let a = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1208 (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64
1209 });
1210 let b = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1211 (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64 * 0.1
1212 });
1213 let mut c = StridedArray::<f64>::row_major(&[3, 4, 5]);
1214
1215 einsum2_into(
1216 c.view_mut(),
1217 &a.view(),
1218 &b.view(),
1219 &['i', 'j', 'k'],
1220 &['i', 'j', 'k'],
1221 &['i', 'j', 'k'],
1222 1.0,
1223 0.0,
1224 )
1225 .unwrap();
1226
1227 assert!((c.get(&[0, 0, 0]) - 0.1).abs() < 1e-12);
1229 assert!((c.get(&[2, 3, 4]) - 360.0).abs() < 1e-10);
1231 }
1232
1233 #[test]
1234 fn test_elementwise_hadamard_with_alpha() {
1235 let a =
1236 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1237 let b =
1238 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1239 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1240
1241 einsum2_into(
1242 c.view_mut(),
1243 &a.view(),
1244 &b.view(),
1245 &['i', 'j'],
1246 &['i', 'j'],
1247 &['i', 'j'],
1248 2.0,
1249 0.0,
1250 )
1251 .unwrap();
1252
1253 assert_eq!(c.get(&[0, 0]), 2.0);
1255 assert_eq!(c.get(&[1, 2]), 72.0);
1257 }
1258
1259 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1260 #[test]
1261 fn test_einsum2_owned_matmul() {
1262 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1263 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1264 });
1265 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1266 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1267 });
1268 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1269
1270 einsum2_into_owned(
1271 c.view_mut(),
1272 a,
1273 b,
1274 &['i', 'k'],
1275 &['i', 'j'],
1276 &['j', 'k'],
1277 1.0,
1278 0.0,
1279 false,
1280 false,
1281 )
1282 .unwrap();
1283
1284 assert_eq!(c.get(&[0, 0]), 19.0);
1285 assert_eq!(c.get(&[0, 1]), 22.0);
1286 assert_eq!(c.get(&[1, 0]), 43.0);
1287 assert_eq!(c.get(&[1, 1]), 50.0);
1288 }
1289
1290 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1291 #[test]
1292 fn test_einsum2_owned_batched() {
1293 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
1294 (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
1295 });
1296 let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1297 (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
1298 });
1299 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1300
1301 einsum2_into_owned(
1302 c.view_mut(),
1303 a,
1304 b,
1305 &['b', 'i', 'k'],
1306 &['b', 'i', 'j'],
1307 &['b', 'j', 'k'],
1308 1.0,
1309 0.0,
1310 false,
1311 false,
1312 )
1313 .unwrap();
1314
1315 assert_eq!(c.get(&[0, 0, 0]), 22.0);
1318 }
1319
1320 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1321 #[test]
1322 fn test_einsum2_owned_alpha_beta() {
1323 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1324 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]]
1325 });
1326 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1327 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1328 });
1329 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1330 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1331 });
1332
1333 einsum2_into_owned(
1334 c.view_mut(),
1335 a,
1336 b,
1337 &['i', 'k'],
1338 &['i', 'j'],
1339 &['j', 'k'],
1340 2.0,
1341 3.0,
1342 false,
1343 false,
1344 )
1345 .unwrap();
1346
1347 assert_eq!(c.get(&[0, 0]), 32.0); assert_eq!(c.get(&[1, 1]), 128.0); }
1351
1352 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1353 #[test]
1354 fn test_einsum2_owned_elementwise() {
1355 let a =
1357 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1358 let b =
1359 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1360 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1361
1362 einsum2_into_owned(
1363 c.view_mut(),
1364 a,
1365 b,
1366 &['i', 'j'],
1367 &['i', 'j'],
1368 &['i', 'j'],
1369 2.0,
1370 0.0,
1371 false,
1372 false,
1373 )
1374 .unwrap();
1375
1376 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1379
1380 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1381 #[test]
1382 fn test_einsum2_owned_left_trace() {
1383 let a =
1386 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1387 let b =
1390 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1391 let mut c = StridedArray::<f64>::row_major(&[2]);
1393
1394 einsum2_into_owned(
1395 c.view_mut(),
1396 a,
1397 b,
1398 &['k'],
1399 &['i', 'j'],
1400 &['j', 'k'],
1401 1.0,
1402 0.0,
1403 false,
1404 false,
1405 )
1406 .unwrap();
1407
1408 assert_eq!(c.get(&[0]), 71.0);
1411 assert_eq!(c.get(&[1]), 92.0);
1412 }
1413
1414 #[test]
1415 fn test_einsum2_naive_matmul() {
1416 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1418 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1419 });
1420 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1421 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1422 });
1423 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1424
1425 einsum2_naive_into(
1426 c.view_mut(),
1427 &a.view(),
1428 &b.view(),
1429 &['i', 'k'],
1430 &['i', 'j'],
1431 &['j', 'k'],
1432 1.0,
1433 0.0,
1434 |x| x,
1435 |x| x,
1436 )
1437 .unwrap();
1438
1439 assert_eq!(c.get(&[0, 0]), 19.0);
1440 assert_eq!(c.get(&[0, 1]), 22.0);
1441 assert_eq!(c.get(&[1, 0]), 43.0);
1442 assert_eq!(c.get(&[1, 1]), 50.0);
1443 }
1444
1445 #[test]
1446 fn test_einsum2_naive_elementwise() {
1447 let a =
1449 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1450 let b =
1451 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1452 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1453
1454 einsum2_naive_into(
1455 c.view_mut(),
1456 &a.view(),
1457 &b.view(),
1458 &['i', 'j'],
1459 &['i', 'j'],
1460 &['i', 'j'],
1461 2.0,
1462 0.0,
1463 |x| x,
1464 |x| x,
1465 )
1466 .unwrap();
1467
1468 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1471
1472 #[test]
1473 fn test_einsum2_naive_custom_type() {
1474 use num_traits::{One, Zero};
1477
1478 #[derive(Debug, Clone, Copy, PartialEq)]
1479 struct MyVal(f64);
1480
1481 impl Default for MyVal {
1482 fn default() -> Self {
1483 MyVal(0.0)
1484 }
1485 }
1486
1487 impl std::ops::Add for MyVal {
1488 type Output = Self;
1489 fn add(self, rhs: Self) -> Self {
1490 MyVal(self.0 + rhs.0)
1491 }
1492 }
1493
1494 impl std::ops::Mul for MyVal {
1495 type Output = Self;
1496 fn mul(self, rhs: Self) -> Self {
1497 MyVal(self.0 * rhs.0)
1498 }
1499 }
1500
1501 impl Zero for MyVal {
1502 fn zero() -> Self {
1503 MyVal(0.0)
1504 }
1505 fn is_zero(&self) -> bool {
1506 self.0 == 0.0
1507 }
1508 }
1509
1510 impl One for MyVal {
1511 fn one() -> Self {
1512 MyVal(1.0)
1513 }
1514 }
1515
1516 let a = StridedArray::from_parts(
1518 vec![MyVal(1.0), MyVal(2.0), MyVal(3.0), MyVal(4.0)],
1519 &[2, 2],
1520 &[2, 1],
1521 0,
1522 )
1523 .unwrap();
1524 let b = StridedArray::from_parts(
1525 vec![MyVal(5.0), MyVal(6.0), MyVal(7.0), MyVal(8.0)],
1526 &[2, 2],
1527 &[2, 1],
1528 0,
1529 )
1530 .unwrap();
1531 let mut c = StridedArray::<MyVal>::col_major(&[2, 2]);
1532
1533 einsum2_naive_into(
1534 c.view_mut(),
1535 &a.view(),
1536 &b.view(),
1537 &['i', 'k'],
1538 &['i', 'j'],
1539 &['j', 'k'],
1540 MyVal(1.0),
1541 MyVal(0.0),
1542 |x| x,
1543 |x| x,
1544 )
1545 .unwrap();
1546
1547 assert_eq!(c.get(&[0, 0]), MyVal(19.0));
1550 assert_eq!(c.get(&[0, 1]), MyVal(22.0));
1551 assert_eq!(c.get(&[1, 0]), MyVal(43.0));
1552 assert_eq!(c.get(&[1, 1]), MyVal(50.0));
1553 }
1554
1555 #[test]
1556 fn test_einsum2_naive_left_trace() {
1557 let a =
1559 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1560 let b =
1561 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1562 let mut c = StridedArray::<f64>::row_major(&[2]);
1563
1564 einsum2_naive_into(
1565 c.view_mut(),
1566 &a.view(),
1567 &b.view(),
1568 &['k'],
1569 &['i', 'j'],
1570 &['j', 'k'],
1571 1.0,
1572 0.0,
1573 |x| x,
1574 |x| x,
1575 )
1576 .unwrap();
1577
1578 assert_eq!(c.get(&[0]), 71.0);
1579 assert_eq!(c.get(&[1]), 92.0);
1580 }
1581
1582 struct TestNaiveBackend;
1584
1585 impl Backend<f64> for TestNaiveBackend {
1586 const MATERIALIZES_CONJ: bool = false;
1587 const REQUIRES_UNIT_STRIDE: bool = false;
1588
1589 fn bgemm_contiguous_into(
1590 c: &mut contiguous::ContiguousOperandMut<f64>,
1591 a: &contiguous::ContiguousOperand<f64>,
1592 b: &contiguous::ContiguousOperand<f64>,
1593 batch_dims: &[usize],
1594 m: usize,
1595 n: usize,
1596 k: usize,
1597 alpha: f64,
1598 beta: f64,
1599 ) -> strided_view::Result<()> {
1600 let a_ptr = a.ptr();
1602 let b_ptr = b.ptr();
1603 let c_ptr = c.ptr();
1604 let a_rs = a.row_stride();
1605 let a_cs = a.col_stride();
1606 let b_rs = b.row_stride();
1607 let b_cs = b.col_stride();
1608 let c_rs = c.row_stride();
1609 let c_cs = c.col_stride();
1610
1611 let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1612 while batch_idx.next().is_some() {
1613 let a_base = batch_idx.offset(a.batch_strides());
1614 let b_base = batch_idx.offset(b.batch_strides());
1615 let c_base = batch_idx.offset(c.batch_strides());
1616
1617 for i in 0..m {
1618 for j in 0..n {
1619 let mut acc = 0.0f64;
1620 for l in 0..k {
1621 let a_val = unsafe {
1622 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1623 };
1624 let b_val = unsafe {
1625 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1626 };
1627 acc += a_val * b_val;
1628 }
1629 unsafe {
1630 let c_elem =
1631 c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1632 if beta == 0.0 {
1633 *c_elem = alpha * acc;
1634 } else {
1635 *c_elem = alpha * acc + beta * (*c_elem);
1636 }
1637 }
1638 }
1639 }
1640 }
1641 Ok(())
1642 }
1643 }
1644
1645 #[test]
1646 fn test_einsum2_with_backend_matmul() {
1647 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1648 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1649 });
1650 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1651 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1652 });
1653 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1654
1655 einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1656 c.view_mut(),
1657 &a.view(),
1658 &b.view(),
1659 &['i', 'k'],
1660 &['i', 'j'],
1661 &['j', 'k'],
1662 1.0,
1663 0.0,
1664 )
1665 .unwrap();
1666
1667 assert_eq!(c.get(&[0, 0]), 19.0);
1668 assert_eq!(c.get(&[0, 1]), 22.0);
1669 assert_eq!(c.get(&[1, 0]), 43.0);
1670 assert_eq!(c.get(&[1, 1]), 50.0);
1671 }
1672
1673 #[test]
1674 fn test_einsum2_with_backend_elementwise() {
1675 let a =
1676 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1677 let b =
1678 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1679 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1680
1681 einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1682 c.view_mut(),
1683 &a.view(),
1684 &b.view(),
1685 &['i', 'j'],
1686 &['i', 'j'],
1687 &['i', 'j'],
1688 2.0,
1689 0.0,
1690 )
1691 .unwrap();
1692
1693 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1696
1697 #[test]
1698 fn test_einsum2_with_backend_custom_type() {
1699 use num_traits::{One, Zero};
1700
1701 #[derive(Debug, Clone, Copy, PartialEq)]
1702 struct Tropical(f64);
1703
1704 impl Default for Tropical {
1705 fn default() -> Self {
1706 Tropical(0.0)
1707 }
1708 }
1709
1710 impl std::ops::Add for Tropical {
1711 type Output = Self;
1712 fn add(self, rhs: Self) -> Self {
1713 Tropical(self.0 + rhs.0)
1714 }
1715 }
1716
1717 impl std::ops::Mul for Tropical {
1718 type Output = Self;
1719 fn mul(self, rhs: Self) -> Self {
1720 Tropical(self.0 * rhs.0)
1721 }
1722 }
1723
1724 impl Zero for Tropical {
1725 fn zero() -> Self {
1726 Tropical(0.0)
1727 }
1728 fn is_zero(&self) -> bool {
1729 self.0 == 0.0
1730 }
1731 }
1732
1733 impl One for Tropical {
1734 fn one() -> Self {
1735 Tropical(1.0)
1736 }
1737 }
1738
1739 struct TropicalBackend;
1740
1741 impl Backend<Tropical> for TropicalBackend {
1742 const MATERIALIZES_CONJ: bool = false;
1743 const REQUIRES_UNIT_STRIDE: bool = false;
1744
1745 fn bgemm_contiguous_into(
1746 c: &mut contiguous::ContiguousOperandMut<Tropical>,
1747 a: &contiguous::ContiguousOperand<Tropical>,
1748 b: &contiguous::ContiguousOperand<Tropical>,
1749 batch_dims: &[usize],
1750 m: usize,
1751 n: usize,
1752 k: usize,
1753 alpha: Tropical,
1754 beta: Tropical,
1755 ) -> strided_view::Result<()> {
1756 let a_ptr = a.ptr();
1758 let b_ptr = b.ptr();
1759 let c_ptr = c.ptr();
1760 let a_rs = a.row_stride();
1761 let a_cs = a.col_stride();
1762 let b_rs = b.row_stride();
1763 let b_cs = b.col_stride();
1764 let c_rs = c.row_stride();
1765 let c_cs = c.col_stride();
1766
1767 let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1768 while batch_idx.next().is_some() {
1769 let a_base = batch_idx.offset(a.batch_strides());
1770 let b_base = batch_idx.offset(b.batch_strides());
1771 let c_base = batch_idx.offset(c.batch_strides());
1772
1773 for i in 0..m {
1774 for j in 0..n {
1775 let mut acc = Tropical::zero();
1776 for l in 0..k {
1777 let a_val = unsafe {
1778 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1779 };
1780 let b_val = unsafe {
1781 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1782 };
1783 acc = acc + a_val * b_val;
1784 }
1785 unsafe {
1786 let c_elem =
1787 c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1788 if beta == Tropical::zero() {
1789 *c_elem = alpha * acc;
1790 } else {
1791 *c_elem = alpha * acc + beta * (*c_elem);
1792 }
1793 }
1794 }
1795 }
1796 }
1797 Ok(())
1798 }
1799 }
1800
1801 let a = StridedArray::from_parts(
1802 vec![Tropical(1.0), Tropical(2.0), Tropical(3.0), Tropical(4.0)],
1803 &[2, 2],
1804 &[2, 1],
1805 0,
1806 )
1807 .unwrap();
1808 let b = StridedArray::from_parts(
1809 vec![Tropical(5.0), Tropical(6.0), Tropical(7.0), Tropical(8.0)],
1810 &[2, 2],
1811 &[2, 1],
1812 0,
1813 )
1814 .unwrap();
1815 let mut c = StridedArray::<Tropical>::col_major(&[2, 2]);
1816
1817 einsum2_with_backend_into::<_, TropicalBackend, _>(
1818 c.view_mut(),
1819 &a.view(),
1820 &b.view(),
1821 &['i', 'k'],
1822 &['i', 'j'],
1823 &['j', 'k'],
1824 Tropical(1.0),
1825 Tropical(0.0),
1826 )
1827 .unwrap();
1828
1829 assert_eq!(c.get(&[0, 0]), Tropical(19.0));
1832 assert_eq!(c.get(&[0, 1]), Tropical(22.0));
1833 assert_eq!(c.get(&[1, 0]), Tropical(43.0));
1834 assert_eq!(c.get(&[1, 1]), Tropical(50.0));
1835 }
1836}