1#[cfg(all(feature = "faer", feature = "blas"))]
25compile_error!("Features `faer` and `blas` are mutually exclusive. Use one or the other.");
26
27#[cfg(all(feature = "faer", feature = "blas-inject"))]
28compile_error!("Features `faer` and `blas-inject` are mutually exclusive.");
29
30#[cfg(all(feature = "blas", feature = "blas-inject"))]
31compile_error!("Features `blas` and `blas-inject` are mutually exclusive.");
32
33#[cfg(all(feature = "blas-inject", not(feature = "blas")))]
34extern crate cblas_inject as cblas_sys;
35#[cfg(all(feature = "blas", not(feature = "blas-inject")))]
36extern crate cblas_sys;
37
38#[cfg(all(
39 not(feature = "faer"),
40 any(
41 all(feature = "blas", not(feature = "blas-inject")),
42 all(feature = "blas-inject", not(feature = "blas"))
43 )
44))]
45pub mod bgemm_blas;
46
47#[cfg(feature = "faer")]
48pub mod bgemm_faer;
50pub mod bgemm_naive;
52pub mod contiguous;
54pub mod plan;
56pub mod trace;
58pub mod util;
60
61pub mod backend;
63
64use std::any::TypeId;
65use std::fmt::Debug;
66use std::hash::Hash;
67
68use strided_kernel::zip_map2_into;
69#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
70use strided_view::StridedArray;
71use strided_view::{Adjoint, Conj, ElementOp, ElementOpApply, StridedView, StridedViewMut};
72
73pub use strided_traits::ScalarBase;
74
75pub use backend::{BackendConfig, BgemmBackend};
76pub use plan::Einsum2Plan;
77
78pub trait AxisId: Clone + Eq + Hash + Debug {}
80impl<T: Clone + Eq + Hash + Debug> AxisId for T {}
81
82#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
91pub trait Scalar: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
92
93#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
94impl<T> Scalar for T where T: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
95
96#[cfg(all(
100 not(feature = "faer"),
101 any(
102 all(feature = "blas", not(feature = "blas-inject")),
103 all(feature = "blas-inject", not(feature = "blas"))
104 )
105))]
106pub trait Scalar: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
107
108#[cfg(all(
109 not(feature = "faer"),
110 any(
111 all(feature = "blas", not(feature = "blas-inject")),
112 all(feature = "blas-inject", not(feature = "blas"))
113 )
114))]
115impl<T> Scalar for T where T: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
116
117#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
119pub trait Scalar: ScalarBase + ElementOpApply {}
120
121#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
122impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
123
124#[cfg(any(
129 all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
130 all(feature = "blas", feature = "blas-inject")
131))]
132pub trait Scalar: ScalarBase + ElementOpApply {}
133
134#[cfg(any(
135 all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
136 all(feature = "blas", feature = "blas-inject")
137))]
138impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
139
140#[derive(Debug, thiserror::Error)]
142pub enum EinsumError {
143 #[error("duplicate axis label: {0}")]
144 DuplicateAxis(String),
145 #[error("output axis {0} not found in any input")]
146 OrphanOutputAxis(String),
147 #[error("dimension mismatch for axis {axis:?}: {dim_a} vs {dim_b}")]
148 DimensionMismatch {
149 axis: String,
150 dim_a: usize,
151 dim_b: usize,
152 },
153 #[error(transparent)]
154 Strided(#[from] strided_view::StridedError),
155}
156
157pub type Result<T> = std::result::Result<T, EinsumError>;
159
160fn op_is_conj<Op: 'static>() -> bool {
169 TypeId::of::<Op>() == TypeId::of::<Conj>() || TypeId::of::<Op>() == TypeId::of::<Adjoint>()
170}
171
172pub fn einsum2_into<T: Scalar, OpA, OpB, ID: AxisId>(
183 c: StridedViewMut<T>,
184 a: &StridedView<T, OpA>,
185 b: &StridedView<T, OpB>,
186 ic: &[ID],
187 ia: &[ID],
188 ib: &[ID],
189 alpha: T,
190 beta: T,
191) -> Result<()>
192where
193 OpA: ElementOp<T> + 'static,
194 OpB: ElementOp<T> + 'static,
195{
196 let plan = Einsum2Plan::new(ia, ib, ic)?;
198
199 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
201
202 let (a_buf, conj_a) = if !plan.left_trace.is_empty() {
207 let trace_indices = plan.left_trace_indices(ia);
208 (Some(trace::reduce_trace_axes(a, &trace_indices)?), false)
209 } else {
210 (None, op_is_conj::<OpA>())
211 };
212
213 let a_view: StridedView<T> = match a_buf.as_ref() {
214 Some(buf) => buf.view(),
215 None => StridedView::new(a.data(), a.dims(), a.strides(), a.offset())
216 .expect("strip_op_view: metadata already validated"),
217 };
218
219 let (b_buf, conj_b) = if !plan.right_trace.is_empty() {
220 let trace_indices = plan.right_trace_indices(ib);
221 (Some(trace::reduce_trace_axes(b, &trace_indices)?), false)
222 } else {
223 (None, op_is_conj::<OpB>())
224 };
225
226 let b_view: StridedView<T> = match b_buf.as_ref() {
227 Some(buf) => buf.view(),
228 None => StridedView::new(b.data(), b.dims(), b.strides(), b.offset())
229 .expect("strip_op_view: metadata already validated"),
230 };
231
232 #[cfg(any(
234 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
235 all(
236 not(feature = "faer"),
237 any(
238 all(feature = "blas", not(feature = "blas-inject")),
239 all(feature = "blas-inject", not(feature = "blas"))
240 )
241 )
242 ))]
243 einsum2_gemm_dispatch(c, &a_view, &b_view, &plan, alpha, beta, conj_a, conj_b)?;
244
245 #[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
246 {
247 let a_perm = a_view.permute(&plan.left_perm)?;
248 let b_perm = b_view.permute(&plan.right_perm)?;
249 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
250
251 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
252 let mul_fn = move |a_val: T, b_val: T| -> T {
253 let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
254 let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
255 alpha * a_c * b_c
256 };
257 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
258 return Ok(());
259 }
260
261 bgemm_naive::bgemm_strided_into(
262 &mut c_perm,
263 &a_perm,
264 &b_perm,
265 plan.batch.len(),
266 plan.lo.len(),
267 plan.ro.len(),
268 plan.sum.len(),
269 alpha,
270 beta,
271 conj_a,
272 conj_b,
273 )?;
274 }
275
276 Ok(())
277}
278
279pub fn einsum2_naive_into<T, ID, MapA, MapB>(
287 c: StridedViewMut<T>,
288 a: &StridedView<T>,
289 b: &StridedView<T>,
290 ic: &[ID],
291 ia: &[ID],
292 ib: &[ID],
293 alpha: T,
294 beta: T,
295 map_a: MapA,
296 map_b: MapB,
297) -> Result<()>
298where
299 T: ScalarBase,
300 ID: AxisId,
301 MapA: Fn(T) -> T,
302 MapB: Fn(T) -> T,
303{
304 let plan = Einsum2Plan::new(ia, ib, ic)?;
305 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
306
307 let (a_buf, use_map_a) = if !plan.left_trace.is_empty() {
311 let trace_indices = plan.left_trace_indices(ia);
312 let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(a.dims()) };
313 strided_kernel::map_into(&mut mapped.view_mut(), a, &map_a)?;
314 let reduced = trace::reduce_trace_axes(&mapped.view(), &trace_indices)?;
315 (Some(reduced), false)
316 } else {
317 (None, true)
318 };
319 let a_view: StridedView<T> = match a_buf.as_ref() {
320 Some(buf) => buf.view(),
321 None => a.clone(),
322 };
323
324 let (b_buf, use_map_b) = if !plan.right_trace.is_empty() {
325 let trace_indices = plan.right_trace_indices(ib);
326 let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(b.dims()) };
327 strided_kernel::map_into(&mut mapped.view_mut(), b, &map_b)?;
328 let reduced = trace::reduce_trace_axes(&mapped.view(), &trace_indices)?;
329 (Some(reduced), false)
330 } else {
331 (None, true)
332 };
333 let b_view: StridedView<T> = match b_buf.as_ref() {
334 Some(buf) => buf.view(),
335 None => b.clone(),
336 };
337
338 let a_perm = a_view.permute(&plan.left_perm)?;
339 let b_perm = b_view.permute(&plan.right_perm)?;
340 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
341
342 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
344 let mul_fn = move |a_val: T, b_val: T| -> T {
345 let a_c = if use_map_a { map_a(a_val) } else { a_val };
346 let b_c = if use_map_b { map_b(b_val) } else { b_val };
347 alpha * a_c * b_c
348 };
349 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
350 return Ok(());
351 }
352
353 let final_map_a: Box<dyn Fn(T) -> T> = if use_map_a {
354 Box::new(map_a)
355 } else {
356 Box::new(|x| x)
357 };
358 let final_map_b: Box<dyn Fn(T) -> T> = if use_map_b {
359 Box::new(map_b)
360 } else {
361 Box::new(|x| x)
362 };
363
364 bgemm_naive::bgemm_strided_into_with_map(
365 &mut c_perm,
366 &a_perm,
367 &b_perm,
368 plan.batch.len(),
369 plan.lo.len(),
370 plan.ro.len(),
371 plan.sum.len(),
372 alpha,
373 beta,
374 final_map_a,
375 final_map_b,
376 )?;
377
378 Ok(())
379}
380
381pub fn einsum2_with_backend_into<T, B, ID>(
390 c: StridedViewMut<T>,
391 a: &StridedView<T>,
392 b: &StridedView<T>,
393 ic: &[ID],
394 ia: &[ID],
395 ib: &[ID],
396 alpha: T,
397 beta: T,
398) -> Result<()>
399where
400 T: ScalarBase,
401 B: BgemmBackend<T> + BackendConfig,
402 ID: AxisId,
403{
404 let plan = Einsum2Plan::new(ia, ib, ic)?;
405 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
406
407 let a_buf = if !plan.left_trace.is_empty() {
409 let trace_indices = plan.left_trace_indices(ia);
410 Some(trace::reduce_trace_axes(a, &trace_indices)?)
411 } else {
412 None
413 };
414 let a_view: StridedView<T> = match a_buf.as_ref() {
415 Some(buf) => buf.view(),
416 None => a.clone(),
417 };
418
419 let b_buf = if !plan.right_trace.is_empty() {
420 let trace_indices = plan.right_trace_indices(ib);
421 Some(trace::reduce_trace_axes(b, &trace_indices)?)
422 } else {
423 None
424 };
425 let b_view: StridedView<T> = match b_buf.as_ref() {
426 Some(buf) => buf.view(),
427 None => b.clone(),
428 };
429
430 let a_perm = a_view.permute(&plan.left_perm)?;
432 let b_perm = b_view.permute(&plan.right_perm)?;
433 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
434
435 let n_batch = plan.batch.len();
436 let n_lo = plan.lo.len();
437 let n_ro = plan.ro.len();
438 let n_sum = plan.sum.len();
439
440 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
442 if alpha == T::one() {
443 zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
444 } else {
445 let mul_fn = move |a_val: T, b_val: T| -> T { alpha * a_val * b_val };
446 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
447 }
448 return Ok(());
449 }
450
451 let a_op = contiguous::prepare_input_view_for_backend::<T, B>(&a_perm, n_batch, n_lo, n_sum)?;
453 let b_op = contiguous::prepare_input_view_for_backend::<T, B>(&b_perm, n_batch, n_sum, n_ro)?;
454 let mut c_op = contiguous::prepare_output_view_for_backend::<T, B>(
455 &mut c_perm,
456 n_batch,
457 n_lo,
458 n_ro,
459 beta,
460 )?;
461
462 let lo_dims = &a_perm.dims()[..n_lo];
464 let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
465 let batch_dims = &a_perm.dims()[n_lo + n_sum..];
466 let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
467 let m: usize = lo_dims.iter().product::<usize>().max(1);
468 let k: usize = sum_dims.iter().product::<usize>().max(1);
469 let n: usize = ro_dims.iter().product::<usize>().max(1);
470
471 B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
473
474 c_op.finalize_into(&mut c_perm)?;
476
477 Ok(())
478}
479
480#[cfg(any(
489 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
490 all(
491 not(feature = "faer"),
492 any(
493 all(feature = "blas", not(feature = "blas-inject")),
494 all(feature = "blas-inject", not(feature = "blas"))
495 )
496 )
497))]
498fn einsum2_gemm_dispatch<T: Scalar>(
499 c: StridedViewMut<T>,
500 a: &StridedView<T>,
501 b: &StridedView<T>,
502 plan: &Einsum2Plan<impl AxisId>,
503 alpha: T,
504 beta: T,
505 conj_a: bool,
506 conj_b: bool,
507) -> Result<()>
508where
509 backend::ActiveBackend: BgemmBackend<T>,
510{
511 let a_perm = a.permute(&plan.left_perm)?;
513 let b_perm = b.permute(&plan.right_perm)?;
514 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
515
516 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
518 if alpha == T::one() && !conj_a && !conj_b {
519 zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
520 } else if alpha == T::one() {
521 let mul_fn = move |a_val: T, b_val: T| -> T {
522 let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
523 let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
524 a_c * b_c
525 };
526 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
527 } else {
528 let mul_fn = move |a_val: T, b_val: T| -> T {
529 let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
530 let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
531 alpha * a_c * b_c
532 };
533 zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
534 }
535 return Ok(());
536 }
537
538 let n_batch = plan.batch.len();
540 let n_lo = plan.lo.len();
541 let n_ro = plan.ro.len();
542 let n_sum = plan.sum.len();
543
544 let a_op = contiguous::prepare_input_view(&a_perm, n_batch, n_lo, n_sum, conj_a)?;
545 let b_op = contiguous::prepare_input_view(&b_perm, n_batch, n_sum, n_ro, conj_b)?;
546 let mut c_op = contiguous::prepare_output_view(&mut c_perm, n_batch, n_lo, n_ro, beta)?;
547
548 let lo_dims = &a_perm.dims()[..n_lo];
550 let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
551 let batch_dims = &a_perm.dims()[n_lo + n_sum..];
552 let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
553 let m: usize = lo_dims.iter().product::<usize>().max(1);
554 let k: usize = sum_dims.iter().product::<usize>().max(1);
555 let n: usize = ro_dims.iter().product::<usize>().max(1);
556
557 backend::ActiveBackend::bgemm_contiguous_into(
559 &mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta,
560 )?;
561
562 c_op.finalize_into(&mut c_perm)?;
564
565 Ok(())
566}
567
568#[cfg(any(
577 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
578 all(
579 not(feature = "faer"),
580 any(
581 all(feature = "blas", not(feature = "blas-inject")),
582 all(feature = "blas-inject", not(feature = "blas"))
583 )
584 )
585))]
586pub fn einsum2_into_owned<T: Scalar, ID: AxisId>(
587 c: StridedViewMut<T>,
588 a: StridedArray<T>,
589 b: StridedArray<T>,
590 ic: &[ID],
591 ia: &[ID],
592 ib: &[ID],
593 alpha: T,
594 beta: T,
595 conj_a: bool,
596 conj_b: bool,
597) -> Result<()>
598where
599 backend::ActiveBackend: BgemmBackend<T>,
600{
601 let plan = Einsum2Plan::new(ia, ib, ic)?;
603
604 validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
606
607 let (a_for_gemm, conj_a_final) = if !plan.left_trace.is_empty() {
611 let trace_indices = plan.left_trace_indices(ia);
612 (trace::reduce_trace_axes(&a.view(), &trace_indices)?, false)
613 } else {
614 (a, conj_a)
615 };
616
617 let (b_for_gemm, conj_b_final) = if !plan.right_trace.is_empty() {
618 let trace_indices = plan.right_trace_indices(ib);
619 (trace::reduce_trace_axes(&b.view(), &trace_indices)?, false)
620 } else {
621 (b, conj_b)
622 };
623
624 let a_perm = a_for_gemm.permuted(&plan.left_perm)?;
626 let b_perm = b_for_gemm.permuted(&plan.right_perm)?;
627 let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
628
629 let n_batch = plan.batch.len();
630 let n_lo = plan.lo.len();
631 let n_ro = plan.ro.len();
632 let n_sum = plan.sum.len();
633
634 if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
636 let mul_fn = move |a_val: T, b_val: T| -> T {
637 let a_c = if conj_a_final {
638 Conj::apply(a_val)
639 } else {
640 a_val
641 };
642 let b_c = if conj_b_final {
643 Conj::apply(b_val)
644 } else {
645 b_val
646 };
647 alpha * a_c * b_c
648 };
649 zip_map2_into(&mut c_perm, &a_perm.view(), &b_perm.view(), mul_fn)?;
650 return Ok(());
651 }
652
653 let a_dims_perm = a_perm.dims().to_vec();
655 let b_dims_perm = b_perm.dims().to_vec();
656
657 let lo_dims = &a_dims_perm[..n_lo];
658 let sum_dims = &a_dims_perm[n_lo..n_lo + n_sum];
659 let batch_dims = a_dims_perm[n_lo + n_sum..].to_vec();
660 let ro_dims = &b_dims_perm[n_sum..n_sum + n_ro];
661 let m: usize = lo_dims.iter().product::<usize>().max(1);
662 let k: usize = sum_dims.iter().product::<usize>().max(1);
663 let n: usize = ro_dims.iter().product::<usize>().max(1);
664
665 let a_op = contiguous::prepare_input_owned(a_perm, n_batch, n_lo, n_sum, conj_a_final)?;
667 let b_op = contiguous::prepare_input_owned(b_perm, n_batch, n_sum, n_ro, conj_b_final)?;
668 let mut c_op = contiguous::prepare_output_view(&mut c_perm, n_batch, n_lo, n_ro, beta)?;
669
670 backend::ActiveBackend::bgemm_contiguous_into(
672 &mut c_op,
673 &a_op,
674 &b_op,
675 &batch_dims,
676 m,
677 n,
678 k,
679 alpha,
680 beta,
681 )?;
682
683 c_op.finalize_into(&mut c_perm)?;
685
686 Ok(())
687}
688
689fn validate_dimensions<ID: AxisId>(
691 plan: &Einsum2Plan<ID>,
692 a_dims: &[usize],
693 b_dims: &[usize],
694 c_dims: &[usize],
695 ia: &[ID],
696 ib: &[ID],
697 ic: &[ID],
698) -> Result<()> {
699 let find_dim = |labels: &[ID], dims: &[usize], id: &ID| -> usize {
700 labels
701 .iter()
702 .position(|x| x == id)
703 .map(|i| dims[i])
704 .unwrap()
705 };
706
707 for id in &plan.batch {
709 let da = find_dim(ia, a_dims, id);
710 let db = find_dim(ib, b_dims, id);
711 let dc = find_dim(ic, c_dims, id);
712 if da != db || da != dc {
713 return Err(EinsumError::DimensionMismatch {
714 axis: format!("{:?}", id),
715 dim_a: da,
716 dim_b: db,
717 });
718 }
719 }
720
721 for id in &plan.sum {
723 let da = find_dim(ia, a_dims, id);
724 let db = find_dim(ib, b_dims, id);
725 if da != db {
726 return Err(EinsumError::DimensionMismatch {
727 axis: format!("{:?}", id),
728 dim_a: da,
729 dim_b: db,
730 });
731 }
732 }
733
734 for id in &plan.lo {
736 let da = find_dim(ia, a_dims, id);
737 let dc = find_dim(ic, c_dims, id);
738 if da != dc {
739 return Err(EinsumError::DimensionMismatch {
740 axis: format!("{:?}", id),
741 dim_a: da,
742 dim_b: dc,
743 });
744 }
745 }
746
747 for id in &plan.ro {
749 let db = find_dim(ib, b_dims, id);
750 let dc = find_dim(ic, c_dims, id);
751 if db != dc {
752 return Err(EinsumError::DimensionMismatch {
753 axis: format!("{:?}", id),
754 dim_a: db,
755 dim_b: dc,
756 });
757 }
758 }
759
760 Ok(())
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766 use strided_view::StridedArray;
767
768 #[test]
769 fn test_matmul_ij_jk_ik() {
770 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
772 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
773 });
774 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
775 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
776 });
777 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
778
779 einsum2_into(
780 c.view_mut(),
781 &a.view(),
782 &b.view(),
783 &['i', 'k'],
784 &['i', 'j'],
785 &['j', 'k'],
786 1.0,
787 0.0,
788 )
789 .unwrap();
790
791 assert_eq!(c.get(&[0, 0]), 19.0);
792 assert_eq!(c.get(&[0, 1]), 22.0);
793 assert_eq!(c.get(&[1, 0]), 43.0);
794 assert_eq!(c.get(&[1, 1]), 50.0);
795 }
796
797 #[test]
798 fn test_matmul_rect() {
799 let a =
801 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
802 let b =
803 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
804 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
805
806 einsum2_into(
807 c.view_mut(),
808 &a.view(),
809 &b.view(),
810 &['i', 'k'],
811 &['i', 'j'],
812 &['j', 'k'],
813 1.0,
814 0.0,
815 )
816 .unwrap();
817
818 assert_eq!(c.get(&[0, 0]), 38.0);
820 assert_eq!(c.get(&[1, 3]), 128.0);
821 }
822
823 #[test]
824 fn test_batched_matmul() {
825 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
827 (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
828 });
829 let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
830 (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
831 });
832 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
833
834 einsum2_into(
835 c.view_mut(),
836 &a.view(),
837 &b.view(),
838 &['b', 'i', 'k'],
839 &['b', 'i', 'j'],
840 &['b', 'j', 'k'],
841 1.0,
842 0.0,
843 )
844 .unwrap();
845
846 assert_eq!(c.get(&[0, 0, 0]), 22.0);
849 }
850
851 #[test]
852 fn test_batched_matmul_col_major_output() {
853 let a_data = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
855 let b_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
856 let a = StridedArray::<f64>::from_parts(a_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
857 let b = StridedArray::<f64>::from_parts(b_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
858 let mut c = StridedArray::<f64>::col_major(&[2, 2, 2]);
859
860 einsum2_into(
861 c.view_mut(),
862 &a.view(),
863 &b.view(),
864 &['b', 'i', 'k'],
865 &['b', 'i', 'j'],
866 &['b', 'j', 'k'],
867 1.0,
868 0.0,
869 )
870 .unwrap();
871
872 assert_eq!(c.get(&[0, 0, 0]), 1.0);
874 assert_eq!(c.get(&[0, 0, 1]), 2.0);
875 assert_eq!(c.get(&[0, 1, 0]), 3.0);
876 assert_eq!(c.get(&[0, 1, 1]), 4.0);
877 assert_eq!(c.get(&[1, 0, 0]), 10.0);
879 assert_eq!(c.get(&[1, 1, 1]), 16.0);
880 }
881
882 #[test]
883 fn test_outer_product() {
884 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
886 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
887 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
888
889 einsum2_into(
890 c.view_mut(),
891 &a.view(),
892 &b.view(),
893 &['i', 'j'],
894 &['i'],
895 &['j'],
896 1.0,
897 0.0,
898 )
899 .unwrap();
900
901 assert_eq!(c.get(&[0, 0]), 1.0);
902 assert_eq!(c.get(&[2, 3]), 12.0);
903 }
904
905 #[test]
906 fn test_dot_product() {
907 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
909 let b = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
910 let mut c = StridedArray::<f64>::row_major(&[]);
911
912 einsum2_into(
913 c.view_mut(),
914 &a.view(),
915 &b.view(),
916 &[] as &[char],
917 &['i'],
918 &['i'],
919 1.0,
920 0.0,
921 )
922 .unwrap();
923
924 assert_eq!(c.get(&[]), 14.0);
926 }
927
928 #[test]
929 fn test_alpha_beta() {
930 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
932 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
934 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
935 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
936 });
937 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
938 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
939 });
940
941 einsum2_into(
942 c.view_mut(),
943 &a.view(),
944 &b.view(),
945 &['i', 'k'],
946 &['i', 'j'],
947 &['j', 'k'],
948 2.0,
949 3.0,
950 )
951 .unwrap();
952
953 assert_eq!(c.get(&[0, 0]), 32.0); assert_eq!(c.get(&[1, 1]), 128.0); }
957
958 #[test]
959 fn test_transposed_output() {
960 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
962 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
963 });
964 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
965 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
966 });
967 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
968
969 einsum2_into(
970 c.view_mut(),
971 &a.view(),
972 &b.view(),
973 &['k', 'i'], &['i', 'j'],
975 &['j', 'k'],
976 1.0,
977 0.0,
978 )
979 .unwrap();
980
981 assert_eq!(c.get(&[0, 0]), 19.0); assert_eq!(c.get(&[0, 1]), 43.0); assert_eq!(c.get(&[1, 0]), 22.0); assert_eq!(c.get(&[1, 1]), 50.0); }
988
989 #[test]
990 fn test_left_trace() {
991 let a =
994 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
995 let b =
998 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
999 let mut c = StridedArray::<f64>::row_major(&[2]);
1001
1002 einsum2_into(
1003 c.view_mut(),
1004 &a.view(),
1005 &b.view(),
1006 &['k'],
1007 &['i', 'j'],
1008 &['j', 'k'],
1009 1.0,
1010 0.0,
1011 )
1012 .unwrap();
1013
1014 assert_eq!(c.get(&[0]), 71.0);
1018 assert_eq!(c.get(&[1]), 92.0);
1019 }
1020
1021 #[test]
1022 fn test_u32_labels() {
1023 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1025 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1026 });
1027 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1028 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1029 });
1030 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1031
1032 einsum2_into(
1033 c.view_mut(),
1034 &a.view(),
1035 &b.view(),
1036 &[0u32, 2],
1037 &[0u32, 1],
1038 &[1u32, 2],
1039 1.0,
1040 0.0,
1041 )
1042 .unwrap();
1043
1044 assert_eq!(c.get(&[0, 0]), 19.0);
1045 assert_eq!(c.get(&[1, 1]), 50.0);
1046 }
1047
1048 #[test]
1049 fn test_complex_matmul() {
1050 use num_complex::Complex64;
1051 let i = Complex64::i();
1052
1053 let a_vals = [
1055 [1.0 + i, Complex64::new(2.0, 0.0)],
1056 [Complex64::new(3.0, 0.0), 4.0 - i],
1057 ];
1058 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1059
1060 let b_vals = [
1062 [Complex64::new(1.0, 0.0), i],
1063 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
1064 ];
1065 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1066
1067 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1068
1069 einsum2_into(
1070 c.view_mut(),
1071 &a.view(),
1072 &b.view(),
1073 &['i', 'k'],
1074 &['i', 'j'],
1075 &['j', 'k'],
1076 Complex64::new(1.0, 0.0),
1077 Complex64::new(0.0, 0.0),
1078 )
1079 .unwrap();
1080
1081 assert_eq!(c.get(&[0, 0]), 1.0 + i);
1087 assert_eq!(c.get(&[0, 1]), 1.0 + i);
1088 assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1089 assert_eq!(c.get(&[1, 1]), 4.0 + 2.0 * i);
1090 }
1091
1092 #[test]
1093 fn test_complex_matmul_with_conj() {
1094 use num_complex::Complex64;
1095 let i = Complex64::i();
1096
1097 let a_vals = [[1.0 + i, 2.0 * i], [Complex64::new(3.0, 0.0), 4.0 - i]];
1099 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1100
1101 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| {
1103 if idx[0] == idx[1] {
1104 Complex64::new(1.0, 0.0)
1105 } else {
1106 Complex64::new(0.0, 0.0)
1107 }
1108 });
1109
1110 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1111
1112 let a_conj = a.view().conj();
1114 einsum2_into(
1115 c.view_mut(),
1116 &a_conj,
1117 &b.view(),
1118 &['i', 'k'],
1119 &['i', 'j'],
1120 &['j', 'k'],
1121 Complex64::new(1.0, 0.0),
1122 Complex64::new(0.0, 0.0),
1123 )
1124 .unwrap();
1125
1126 assert_eq!(c.get(&[0, 0]), 1.0 - i);
1128 assert_eq!(c.get(&[0, 1]), -2.0 * i);
1129 assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1130 assert_eq!(c.get(&[1, 1]), 4.0 + i);
1131 }
1132
1133 #[test]
1134 fn test_complex_matmul_with_conj_both() {
1135 use num_complex::Complex64;
1136 let i = Complex64::i();
1137
1138 let a_vals = [
1140 [1.0 + i, Complex64::new(0.0, 0.0)],
1141 [Complex64::new(0.0, 0.0), 2.0 - i],
1142 ];
1143 let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1144
1145 let b_vals = [
1147 [Complex64::new(1.0, 0.0), i],
1148 [Complex64::new(0.0, 0.0), 1.0 + i],
1149 ];
1150 let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1151
1152 let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1153
1154 let a_conj = a.view().conj();
1156 let b_conj = b.view().conj();
1157 einsum2_into(
1158 c.view_mut(),
1159 &a_conj,
1160 &b_conj,
1161 &['i', 'k'],
1162 &['i', 'j'],
1163 &['j', 'k'],
1164 Complex64::new(1.0, 0.0),
1165 Complex64::new(0.0, 0.0),
1166 )
1167 .unwrap();
1168
1169 assert_eq!(c.get(&[0, 0]), 1.0 - i);
1177 assert_eq!(c.get(&[0, 1]), -(1.0 + i));
1178 assert_eq!(c.get(&[1, 0]), Complex64::new(0.0, 0.0));
1179 assert_eq!(c.get(&[1, 1]), 3.0 - i);
1180 }
1181
1182 #[test]
1183 fn test_elementwise_hadamard() {
1184 let a = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1186 (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64
1187 });
1188 let b = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1189 (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64 * 0.1
1190 });
1191 let mut c = StridedArray::<f64>::row_major(&[3, 4, 5]);
1192
1193 einsum2_into(
1194 c.view_mut(),
1195 &a.view(),
1196 &b.view(),
1197 &['i', 'j', 'k'],
1198 &['i', 'j', 'k'],
1199 &['i', 'j', 'k'],
1200 1.0,
1201 0.0,
1202 )
1203 .unwrap();
1204
1205 assert!((c.get(&[0, 0, 0]) - 0.1).abs() < 1e-12);
1207 assert!((c.get(&[2, 3, 4]) - 360.0).abs() < 1e-10);
1209 }
1210
1211 #[test]
1212 fn test_elementwise_hadamard_with_alpha() {
1213 let a =
1214 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1215 let b =
1216 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1217 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1218
1219 einsum2_into(
1220 c.view_mut(),
1221 &a.view(),
1222 &b.view(),
1223 &['i', 'j'],
1224 &['i', 'j'],
1225 &['i', 'j'],
1226 2.0,
1227 0.0,
1228 )
1229 .unwrap();
1230
1231 assert_eq!(c.get(&[0, 0]), 2.0);
1233 assert_eq!(c.get(&[1, 2]), 72.0);
1235 }
1236
1237 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1238 #[test]
1239 fn test_einsum2_owned_matmul() {
1240 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1241 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1242 });
1243 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1244 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1245 });
1246 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1247
1248 einsum2_into_owned(
1249 c.view_mut(),
1250 a,
1251 b,
1252 &['i', 'k'],
1253 &['i', 'j'],
1254 &['j', 'k'],
1255 1.0,
1256 0.0,
1257 false,
1258 false,
1259 )
1260 .unwrap();
1261
1262 assert_eq!(c.get(&[0, 0]), 19.0);
1263 assert_eq!(c.get(&[0, 1]), 22.0);
1264 assert_eq!(c.get(&[1, 0]), 43.0);
1265 assert_eq!(c.get(&[1, 1]), 50.0);
1266 }
1267
1268 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1269 #[test]
1270 fn test_einsum2_owned_batched() {
1271 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
1272 (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
1273 });
1274 let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1275 (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
1276 });
1277 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1278
1279 einsum2_into_owned(
1280 c.view_mut(),
1281 a,
1282 b,
1283 &['b', 'i', 'k'],
1284 &['b', 'i', 'j'],
1285 &['b', 'j', 'k'],
1286 1.0,
1287 0.0,
1288 false,
1289 false,
1290 )
1291 .unwrap();
1292
1293 assert_eq!(c.get(&[0, 0, 0]), 22.0);
1296 }
1297
1298 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1299 #[test]
1300 fn test_einsum2_owned_alpha_beta() {
1301 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1302 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]]
1303 });
1304 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1305 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1306 });
1307 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1308 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1309 });
1310
1311 einsum2_into_owned(
1312 c.view_mut(),
1313 a,
1314 b,
1315 &['i', 'k'],
1316 &['i', 'j'],
1317 &['j', 'k'],
1318 2.0,
1319 3.0,
1320 false,
1321 false,
1322 )
1323 .unwrap();
1324
1325 assert_eq!(c.get(&[0, 0]), 32.0); assert_eq!(c.get(&[1, 1]), 128.0); }
1329
1330 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1331 #[test]
1332 fn test_einsum2_owned_elementwise() {
1333 let a =
1335 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1336 let b =
1337 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1338 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1339
1340 einsum2_into_owned(
1341 c.view_mut(),
1342 a,
1343 b,
1344 &['i', 'j'],
1345 &['i', 'j'],
1346 &['i', 'j'],
1347 2.0,
1348 0.0,
1349 false,
1350 false,
1351 )
1352 .unwrap();
1353
1354 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1357
1358 #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1359 #[test]
1360 fn test_einsum2_owned_left_trace() {
1361 let a =
1364 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1365 let b =
1368 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1369 let mut c = StridedArray::<f64>::row_major(&[2]);
1371
1372 einsum2_into_owned(
1373 c.view_mut(),
1374 a,
1375 b,
1376 &['k'],
1377 &['i', 'j'],
1378 &['j', 'k'],
1379 1.0,
1380 0.0,
1381 false,
1382 false,
1383 )
1384 .unwrap();
1385
1386 assert_eq!(c.get(&[0]), 71.0);
1389 assert_eq!(c.get(&[1]), 92.0);
1390 }
1391
1392 #[test]
1393 fn test_einsum2_naive_matmul() {
1394 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1396 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1397 });
1398 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1399 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1400 });
1401 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1402
1403 einsum2_naive_into(
1404 c.view_mut(),
1405 &a.view(),
1406 &b.view(),
1407 &['i', 'k'],
1408 &['i', 'j'],
1409 &['j', 'k'],
1410 1.0,
1411 0.0,
1412 |x| x,
1413 |x| x,
1414 )
1415 .unwrap();
1416
1417 assert_eq!(c.get(&[0, 0]), 19.0);
1418 assert_eq!(c.get(&[0, 1]), 22.0);
1419 assert_eq!(c.get(&[1, 0]), 43.0);
1420 assert_eq!(c.get(&[1, 1]), 50.0);
1421 }
1422
1423 #[test]
1424 fn test_einsum2_naive_elementwise() {
1425 let a =
1427 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1428 let b =
1429 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1430 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1431
1432 einsum2_naive_into(
1433 c.view_mut(),
1434 &a.view(),
1435 &b.view(),
1436 &['i', 'j'],
1437 &['i', 'j'],
1438 &['i', 'j'],
1439 2.0,
1440 0.0,
1441 |x| x,
1442 |x| x,
1443 )
1444 .unwrap();
1445
1446 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1449
1450 #[test]
1451 fn test_einsum2_naive_custom_type() {
1452 use num_traits::{One, Zero};
1455
1456 #[derive(Debug, Clone, Copy, PartialEq)]
1457 struct MyVal(f64);
1458
1459 impl Default for MyVal {
1460 fn default() -> Self {
1461 MyVal(0.0)
1462 }
1463 }
1464
1465 impl std::ops::Add for MyVal {
1466 type Output = Self;
1467 fn add(self, rhs: Self) -> Self {
1468 MyVal(self.0 + rhs.0)
1469 }
1470 }
1471
1472 impl std::ops::Mul for MyVal {
1473 type Output = Self;
1474 fn mul(self, rhs: Self) -> Self {
1475 MyVal(self.0 * rhs.0)
1476 }
1477 }
1478
1479 impl Zero for MyVal {
1480 fn zero() -> Self {
1481 MyVal(0.0)
1482 }
1483 fn is_zero(&self) -> bool {
1484 self.0 == 0.0
1485 }
1486 }
1487
1488 impl One for MyVal {
1489 fn one() -> Self {
1490 MyVal(1.0)
1491 }
1492 }
1493
1494 let a = StridedArray::from_parts(
1496 vec![MyVal(1.0), MyVal(2.0), MyVal(3.0), MyVal(4.0)],
1497 &[2, 2],
1498 &[2, 1],
1499 0,
1500 )
1501 .unwrap();
1502 let b = StridedArray::from_parts(
1503 vec![MyVal(5.0), MyVal(6.0), MyVal(7.0), MyVal(8.0)],
1504 &[2, 2],
1505 &[2, 1],
1506 0,
1507 )
1508 .unwrap();
1509 let mut c = StridedArray::<MyVal>::col_major(&[2, 2]);
1510
1511 einsum2_naive_into(
1512 c.view_mut(),
1513 &a.view(),
1514 &b.view(),
1515 &['i', 'k'],
1516 &['i', 'j'],
1517 &['j', 'k'],
1518 MyVal(1.0),
1519 MyVal(0.0),
1520 |x| x,
1521 |x| x,
1522 )
1523 .unwrap();
1524
1525 assert_eq!(c.get(&[0, 0]), MyVal(19.0));
1528 assert_eq!(c.get(&[0, 1]), MyVal(22.0));
1529 assert_eq!(c.get(&[1, 0]), MyVal(43.0));
1530 assert_eq!(c.get(&[1, 1]), MyVal(50.0));
1531 }
1532
1533 #[test]
1534 fn test_einsum2_naive_left_trace() {
1535 let a =
1537 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1538 let b =
1539 StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1540 let mut c = StridedArray::<f64>::row_major(&[2]);
1541
1542 einsum2_naive_into(
1543 c.view_mut(),
1544 &a.view(),
1545 &b.view(),
1546 &['k'],
1547 &['i', 'j'],
1548 &['j', 'k'],
1549 1.0,
1550 0.0,
1551 |x| x,
1552 |x| x,
1553 )
1554 .unwrap();
1555
1556 assert_eq!(c.get(&[0]), 71.0);
1557 assert_eq!(c.get(&[1]), 92.0);
1558 }
1559
1560 struct TestNaiveBackend;
1562
1563 impl BackendConfig for TestNaiveBackend {
1564 const MATERIALIZES_CONJ: bool = false;
1565 const REQUIRES_UNIT_STRIDE: bool = false;
1566 }
1567
1568 impl BgemmBackend<f64> for TestNaiveBackend {
1569 fn bgemm_contiguous_into(
1570 c: &mut contiguous::ContiguousOperandMut<f64>,
1571 a: &contiguous::ContiguousOperand<f64>,
1572 b: &contiguous::ContiguousOperand<f64>,
1573 batch_dims: &[usize],
1574 m: usize,
1575 n: usize,
1576 k: usize,
1577 alpha: f64,
1578 beta: f64,
1579 ) -> strided_view::Result<()> {
1580 let a_ptr = a.ptr();
1582 let b_ptr = b.ptr();
1583 let c_ptr = c.ptr();
1584 let a_rs = a.row_stride();
1585 let a_cs = a.col_stride();
1586 let b_rs = b.row_stride();
1587 let b_cs = b.col_stride();
1588 let c_rs = c.row_stride();
1589 let c_cs = c.col_stride();
1590
1591 let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1592 while batch_idx.next().is_some() {
1593 let a_base = batch_idx.offset(a.batch_strides());
1594 let b_base = batch_idx.offset(b.batch_strides());
1595 let c_base = batch_idx.offset(c.batch_strides());
1596
1597 for i in 0..m {
1598 for j in 0..n {
1599 let mut acc = 0.0f64;
1600 for l in 0..k {
1601 let a_val = unsafe {
1602 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1603 };
1604 let b_val = unsafe {
1605 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1606 };
1607 acc += a_val * b_val;
1608 }
1609 unsafe {
1610 let c_elem =
1611 c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1612 if beta == 0.0 {
1613 *c_elem = alpha * acc;
1614 } else {
1615 *c_elem = alpha * acc + beta * (*c_elem);
1616 }
1617 }
1618 }
1619 }
1620 }
1621 Ok(())
1622 }
1623 }
1624
1625 #[test]
1626 fn test_einsum2_with_backend_matmul() {
1627 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1628 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1629 });
1630 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1631 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1632 });
1633 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1634
1635 einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1636 c.view_mut(),
1637 &a.view(),
1638 &b.view(),
1639 &['i', 'k'],
1640 &['i', 'j'],
1641 &['j', 'k'],
1642 1.0,
1643 0.0,
1644 )
1645 .unwrap();
1646
1647 assert_eq!(c.get(&[0, 0]), 19.0);
1648 assert_eq!(c.get(&[0, 1]), 22.0);
1649 assert_eq!(c.get(&[1, 0]), 43.0);
1650 assert_eq!(c.get(&[1, 1]), 50.0);
1651 }
1652
1653 #[test]
1654 fn test_einsum2_with_backend_elementwise() {
1655 let a =
1656 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1657 let b =
1658 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1659 let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1660
1661 einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1662 c.view_mut(),
1663 &a.view(),
1664 &b.view(),
1665 &['i', 'j'],
1666 &['i', 'j'],
1667 &['i', 'j'],
1668 2.0,
1669 0.0,
1670 )
1671 .unwrap();
1672
1673 assert_eq!(c.get(&[0, 0]), 2.0); assert_eq!(c.get(&[1, 2]), 72.0); }
1676
1677 #[test]
1678 fn test_einsum2_with_backend_custom_type() {
1679 use num_traits::{One, Zero};
1680
1681 #[derive(Debug, Clone, Copy, PartialEq)]
1682 struct Tropical(f64);
1683
1684 impl Default for Tropical {
1685 fn default() -> Self {
1686 Tropical(0.0)
1687 }
1688 }
1689
1690 impl std::ops::Add for Tropical {
1691 type Output = Self;
1692 fn add(self, rhs: Self) -> Self {
1693 Tropical(self.0 + rhs.0)
1694 }
1695 }
1696
1697 impl std::ops::Mul for Tropical {
1698 type Output = Self;
1699 fn mul(self, rhs: Self) -> Self {
1700 Tropical(self.0 * rhs.0)
1701 }
1702 }
1703
1704 impl Zero for Tropical {
1705 fn zero() -> Self {
1706 Tropical(0.0)
1707 }
1708 fn is_zero(&self) -> bool {
1709 self.0 == 0.0
1710 }
1711 }
1712
1713 impl One for Tropical {
1714 fn one() -> Self {
1715 Tropical(1.0)
1716 }
1717 }
1718
1719 struct TropicalBackend;
1720
1721 impl BackendConfig for TropicalBackend {
1722 const MATERIALIZES_CONJ: bool = false;
1723 const REQUIRES_UNIT_STRIDE: bool = false;
1724 }
1725
1726 impl BgemmBackend<Tropical> for TropicalBackend {
1727 fn bgemm_contiguous_into(
1728 c: &mut contiguous::ContiguousOperandMut<Tropical>,
1729 a: &contiguous::ContiguousOperand<Tropical>,
1730 b: &contiguous::ContiguousOperand<Tropical>,
1731 batch_dims: &[usize],
1732 m: usize,
1733 n: usize,
1734 k: usize,
1735 alpha: Tropical,
1736 beta: Tropical,
1737 ) -> strided_view::Result<()> {
1738 let a_ptr = a.ptr();
1740 let b_ptr = b.ptr();
1741 let c_ptr = c.ptr();
1742 let a_rs = a.row_stride();
1743 let a_cs = a.col_stride();
1744 let b_rs = b.row_stride();
1745 let b_cs = b.col_stride();
1746 let c_rs = c.row_stride();
1747 let c_cs = c.col_stride();
1748
1749 let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1750 while batch_idx.next().is_some() {
1751 let a_base = batch_idx.offset(a.batch_strides());
1752 let b_base = batch_idx.offset(b.batch_strides());
1753 let c_base = batch_idx.offset(c.batch_strides());
1754
1755 for i in 0..m {
1756 for j in 0..n {
1757 let mut acc = Tropical::zero();
1758 for l in 0..k {
1759 let a_val = unsafe {
1760 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1761 };
1762 let b_val = unsafe {
1763 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1764 };
1765 acc = acc + a_val * b_val;
1766 }
1767 unsafe {
1768 let c_elem =
1769 c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1770 if beta == Tropical::zero() {
1771 *c_elem = alpha * acc;
1772 } else {
1773 *c_elem = alpha * acc + beta * (*c_elem);
1774 }
1775 }
1776 }
1777 }
1778 }
1779 Ok(())
1780 }
1781 }
1782
1783 let a = StridedArray::from_parts(
1784 vec![Tropical(1.0), Tropical(2.0), Tropical(3.0), Tropical(4.0)],
1785 &[2, 2],
1786 &[2, 1],
1787 0,
1788 )
1789 .unwrap();
1790 let b = StridedArray::from_parts(
1791 vec![Tropical(5.0), Tropical(6.0), Tropical(7.0), Tropical(8.0)],
1792 &[2, 2],
1793 &[2, 1],
1794 0,
1795 )
1796 .unwrap();
1797 let mut c = StridedArray::<Tropical>::col_major(&[2, 2]);
1798
1799 einsum2_with_backend_into::<_, TropicalBackend, _>(
1800 c.view_mut(),
1801 &a.view(),
1802 &b.view(),
1803 &['i', 'k'],
1804 &['i', 'j'],
1805 &['j', 'k'],
1806 Tropical(1.0),
1807 Tropical(0.0),
1808 )
1809 .unwrap();
1810
1811 assert_eq!(c.get(&[0, 0]), Tropical(19.0));
1814 assert_eq!(c.get(&[0, 1]), Tropical(22.0));
1815 assert_eq!(c.get(&[1, 0]), Tropical(43.0));
1816 assert_eq!(c.get(&[1, 1]), Tropical(50.0));
1817 }
1818}