1#[cfg(all(feature = "blas", feature = "blas-inject"))]
25compile_error!("Features `blas` and `blas-inject` are mutually exclusive.");
26
27#[cfg(any(
28 all(feature = "blas-accelerate", feature = "blas-openblas"),
29 all(feature = "blas-accelerate", feature = "blas-mkl"),
30 all(feature = "blas-openblas", feature = "blas-mkl")
31))]
32compile_error!("Select at most one explicit BLAS provider feature.");
33
34#[cfg(all(feature = "blas-inject", not(feature = "blas")))]
35extern crate cblas_inject as cblas_sys;
36#[cfg(all(feature = "blas", not(feature = "blas-inject")))]
37extern crate cblas_sys;
38
39#[cfg(any(
40 all(feature = "blas", not(feature = "blas-inject")),
41 all(feature = "blas-inject", not(feature = "blas"))
42))]
43pub mod bgemm_blas;
44
45#[cfg(feature = "faer")]
46pub mod bgemm_faer;
48pub mod bgemm_naive;
50pub mod contiguous;
52pub mod dot_general;
54pub mod plan;
56pub mod raw_bgemm;
58pub mod trace;
60pub mod util;
62
63pub mod backend;
65
66use std::any::TypeId;
67use std::fmt::Debug;
68use std::hash::Hash;
69
70use strided_kernel::zip_map2_into;
71#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
72use strided_view::StridedArray;
73use strided_view::{Adjoint, Conj, ElementOp, ElementOpApply};
74
75pub use strided_traits::ScalarBase;
76pub use strided_view::{
77 col_major_strides, RawStridedMut, RawStridedRef, StridedView, StridedViewMut,
78};
79
80pub use backend::Backend;
81pub use dot_general::{dot_general_into, dot_general_with_backend_into, DotGeneralConfig};
82pub use plan::Einsum2Plan;
83pub use raw_bgemm::{
84 bgemm_raw_strided_into, bgemm_raw_strided_into_unchecked, bgemm_raw_with_backend_into,
85 bgemm_raw_with_backend_into_unchecked,
86};
87
88pub trait AxisId: Clone + Eq + Hash + Debug {}
90impl<T: Clone + Eq + Hash + Debug> AxisId for T {}
91
92#[cfg(any(
101 all(feature = "blas", not(feature = "blas-inject")),
102 all(feature = "blas-inject", not(feature = "blas"))
103))]
104pub trait Scalar: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
105
106#[cfg(any(
107 all(feature = "blas", not(feature = "blas-inject")),
108 all(feature = "blas-inject", not(feature = "blas"))
109))]
110impl<T> Scalar for T where T: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
111
112#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
117pub trait Scalar: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
118
119#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
120impl<T> Scalar for T where T: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
121
122#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
124pub trait Scalar: ScalarBase + ElementOpApply {}
125
126#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
127impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
128
129#[cfg(all(feature = "blas", feature = "blas-inject"))]
134pub trait Scalar: ScalarBase + ElementOpApply {}
135
136#[cfg(all(feature = "blas", feature = "blas-inject"))]
137impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
138
139#[derive(Debug, thiserror::Error)]
141pub enum EinsumError {
142 #[error("duplicate axis label: {0}")]
143 DuplicateAxis(String),
144 #[error("output axis {0} not found in any input")]
145 OrphanOutputAxis(String),
146 #[error("dimension mismatch for axis {axis:?}: {dim_a} vs {dim_b}")]
147 DimensionMismatch {
148 axis: String,
149 dim_a: usize,
150 dim_b: usize,
151 },
152 #[error("invalid dot-general config: {0}")]
153 InvalidDotGeneralConfig(String),
154 #[error("output shape mismatch: expected {expected:?}, got {got:?}")]
155 OutputShapeMismatch {
156 expected: Vec<usize>,
157 got: Vec<usize>,
158 },
159 #[error(transparent)]
160 Strided(#[from] strided_view::StridedError),
161}
162
163pub type Result<T> = std::result::Result<T, EinsumError>;
165
166fn op_is_conj<Op: 'static>() -> bool {
175 TypeId::of::<Op>() == TypeId::of::<Conj>() || TypeId::of::<Op>() == TypeId::of::<Adjoint>()
176}
177
178pub fn einsum2_into<T: Scalar, OpA, OpB, ID: AxisId>(
189 c: StridedViewMut<T>,
190 a: &StridedView<T, OpA>,
191 b: &StridedView<T, OpB>,
192 ic: &[ID],
193 ia: &[ID],
194 ib: &[ID],
195 alpha: T,
196 beta: T,
197) -> Result<()>
198where
199 OpA: ElementOp<T> + 'static,
200 OpB: ElementOp<T> + 'static,
201{
202 let plan = Einsum2Plan::new(ia, ib, ic)?;
204
205 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
207
208 let left_trace = trace::find_trace_indices(ia, ib, ic);
213 let (a_buf, conj_a) = if !left_trace.is_empty() {
214 (Some(trace::reduce_trace_axes(a, &left_trace)?), false)
215 } else {
216 (None, op_is_conj::<OpA>())
217 };
218
219 let a_view: StridedView<T> = match a_buf.as_ref() {
220 Some(buf) => buf.view(),
221 None => StridedView::new(a.data(), a.dims(), a.strides(), a.offset())
222 .expect("strip_op_view: metadata already validated"),
223 };
224
225 let right_trace = trace::find_trace_indices(ib, ia, ic);
226 let (b_buf, conj_b) = if !right_trace.is_empty() {
227 (Some(trace::reduce_trace_axes(b, &right_trace)?), false)
228 } else {
229 (None, op_is_conj::<OpB>())
230 };
231
232 let b_view: StridedView<T> = match b_buf.as_ref() {
233 Some(buf) => buf.view(),
234 None => StridedView::new(b.data(), b.dims(), b.strides(), b.offset())
235 .expect("strip_op_view: metadata already validated"),
236 };
237
238 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
240 {
241 let conj_fn = make_conj_fn::<T>();
242 einsum2_dispatch::<T, backend::ActiveBackend, _>(
243 c, &a_view, &b_view, &plan, alpha, beta, conj_a, conj_b, conj_fn,
244 )?;
245 }
246
247 #[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
248 {
249 let a_perm = a_view.permute(&plan.left_perm)?;
250 let b_perm = b_view.permute(&plan.right_perm)?;
251 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
252
253 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
254 let mul_fn = move |a_val: T, b_val: T| -> T {
255 let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
256 let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
257 alpha * a_c * b_c
258 };
259 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
260 return Ok(());
261 }
262
263 bgemm_naive::bgemm_strided_into(
264 &mut c_perm,
265 &a_perm,
266 &b_perm,
267 plan.batch.len(),
268 plan.lo.len(),
269 plan.ro.len(),
270 plan.sum.len(),
271 alpha,
272 beta,
273 conj_a,
274 conj_b,
275 )?;
276 }
277
278 Ok(())
279}
280
281pub fn einsum2_naive_into<T, ID, MapA, MapB>(
289 c: StridedViewMut<T>,
290 a: &StridedView<T>,
291 b: &StridedView<T>,
292 ic: &[ID],
293 ia: &[ID],
294 ib: &[ID],
295 alpha: T,
296 beta: T,
297 map_a: MapA,
298 map_b: MapB,
299) -> Result<()>
300where
301 T: ScalarBase,
302 ID: AxisId,
303 MapA: Fn(T) -> T + strided_kernel::MaybeSync,
304 MapB: Fn(T) -> T + strided_kernel::MaybeSync,
305{
306 let plan = Einsum2Plan::new(ia, ib, ic)?;
307 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
308
309 let left_trace = trace::find_trace_indices(ia, ib, ic);
313 let (a_buf, use_map_a) = if !left_trace.is_empty() {
314 let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(a.dims()) };
315 strided_kernel::map_into(&mut mapped.view_mut(), a, &map_a)?;
316 let reduced = trace::reduce_trace_axes(&mapped.view(), &left_trace)?;
317 (Some(reduced), false)
318 } else {
319 (None, true)
320 };
321 let a_view: StridedView<T> = match a_buf.as_ref() {
322 Some(buf) => buf.view(),
323 None => a.clone(),
324 };
325
326 let right_trace = trace::find_trace_indices(ib, ia, ic);
327 let (b_buf, use_map_b) = if !right_trace.is_empty() {
328 let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(b.dims()) };
329 strided_kernel::map_into(&mut mapped.view_mut(), b, &map_b)?;
330 let reduced = trace::reduce_trace_axes(&mapped.view(), &right_trace)?;
331 (Some(reduced), false)
332 } else {
333 (None, true)
334 };
335 let b_view: StridedView<T> = match b_buf.as_ref() {
336 Some(buf) => buf.view(),
337 None => b.clone(),
338 };
339
340 let a_perm = a_view.permute(&plan.left_perm)?;
341 let b_perm = b_view.permute(&plan.right_perm)?;
342 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
343
344 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
346 let mul_fn = move |a_val: T, b_val: T| -> T {
347 let a_c = if use_map_a { map_a(a_val) } else { a_val };
348 let b_c = if use_map_b { map_b(b_val) } else { b_val };
349 alpha * a_c * b_c
350 };
351 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
352 return Ok(());
353 }
354
355 let final_map_a: Box<dyn Fn(T) -> T> = if use_map_a {
356 Box::new(map_a)
357 } else {
358 Box::new(|x| x)
359 };
360 let final_map_b: Box<dyn Fn(T) -> T> = if use_map_b {
361 Box::new(map_b)
362 } else {
363 Box::new(|x| x)
364 };
365
366 bgemm_naive::bgemm_strided_into_with_map(
367 &mut c_perm,
368 &a_perm,
369 &b_perm,
370 plan.batch.len(),
371 plan.lo.len(),
372 plan.ro.len(),
373 plan.sum.len(),
374 alpha,
375 beta,
376 final_map_a,
377 final_map_b,
378 )?;
379
380 Ok(())
381}
382
383pub fn einsum2_with_backend_into<T, B, ID>(
392 c: StridedViewMut<T>,
393 a: &StridedView<T>,
394 b: &StridedView<T>,
395 ic: &[ID],
396 ia: &[ID],
397 ib: &[ID],
398 alpha: T,
399 beta: T,
400) -> Result<()>
401where
402 T: ScalarBase,
403 B: Backend<T>,
404 ID: AxisId,
405{
406 let plan = Einsum2Plan::new(ia, ib, ic)?;
407 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
408
409 let left_trace = trace::find_trace_indices(ia, ib, ic);
411 let a_buf = if !left_trace.is_empty() {
412 Some(trace::reduce_trace_axes(a, &left_trace)?)
413 } else {
414 None
415 };
416 let a_view: StridedView<T> = match a_buf.as_ref() {
417 Some(buf) => buf.view(),
418 None => a.clone(),
419 };
420
421 let right_trace = trace::find_trace_indices(ib, ia, ic);
422 let b_buf = if !right_trace.is_empty() {
423 Some(trace::reduce_trace_axes(b, &right_trace)?)
424 } else {
425 None
426 };
427 let b_view: StridedView<T> = match b_buf.as_ref() {
428 Some(buf) => buf.view(),
429 None => b.clone(),
430 };
431
432 einsum2_dispatch::<T, B, _>(c, &a_view, &b_view, &plan, alpha, beta, false, false, None)
434}
435
436#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
442fn make_conj_fn<T: Scalar>() -> Option<fn(T) -> T> {
443 if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
444 Some(|x| Conj::apply(x))
445 } else {
446 None
447 }
448}
449
450fn scale_or_zero_strided_mut<T: ScalarBase>(c: &mut StridedViewMut<T>, beta: T) {
451 if c.is_empty() {
452 return;
453 }
454
455 let dims = c.dims().to_vec();
456 let strides = c.strides().to_vec();
457 let ptr = c.as_mut_ptr();
458 let zero = T::zero();
459
460 fn visit<T: ScalarBase>(
461 ptr: *mut T,
462 dims: &[usize],
463 strides: &[isize],
464 axis: usize,
465 offset: isize,
466 beta: T,
467 zero: T,
468 ) {
469 if axis == dims.len() {
470 unsafe {
471 let dst = ptr.offset(offset);
472 if beta == zero {
473 *dst = zero;
474 } else {
475 *dst = beta * *dst;
476 }
477 }
478 return;
479 }
480
481 for i in 0..dims[axis] {
482 visit(
483 ptr,
484 dims,
485 strides,
486 axis + 1,
487 offset + i as isize * strides[axis],
488 beta,
489 zero,
490 );
491 }
492 }
493
494 visit(ptr, &dims, &strides, 0, 0, beta, zero);
495}
496
497pub(crate) fn einsum2_dispatch<T, B, ID>(
510 c: StridedViewMut<T>,
511 a: &StridedView<T>,
512 b: &StridedView<T>,
513 plan: &Einsum2Plan<ID>,
514 alpha: T,
515 beta: T,
516 conj_a: bool,
517 conj_b: bool,
518 conj_fn: Option<fn(T) -> T>,
519) -> Result<()>
520where
521 T: ScalarBase,
522 B: Backend<T>,
523 ID: AxisId,
524{
525 let a_perm = a.permute(&plan.left_perm)?;
527 let b_perm = b.permute(&plan.right_perm)?;
528 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
529
530 if c_perm.is_empty() {
531 return Ok(());
532 }
533
534 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
536 if !conj_a && !conj_b && alpha == T::one() {
537 zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
538 } else if !conj_a && !conj_b {
539 let mul_fn = move |a_val: T, b_val: T| -> T { alpha * a_val * b_val };
540 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
541 } else {
542 let conj_fn = conj_fn.unwrap_or(|x| x);
543 let mul_fn = move |a_val: T, b_val: T| -> T {
544 let a_c = if conj_a { conj_fn(a_val) } else { a_val };
545 let b_c = if conj_b { conj_fn(b_val) } else { b_val };
546 alpha * a_c * b_c
547 };
548 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
549 }
550 return Ok(());
551 }
552
553 let n_lo = plan.lo.len();
555 let n_ro = plan.ro.len();
556 let n_sum = plan.sum.len();
557 let use_pool = true;
558 let materialize = if B::MATERIALIZES_CONJ { conj_fn } else { None };
559
560 let a_op = contiguous::prepare_input_view(
561 &a_perm,
562 n_lo,
563 n_sum,
564 conj_a,
565 B::REQUIRES_UNIT_STRIDE,
566 use_pool,
567 materialize,
568 )?;
569 let b_op = contiguous::prepare_input_view(
570 &b_perm,
571 n_sum,
572 n_ro,
573 conj_b,
574 B::REQUIRES_UNIT_STRIDE,
575 use_pool,
576 materialize,
577 )?;
578 let mut c_op = contiguous::prepare_output_view(
579 &mut c_perm,
580 n_lo,
581 n_ro,
582 beta,
583 B::REQUIRES_UNIT_STRIDE,
584 use_pool,
585 )?;
586
587 let lo_dims = &a_perm.dims()[..n_lo];
589 let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
590 let batch_dims = &a_perm.dims()[n_lo + n_sum..];
591 let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
592 if sum_dims.iter().any(|&dim| dim == 0) {
593 scale_or_zero_strided_mut(&mut c_perm, beta);
594 return Ok(());
595 }
596 let m: usize = lo_dims.iter().product::<usize>().max(1);
597 let k: usize = sum_dims.iter().product::<usize>().max(1);
598 let n: usize = ro_dims.iter().product::<usize>().max(1);
599
600 B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
602
603 c_op.finalize_into(&mut c_perm)?;
605
606 Ok(())
607}
608
609#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
618pub fn einsum2_into_owned<T: Scalar, ID: AxisId>(
619 c: StridedViewMut<T>,
620 a: StridedArray<T>,
621 b: StridedArray<T>,
622 ic: &[ID],
623 ia: &[ID],
624 ib: &[ID],
625 alpha: T,
626 beta: T,
627 conj_a: bool,
628 conj_b: bool,
629) -> Result<()>
630where
631 backend::ActiveBackend: Backend<T>,
632{
633 let plan = Einsum2Plan::new(ia, ib, ic)?;
635
636 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
638
639 let left_trace = trace::find_trace_indices(ia, ib, ic);
643 let (a_for_gemm, conj_a_final) = if !left_trace.is_empty() {
644 (trace::reduce_trace_axes(&a.view(), &left_trace)?, false)
645 } else {
646 (a, conj_a)
647 };
648
649 let right_trace = trace::find_trace_indices(ib, ia, ic);
650 let (b_for_gemm, conj_b_final) = if !right_trace.is_empty() {
651 (trace::reduce_trace_axes(&b.view(), &right_trace)?, false)
652 } else {
653 (b, conj_b)
654 };
655
656 let a_perm = a_for_gemm.permuted(&plan.left_perm)?;
658 let b_perm = b_for_gemm.permuted(&plan.right_perm)?;
659 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
660
661 let n_lo = plan.lo.len();
662 let n_ro = plan.ro.len();
663 let n_sum = plan.sum.len();
664
665 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
667 let mul_fn = move |a_val: T, b_val: T| -> T {
668 let a_c = if conj_a_final {
669 Conj::apply(a_val)
670 } else {
671 a_val
672 };
673 let b_c = if conj_b_final {
674 Conj::apply(b_val)
675 } else {
676 b_val
677 };
678 alpha * a_c * b_c
679 };
680 zip_map2_into(&mut c_perm, &a_perm.view(), &b_perm.view(), mul_fn)?;
681 return Ok(());
682 }
683
684 let a_dims_perm = a_perm.dims().to_vec();
686 let b_dims_perm = b_perm.dims().to_vec();
687
688 let lo_dims = &a_dims_perm[..n_lo];
689 let sum_dims = &a_dims_perm[n_lo..n_lo + n_sum];
690 let batch_dims = a_dims_perm[n_lo + n_sum..].to_vec();
691 let ro_dims = &b_dims_perm[n_sum..n_sum + n_ro];
692 let m: usize = lo_dims.iter().product::<usize>().max(1);
693 let k: usize = sum_dims.iter().product::<usize>().max(1);
694 let n: usize = ro_dims.iter().product::<usize>().max(1);
695
696 let conj_fn = make_conj_fn::<T>();
698 let materialize = if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
699 conj_fn
700 } else {
701 None
702 };
703 let use_pool = true;
704 let unit_stride = <backend::ActiveBackend as Backend<T>>::REQUIRES_UNIT_STRIDE;
705 let a_op = contiguous::prepare_input_owned(
706 a_perm,
707 n_lo,
708 n_sum,
709 conj_a_final,
710 unit_stride,
711 use_pool,
712 materialize,
713 )?;
714 let b_op = contiguous::prepare_input_owned(
715 b_perm,
716 n_sum,
717 n_ro,
718 conj_b_final,
719 unit_stride,
720 use_pool,
721 materialize,
722 )?;
723 let mut c_op =
724 contiguous::prepare_output_view(&mut c_perm, n_lo, n_ro, beta, unit_stride, use_pool)?;
725
726 backend::ActiveBackend::bgemm_contiguous_into(
728 &mut c_op,
729 &a_op,
730 &b_op,
731 &batch_dims,
732 m,
733 n,
734 k,
735 alpha,
736 beta,
737 )?;
738
739 c_op.finalize_into(&mut c_perm)?;
741
742 Ok(())
743}
744
745fn validate_dimensions<ID: AxisId>(
747 plan: &Einsum2Plan<ID>,
748 a_dims: &[usize],
749 b_dims: &[usize],
750 c_dims: &[usize],
751 ia: &[ID],
752 ib: &[ID],
753 ic: &[ID],
754) -> Result<()> {
755 let find_dim = |labels: &[ID], dims: &[usize], id: &ID| -> usize {
756 labels
757 .iter()
758 .position(|x| x == id)
759 .map(|i| dims[i])
760 .unwrap()
761 };
762
763 for id in &plan.batch {
765 let da = find_dim(ia, a_dims, id);
766 let db = find_dim(ib, b_dims, id);
767 let dc = find_dim(ic, c_dims, id);
768 if da != db || da != dc {
769 return Err(EinsumError::DimensionMismatch {
770 axis: format!("{:?}", id),
771 dim_a: da,
772 dim_b: db,
773 });
774 }
775 }
776
777 for id in &plan.sum {
779 let da = find_dim(ia, a_dims, id);
780 let db = find_dim(ib, b_dims, id);
781 if da != db {
782 return Err(EinsumError::DimensionMismatch {
783 axis: format!("{:?}", id),
784 dim_a: da,
785 dim_b: db,
786 });
787 }
788 }
789
790 for id in &plan.lo {
792 let da = find_dim(ia, a_dims, id);
793 let dc = find_dim(ic, c_dims, id);
794 if da != dc {
795 return Err(EinsumError::DimensionMismatch {
796 axis: format!("{:?}", id),
797 dim_a: da,
798 dim_b: dc,
799 });
800 }
801 }
802
803 for id in &plan.ro {
805 let db = find_dim(ib, b_dims, id);
806 let dc = find_dim(ic, c_dims, id);
807 if db != dc {
808 return Err(EinsumError::DimensionMismatch {
809 axis: format!("{:?}", id),
810 dim_a: db,
811 dim_b: dc,
812 });
813 }
814 }
815
816 Ok(())
817}
818
819#[cfg(test)]
820mod tests {
821 use super::*;
822 use strided_view::StridedArray;
823
824 #[test]
825 fn test_matmul_ij_jk_ik() {
826 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
828 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
829 });
830 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
831 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
832 });
833 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
834
835 einsum2_into(
836 c.view_mut(),
837 &a.view(),
838 &b.view(),
839 &['i', 'k'],
840 &['i', 'j'],
841 &['j', 'k'],
842 1.0,
843 0.0,
844 )
845 .unwrap();
846
847 assert_eq!(c.get(&[0, 0]), 19.0);
848 assert_eq!(c.get(&[0, 1]), 22.0);
849 assert_eq!(c.get(&[1, 0]), 43.0);
850 assert_eq!(c.get(&[1, 1]), 50.0);
851 }
852
853 #[test]
854 fn test_matmul_rect() {
855 let a =
857 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
858 let b =
859 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
860 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
861
862 einsum2_into(
863 c.view_mut(),
864 &a.view(),
865 &b.view(),
866 &['i', 'k'],
867 &['i', 'j'],
868 &['j', 'k'],
869 1.0,
870 0.0,
871 )
872 .unwrap();
873
874 assert_eq!(c.get(&[0, 0]), 38.0);
876 assert_eq!(c.get(&[1, 3]), 128.0);
877 }
878
879 #[test]
880 fn test_batched_matmul() {
881 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
883 (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
884 });
885 let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
886 (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
887 });
888 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
889
890 einsum2_into(
891 c.view_mut(),
892 &a.view(),
893 &b.view(),
894 &['b', 'i', 'k'],
895 &['b', 'i', 'j'],
896 &['b', 'j', 'k'],
897 1.0,
898 0.0,
899 )
900 .unwrap();
901
902 assert_eq!(c.get(&[0, 0, 0]), 22.0);
905 }
906
907 #[test]
908 fn test_batched_matmul_col_major_output() {
909 let a_data = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
911 let b_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
912 let a = StridedArray::<f64>::from_parts(a_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
913 let b = StridedArray::<f64>::from_parts(b_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
914 let mut c = StridedArray::<f64>::col_major(&[2, 2, 2]);
915
916 einsum2_into(
917 c.view_mut(),
918 &a.view(),
919 &b.view(),
920 &['b', 'i', 'k'],
921 &['b', 'i', 'j'],
922 &['b', 'j', 'k'],
923 1.0,
924 0.0,
925 )
926 .unwrap();
927
928 assert_eq!(c.get(&[0, 0, 0]), 1.0);
930 assert_eq!(c.get(&[0, 0, 1]), 2.0);
931 assert_eq!(c.get(&[0, 1, 0]), 3.0);
932 assert_eq!(c.get(&[0, 1, 1]), 4.0);
933 assert_eq!(c.get(&[1, 0, 0]), 10.0);
935 assert_eq!(c.get(&[1, 1, 1]), 16.0);
936 }
937
938 #[test]
939 fn test_outer_product() {
940 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
942 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
943 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
944
945 einsum2_into(
946 c.view_mut(),
947 &a.view(),
948 &b.view(),
949 &['i', 'j'],
950 &['i'],
951 &['j'],
952 1.0,
953 0.0,
954 )
955 .unwrap();
956
957 assert_eq!(c.get(&[0, 0]), 1.0);
958 assert_eq!(c.get(&[2, 3]), 12.0);
959 }
960
961 #[test]
962 fn test_dot_product() {
963 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
965 let b = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
966 let mut c = StridedArray::<f64>::row_major(&[]);
967
968 einsum2_into(
969 c.view_mut(),
970 &a.view(),
971 &b.view(),
972 &[] as &[char],
973 &['i'],
974 &['i'],
975 1.0,
976 0.0,
977 )
978 .unwrap();
979
980 assert_eq!(c.get(&[]), 14.0);
982 }
983
984 #[test]
985 fn test_alpha_beta() {
986 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
988 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
990 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
991 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
992 });
993 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
994 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
995 });
996
997 einsum2_into(
998 c.view_mut(),
999 &a.view(),
1000 &b.view(),
1001 &['i', 'k'],
1002 &['i', 'j'],
1003 &['j', 'k'],
1004 2.0,
1005 3.0,
1006 )
1007 .unwrap();
1008
1009 assert_eq!(c.get(&[0, 0]), 32.0); assert_eq!(c.get(&[1, 1]), 128.0); }
1013
1014 #[test]
1015 fn test_transposed_output() {
1016 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1018 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1019 });
1020 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1021 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1022 });
1023 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1024
1025 einsum2_into(
1026 c.view_mut(),
1027 &a.view(),
1028 &b.view(),
1029 &['k', 'i'], &['i', 'j'],
1031 &['j', 'k'],
1032 1.0,
1033 0.0,
1034 )
1035 .unwrap();
1036
1037 assert_eq!(c.get(&[0, 0]), 19.0); assert_eq!(c.get(&[0, 1]), 43.0); assert_eq!(c.get(&[1, 0]), 22.0); assert_eq!(c.get(&[1, 1]), 50.0); }
1044
1045 #[test]
1046 fn test_left_trace() {
1047 let a =
1050 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1051 let b =
1054 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1055 let mut c = StridedArray::<f64>::row_major(&[2]);
1057
1058 einsum2_into(
1059 c.view_mut(),
1060 &a.view(),
1061 &b.view(),
1062 &['k'],
1063 &['i', 'j'],
1064 &['j', 'k'],
1065 1.0,
1066 0.0,
1067 )
1068 .unwrap();
1069
1070 assert_eq!(c.get(&[0]), 71.0);
1074 assert_eq!(c.get(&[1]), 92.0);
1075 }
1076
1077 #[test]
1078 fn test_u32_labels() {
1079 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1081 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1082 });
1083 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1084 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1085 });
1086 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1087
1088 einsum2_into(
1089 c.view_mut(),
1090 &a.view(),
1091 &b.view(),
1092 &[0u32, 2],
1093 &[0u32, 1],
1094 &[1u32, 2],
1095 1.0,
1096 0.0,
1097 )
1098 .unwrap();
1099
1100 assert_eq!(c.get(&[0, 0]), 19.0);
1101 assert_eq!(c.get(&[1, 1]), 50.0);
1102 }
1103
1104 #[test]
1105 fn test_complex_matmul() {
1106 use num_complex::Complex64;
1107 let i = Complex64::i();
1108
1109 let a_vals = [
1111 [1.0 + i, Complex64::new(2.0, 0.0)],
1112 [Complex64::new(3.0, 0.0), 4.0 - i],
1113 ];
1114 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1115
1116 let b_vals = [
1118 [Complex64::new(1.0, 0.0), i],
1119 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
1120 ];
1121 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1122
1123 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1124
1125 einsum2_into(
1126 c.view_mut(),
1127 &a.view(),
1128 &b.view(),
1129 &['i', 'k'],
1130 &['i', 'j'],
1131 &['j', 'k'],
1132 Complex64::new(1.0, 0.0),
1133 Complex64::new(0.0, 0.0),
1134 )
1135 .unwrap();
1136
1137 assert_eq!(c.get(&[0, 0]), 1.0 + i);
1143 assert_eq!(c.get(&[0, 1]), 1.0 + i);
1144 assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1145 assert_eq!(c.get(&[1, 1]), 4.0 + 2.0 * i);
1146 }
1147
1148 #[test]
1149 fn test_complex_matmul_with_conj() {
1150 use num_complex::Complex64;
1151 let i = Complex64::i();
1152
1153 let a_vals = [[1.0 + i, 2.0 * i], [Complex64::new(3.0, 0.0), 4.0 - i]];
1155 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1156
1157 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| {
1159 if idx[0] == idx[1] {
1160 Complex64::new(1.0, 0.0)
1161 } else {
1162 Complex64::new(0.0, 0.0)
1163 }
1164 });
1165
1166 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1167
1168 let a_conj = a.view().conj();
1170 einsum2_into(
1171 c.view_mut(),
1172 &a_conj,
1173 &b.view(),
1174 &['i', 'k'],
1175 &['i', 'j'],
1176 &['j', 'k'],
1177 Complex64::new(1.0, 0.0),
1178 Complex64::new(0.0, 0.0),
1179 )
1180 .unwrap();
1181
1182 assert_eq!(c.get(&[0, 0]), 1.0 - i);
1184 assert_eq!(c.get(&[0, 1]), -2.0 * i);
1185 assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1186 assert_eq!(c.get(&[1, 1]), 4.0 + i);
1187 }
1188
1189 #[test]
1190 fn test_complex_matmul_with_conj_both() {
1191 use num_complex::Complex64;
1192 let i = Complex64::i();
1193
1194 let a_vals = [
1196 [1.0 + i, Complex64::new(0.0, 0.0)],
1197 [Complex64::new(0.0, 0.0), 2.0 - i],
1198 ];
1199 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1200
1201 let b_vals = [
1203 [Complex64::new(1.0, 0.0), i],
1204 [Complex64::new(0.0, 0.0), 1.0 + i],
1205 ];
1206 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1207
1208 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1209
1210 let a_conj = a.view().conj();
1212 let b_conj = b.view().conj();
1213 einsum2_into(
1214 c.view_mut(),
1215 &a_conj,
1216 &b_conj,
1217 &['i', 'k'],
1218 &['i', 'j'],
1219 &['j', 'k'],
1220 Complex64::new(1.0, 0.0),
1221 Complex64::new(0.0, 0.0),
1222 )
1223 .unwrap();
1224
1225 assert_eq!(c.get(&[0, 0]), 1.0 - i);
1233 assert_eq!(c.get(&[0, 1]), -(1.0 + i));
1234 assert_eq!(c.get(&[1, 0]), Complex64::new(0.0, 0.0));
1235 assert_eq!(c.get(&[1, 1]), 3.0 - i);
1236 }
1237
1238 #[test]
1239 fn test_elementwise_hadamard() {
1240 let a = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1242 (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64
1243 });
1244 let b = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1245 (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64 * 0.1
1246 });
1247 let mut c = StridedArray::<f64>::row_major(&[3, 4, 5]);
1248
1249 einsum2_into(
1250 c.view_mut(),
1251 &a.view(),
1252 &b.view(),
1253 &['i', 'j', 'k'],
1254 &['i', 'j', 'k'],
1255 &['i', 'j', 'k'],
1256 1.0,
1257 0.0,
1258 )
1259 .unwrap();
1260
1261 assert!((c.get(&[0, 0, 0]) - 0.1).abs() < 1e-12);
1263 assert!((c.get(&[2, 3, 4]) - 360.0).abs() < 1e-10);
1265 }
1266
1267 #[test]
1268 fn test_elementwise_hadamard_with_alpha() {
1269 let a =
1270 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1271 let b =
1272 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1273 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1274
1275 einsum2_into(
1276 c.view_mut(),
1277 &a.view(),
1278 &b.view(),
1279 &['i', 'j'],
1280 &['i', 'j'],
1281 &['i', 'j'],
1282 2.0,
1283 0.0,
1284 )
1285 .unwrap();
1286
1287 assert_eq!(c.get(&[0, 0]), 2.0);
1289 assert_eq!(c.get(&[1, 2]), 72.0);
1291 }
1292
1293 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1294 #[test]
1295 fn test_einsum2_owned_matmul() {
1296 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1297 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1298 });
1299 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1300 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1301 });
1302 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1303
1304 einsum2_into_owned(
1305 c.view_mut(),
1306 a,
1307 b,
1308 &['i', 'k'],
1309 &['i', 'j'],
1310 &['j', 'k'],
1311 1.0,
1312 0.0,
1313 false,
1314 false,
1315 )
1316 .unwrap();
1317
1318 assert_eq!(c.get(&[0, 0]), 19.0);
1319 assert_eq!(c.get(&[0, 1]), 22.0);
1320 assert_eq!(c.get(&[1, 0]), 43.0);
1321 assert_eq!(c.get(&[1, 1]), 50.0);
1322 }
1323
1324 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1325 #[test]
1326 fn test_einsum2_owned_batched() {
1327 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
1328 (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
1329 });
1330 let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1331 (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
1332 });
1333 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1334
1335 einsum2_into_owned(
1336 c.view_mut(),
1337 a,
1338 b,
1339 &['b', 'i', 'k'],
1340 &['b', 'i', 'j'],
1341 &['b', 'j', 'k'],
1342 1.0,
1343 0.0,
1344 false,
1345 false,
1346 )
1347 .unwrap();
1348
1349 assert_eq!(c.get(&[0, 0, 0]), 22.0);
1352 }
1353
1354 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1355 #[test]
1356 fn test_einsum2_owned_alpha_beta() {
1357 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1358 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]]
1359 });
1360 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1361 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1362 });
1363 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1364 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1365 });
1366
1367 einsum2_into_owned(
1368 c.view_mut(),
1369 a,
1370 b,
1371 &['i', 'k'],
1372 &['i', 'j'],
1373 &['j', 'k'],
1374 2.0,
1375 3.0,
1376 false,
1377 false,
1378 )
1379 .unwrap();
1380
1381 assert_eq!(c.get(&[0, 0]), 32.0); assert_eq!(c.get(&[1, 1]), 128.0); }
1385
1386 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1387 #[test]
1388 fn test_einsum2_owned_elementwise() {
1389 let a =
1391 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1392 let b =
1393 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1394 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1395
1396 einsum2_into_owned(
1397 c.view_mut(),
1398 a,
1399 b,
1400 &['i', 'j'],
1401 &['i', 'j'],
1402 &['i', 'j'],
1403 2.0,
1404 0.0,
1405 false,
1406 false,
1407 )
1408 .unwrap();
1409
1410 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1413
1414 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1415 #[test]
1416 fn test_einsum2_owned_left_trace() {
1417 let a =
1420 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1421 let b =
1424 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1425 let mut c = StridedArray::<f64>::row_major(&[2]);
1427
1428 einsum2_into_owned(
1429 c.view_mut(),
1430 a,
1431 b,
1432 &['k'],
1433 &['i', 'j'],
1434 &['j', 'k'],
1435 1.0,
1436 0.0,
1437 false,
1438 false,
1439 )
1440 .unwrap();
1441
1442 assert_eq!(c.get(&[0]), 71.0);
1445 assert_eq!(c.get(&[1]), 92.0);
1446 }
1447
1448 #[test]
1449 fn test_einsum2_naive_matmul() {
1450 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1452 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1453 });
1454 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1455 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1456 });
1457 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1458
1459 einsum2_naive_into(
1460 c.view_mut(),
1461 &a.view(),
1462 &b.view(),
1463 &['i', 'k'],
1464 &['i', 'j'],
1465 &['j', 'k'],
1466 1.0,
1467 0.0,
1468 |x| x,
1469 |x| x,
1470 )
1471 .unwrap();
1472
1473 assert_eq!(c.get(&[0, 0]), 19.0);
1474 assert_eq!(c.get(&[0, 1]), 22.0);
1475 assert_eq!(c.get(&[1, 0]), 43.0);
1476 assert_eq!(c.get(&[1, 1]), 50.0);
1477 }
1478
1479 #[test]
1480 fn test_einsum2_naive_elementwise() {
1481 let a =
1483 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1484 let b =
1485 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1486 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1487
1488 einsum2_naive_into(
1489 c.view_mut(),
1490 &a.view(),
1491 &b.view(),
1492 &['i', 'j'],
1493 &['i', 'j'],
1494 &['i', 'j'],
1495 2.0,
1496 0.0,
1497 |x| x,
1498 |x| x,
1499 )
1500 .unwrap();
1501
1502 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1505
1506 #[test]
1507 fn test_einsum2_naive_custom_type() {
1508 use num_traits::{One, Zero};
1511
1512 #[derive(Debug, Clone, Copy, PartialEq)]
1513 struct MyVal(f64);
1514
1515 impl Default for MyVal {
1516 fn default() -> Self {
1517 MyVal(0.0)
1518 }
1519 }
1520
1521 impl std::ops::Add for MyVal {
1522 type Output = Self;
1523 fn add(self, rhs: Self) -> Self {
1524 MyVal(self.0 + rhs.0)
1525 }
1526 }
1527
1528 impl std::ops::Mul for MyVal {
1529 type Output = Self;
1530 fn mul(self, rhs: Self) -> Self {
1531 MyVal(self.0 * rhs.0)
1532 }
1533 }
1534
1535 impl Zero for MyVal {
1536 fn zero() -> Self {
1537 MyVal(0.0)
1538 }
1539 fn is_zero(&self) -> bool {
1540 self.0 == 0.0
1541 }
1542 }
1543
1544 impl One for MyVal {
1545 fn one() -> Self {
1546 MyVal(1.0)
1547 }
1548 }
1549
1550 let a = StridedArray::from_parts(
1552 vec![MyVal(1.0), MyVal(2.0), MyVal(3.0), MyVal(4.0)],
1553 &[2, 2],
1554 &[2, 1],
1555 0,
1556 )
1557 .unwrap();
1558 let b = StridedArray::from_parts(
1559 vec![MyVal(5.0), MyVal(6.0), MyVal(7.0), MyVal(8.0)],
1560 &[2, 2],
1561 &[2, 1],
1562 0,
1563 )
1564 .unwrap();
1565 let mut c = StridedArray::<MyVal>::col_major(&[2, 2]);
1566
1567 einsum2_naive_into(
1568 c.view_mut(),
1569 &a.view(),
1570 &b.view(),
1571 &['i', 'k'],
1572 &['i', 'j'],
1573 &['j', 'k'],
1574 MyVal(1.0),
1575 MyVal(0.0),
1576 |x| x,
1577 |x| x,
1578 )
1579 .unwrap();
1580
1581 assert_eq!(c.get(&[0, 0]), MyVal(19.0));
1584 assert_eq!(c.get(&[0, 1]), MyVal(22.0));
1585 assert_eq!(c.get(&[1, 0]), MyVal(43.0));
1586 assert_eq!(c.get(&[1, 1]), MyVal(50.0));
1587 }
1588
1589 #[test]
1590 fn test_einsum2_naive_left_trace() {
1591 let a =
1593 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1594 let b =
1595 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1596 let mut c = StridedArray::<f64>::row_major(&[2]);
1597
1598 einsum2_naive_into(
1599 c.view_mut(),
1600 &a.view(),
1601 &b.view(),
1602 &['k'],
1603 &['i', 'j'],
1604 &['j', 'k'],
1605 1.0,
1606 0.0,
1607 |x| x,
1608 |x| x,
1609 )
1610 .unwrap();
1611
1612 assert_eq!(c.get(&[0]), 71.0);
1613 assert_eq!(c.get(&[1]), 92.0);
1614 }
1615
1616 struct TestNaiveBackend;
1618
1619 impl Backend<f64> for TestNaiveBackend {
1620 const MATERIALIZES_CONJ: bool = false;
1621 const REQUIRES_UNIT_STRIDE: bool = false;
1622
1623 fn bgemm_contiguous_into(
1624 c: &mut contiguous::ContiguousOperandMut<f64>,
1625 a: &contiguous::ContiguousOperand<f64>,
1626 b: &contiguous::ContiguousOperand<f64>,
1627 batch_dims: &[usize],
1628 m: usize,
1629 n: usize,
1630 k: usize,
1631 alpha: f64,
1632 beta: f64,
1633 ) -> strided_view::Result<()> {
1634 let a_ptr = a.ptr();
1636 let b_ptr = b.ptr();
1637 let c_ptr = c.ptr();
1638 let a_rs = a.row_stride();
1639 let a_cs = a.col_stride();
1640 let b_rs = b.row_stride();
1641 let b_cs = b.col_stride();
1642 let c_rs = c.row_stride();
1643 let c_cs = c.col_stride();
1644
1645 let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1646 while batch_idx.next().is_some() {
1647 let a_base = batch_idx.offset(a.batch_strides());
1648 let b_base = batch_idx.offset(b.batch_strides());
1649 let c_base = batch_idx.offset(c.batch_strides());
1650
1651 for i in 0..m {
1652 for j in 0..n {
1653 let mut acc = 0.0f64;
1654 for l in 0..k {
1655 let a_val = unsafe {
1656 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1657 };
1658 let b_val = unsafe {
1659 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1660 };
1661 acc += a_val * b_val;
1662 }
1663 unsafe {
1664 let c_elem =
1665 c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1666 if beta == 0.0 {
1667 *c_elem = alpha * acc;
1668 } else {
1669 *c_elem = alpha * acc + beta * (*c_elem);
1670 }
1671 }
1672 }
1673 }
1674 }
1675 Ok(())
1676 }
1677 }
1678
1679 #[test]
1680 fn test_einsum2_with_backend_matmul() {
1681 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1682 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1683 });
1684 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1685 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1686 });
1687 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1688
1689 einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1690 c.view_mut(),
1691 &a.view(),
1692 &b.view(),
1693 &['i', 'k'],
1694 &['i', 'j'],
1695 &['j', 'k'],
1696 1.0,
1697 0.0,
1698 )
1699 .unwrap();
1700
1701 assert_eq!(c.get(&[0, 0]), 19.0);
1702 assert_eq!(c.get(&[0, 1]), 22.0);
1703 assert_eq!(c.get(&[1, 0]), 43.0);
1704 assert_eq!(c.get(&[1, 1]), 50.0);
1705 }
1706
1707 #[test]
1708 fn test_einsum2_with_backend_elementwise() {
1709 let a =
1710 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1711 let b =
1712 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1713 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1714
1715 einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1716 c.view_mut(),
1717 &a.view(),
1718 &b.view(),
1719 &['i', 'j'],
1720 &['i', 'j'],
1721 &['i', 'j'],
1722 2.0,
1723 0.0,
1724 )
1725 .unwrap();
1726
1727 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1730
1731 #[test]
1732 fn test_einsum2_with_backend_custom_type() {
1733 use num_traits::{One, Zero};
1734
1735 #[derive(Debug, Clone, Copy, PartialEq)]
1736 struct Tropical(f64);
1737
1738 impl Default for Tropical {
1739 fn default() -> Self {
1740 Tropical(0.0)
1741 }
1742 }
1743
1744 impl std::ops::Add for Tropical {
1745 type Output = Self;
1746 fn add(self, rhs: Self) -> Self {
1747 Tropical(self.0 + rhs.0)
1748 }
1749 }
1750
1751 impl std::ops::Mul for Tropical {
1752 type Output = Self;
1753 fn mul(self, rhs: Self) -> Self {
1754 Tropical(self.0 * rhs.0)
1755 }
1756 }
1757
1758 impl Zero for Tropical {
1759 fn zero() -> Self {
1760 Tropical(0.0)
1761 }
1762 fn is_zero(&self) -> bool {
1763 self.0 == 0.0
1764 }
1765 }
1766
1767 impl One for Tropical {
1768 fn one() -> Self {
1769 Tropical(1.0)
1770 }
1771 }
1772
1773 struct TropicalBackend;
1774
1775 impl Backend<Tropical> for TropicalBackend {
1776 const MATERIALIZES_CONJ: bool = false;
1777 const REQUIRES_UNIT_STRIDE: bool = false;
1778
1779 fn bgemm_contiguous_into(
1780 c: &mut contiguous::ContiguousOperandMut<Tropical>,
1781 a: &contiguous::ContiguousOperand<Tropical>,
1782 b: &contiguous::ContiguousOperand<Tropical>,
1783 batch_dims: &[usize],
1784 m: usize,
1785 n: usize,
1786 k: usize,
1787 alpha: Tropical,
1788 beta: Tropical,
1789 ) -> strided_view::Result<()> {
1790 let a_ptr = a.ptr();
1792 let b_ptr = b.ptr();
1793 let c_ptr = c.ptr();
1794 let a_rs = a.row_stride();
1795 let a_cs = a.col_stride();
1796 let b_rs = b.row_stride();
1797 let b_cs = b.col_stride();
1798 let c_rs = c.row_stride();
1799 let c_cs = c.col_stride();
1800
1801 let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1802 while batch_idx.next().is_some() {
1803 let a_base = batch_idx.offset(a.batch_strides());
1804 let b_base = batch_idx.offset(b.batch_strides());
1805 let c_base = batch_idx.offset(c.batch_strides());
1806
1807 for i in 0..m {
1808 for j in 0..n {
1809 let mut acc = Tropical::zero();
1810 for l in 0..k {
1811 let a_val = unsafe {
1812 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1813 };
1814 let b_val = unsafe {
1815 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1816 };
1817 acc = acc + a_val * b_val;
1818 }
1819 unsafe {
1820 let c_elem =
1821 c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1822 if beta == Tropical::zero() {
1823 *c_elem = alpha * acc;
1824 } else {
1825 *c_elem = alpha * acc + beta * (*c_elem);
1826 }
1827 }
1828 }
1829 }
1830 }
1831 Ok(())
1832 }
1833 }
1834
1835 let a = StridedArray::from_parts(
1836 vec![Tropical(1.0), Tropical(2.0), Tropical(3.0), Tropical(4.0)],
1837 &[2, 2],
1838 &[2, 1],
1839 0,
1840 )
1841 .unwrap();
1842 let b = StridedArray::from_parts(
1843 vec![Tropical(5.0), Tropical(6.0), Tropical(7.0), Tropical(8.0)],
1844 &[2, 2],
1845 &[2, 1],
1846 0,
1847 )
1848 .unwrap();
1849 let mut c = StridedArray::<Tropical>::col_major(&[2, 2]);
1850
1851 einsum2_with_backend_into::<_, TropicalBackend, _>(
1852 c.view_mut(),
1853 &a.view(),
1854 &b.view(),
1855 &['i', 'k'],
1856 &['i', 'j'],
1857 &['j', 'k'],
1858 Tropical(1.0),
1859 Tropical(0.0),
1860 )
1861 .unwrap();
1862
1863 assert_eq!(c.get(&[0, 0]), Tropical(19.0));
1866 assert_eq!(c.get(&[0, 1]), Tropical(22.0));
1867 assert_eq!(c.get(&[1, 0]), Tropical(43.0));
1868 assert_eq!(c.get(&[1, 1]), Tropical(50.0));
1869 }
1870}