tenferro_einsum/lowering.rs
1//! Read-only einsum lowering plans.
2//!
3//! These wrappers expose shape and mode metadata computed by the ordinary
4//! einsum planner without exposing graph-building or arithmetic execution.
5//!
6//! # Examples
7//!
8//! ```
9//! use tenferro_einsum::{ContractionTree, Subscripts};
10//!
11//! let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
12//! let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
13//! let gemm = tree.step_plan(0).unwrap().gemm();
14//!
15//! assert_eq!(gemm.contracted_modes(), &[1]);
16//! ```
17
18use crate::planning::plan::{
19 DiagPlan as InnerDiagPlan, DiagStage as InnerDiagStage, GemmPlan as InnerGemmPlan,
20 ReducePlan as InnerReducePlan, StepPlan as InnerStepPlan,
21};
22
23/// Read-only lowering data for one pairwise contraction step.
24///
25/// # Examples
26///
27/// ```
28/// use tenferro_einsum::{ContractionTree, Subscripts};
29///
30/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
31/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
32/// let step = tree.step_plan(0).unwrap();
33///
34/// assert_eq!(step.gemm().contracted_modes(), &[1]);
35/// ```
36#[derive(Clone, Copy, Debug)]
37pub struct PairwiseStepPlan<'a> {
38 inner: &'a InnerStepPlan,
39}
40
41impl<'a> PairwiseStepPlan<'a> {
42 pub(crate) fn new(inner: &'a InnerStepPlan) -> Self {
43 Self { inner }
44 }
45
46 /// Return the left operand diagonal extraction plan, if any.
47 ///
48 /// # Examples
49 ///
50 /// ```
51 /// use tenferro_einsum::{ContractionTree, Subscripts};
52 ///
53 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
54 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
55 ///
56 /// assert!(tree.step_plan(0).unwrap().lhs_diag().is_none());
57 /// ```
58 #[must_use]
59 pub fn lhs_diag(&self) -> Option<DiagPlan<'a>> {
60 let inner: &'a InnerStepPlan = self.inner;
61 inner.diag_a.as_ref().map(DiagPlan::new)
62 }
63
64 /// Return the right operand diagonal extraction plan, if any.
65 ///
66 /// # Examples
67 ///
68 /// ```
69 /// use tenferro_einsum::{ContractionTree, Subscripts};
70 ///
71 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
72 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
73 ///
74 /// assert!(tree.step_plan(0).unwrap().rhs_diag().is_none());
75 /// ```
76 #[must_use]
77 pub fn rhs_diag(&self) -> Option<DiagPlan<'a>> {
78 let inner: &'a InnerStepPlan = self.inner;
79 inner.diag_b.as_ref().map(DiagPlan::new)
80 }
81
82 /// Return the left operand pre-reduction plan, if any.
83 ///
84 /// # Examples
85 ///
86 /// ```
87 /// use tenferro_einsum::{ContractionTree, Subscripts};
88 ///
89 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
90 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
91 /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
92 ///
93 /// assert_eq!(reduce.kept_subs(), &[1]);
94 /// ```
95 #[must_use]
96 pub fn lhs_reduce(&self) -> Option<ReducePlan<'a>> {
97 let inner: &'a InnerStepPlan = self.inner;
98 inner.gemm.reduce_a.as_ref().map(ReducePlan::new)
99 }
100
101 /// Return the right operand pre-reduction plan, if any.
102 ///
103 /// # Examples
104 ///
105 /// ```
106 /// use tenferro_einsum::{ContractionTree, Subscripts};
107 ///
108 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0]);
109 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
110 /// let reduce = tree.step_plan(0).unwrap().rhs_reduce().unwrap();
111 ///
112 /// assert_eq!(reduce.kept_subs(), &[1]);
113 /// ```
114 #[must_use]
115 pub fn rhs_reduce(&self) -> Option<ReducePlan<'a>> {
116 let inner: &'a InnerStepPlan = self.inner;
117 inner.gemm.reduce_b.as_ref().map(ReducePlan::new)
118 }
119
120 /// Return the GEMM decomposition for this pairwise step.
121 ///
122 /// # Examples
123 ///
124 /// ```
125 /// use tenferro_einsum::{ContractionTree, Subscripts};
126 ///
127 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
128 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
129 /// let gemm = tree.step_plan(0).unwrap().gemm();
130 ///
131 /// assert_eq!(gemm.output_gemm_shape(), &[2, 4]);
132 /// ```
133 #[must_use]
134 pub fn gemm(&self) -> GemmPlan<'a> {
135 let inner: &'a InnerStepPlan = self.inner;
136 GemmPlan::new(&inner.gemm)
137 }
138}
139
140/// Read-only diagonal extraction plan for one operand.
141///
142/// # Examples
143///
144/// ```
145/// use tenferro_einsum::{ContractionTree, Subscripts};
146///
147/// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
148/// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
149/// let diag = tree.step_plan(0).unwrap().lhs_diag().unwrap();
150///
151/// assert_eq!(diag.result_subs(), &[0]);
152/// ```
153#[derive(Clone, Copy, Debug)]
154pub struct DiagPlan<'a> {
155 inner: &'a InnerDiagPlan,
156}
157
158impl<'a> DiagPlan<'a> {
159 fn new(inner: &'a InnerDiagPlan) -> Self {
160 Self { inner }
161 }
162
163 /// Return the sequential diagonal extraction stages.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// use tenferro_einsum::{ContractionTree, Subscripts};
169 ///
170 /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
171 /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
172 /// let diag = tree.step_plan(0).unwrap().lhs_diag().unwrap();
173 /// let mut stages = diag.stages();
174 ///
175 /// assert_eq!(stages.next().unwrap().axis_pairs(), &[(0, 1)]);
176 /// assert!(stages.next().is_none());
177 /// ```
178 pub fn stages(self) -> impl ExactSizeIterator<Item = DiagStage<'a>> + 'a {
179 let inner: &'a InnerDiagPlan = self.inner;
180 inner.stages.iter().map(DiagStage::new)
181 }
182
183 /// Return final subscripts after all diagonal extraction stages.
184 ///
185 /// # Examples
186 ///
187 /// ```
188 /// use tenferro_einsum::{ContractionTree, Subscripts};
189 ///
190 /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
191 /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
192 /// let diag = tree.step_plan(0).unwrap().lhs_diag().unwrap();
193 ///
194 /// assert_eq!(diag.result_subs(), &[0]);
195 /// ```
196 #[must_use]
197 pub fn result_subs(&self) -> &'a [u32] {
198 let inner: &'a InnerDiagPlan = self.inner;
199 inner.result_subs.as_slice()
200 }
201}
202
203/// Read-only metadata for one diagonal extraction stage.
204///
205/// # Examples
206///
207/// ```
208/// use tenferro_einsum::{ContractionTree, Subscripts};
209///
210/// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
211/// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
212/// let stage = tree.step_plan(0).unwrap().lhs_diag().unwrap().stages().next().unwrap();
213///
214/// assert_eq!(stage.axis_pairs(), &[(0, 1)]);
215/// ```
216#[derive(Clone, Copy, Debug)]
217pub struct DiagStage<'a> {
218 inner: &'a InnerDiagStage,
219}
220
221impl<'a> DiagStage<'a> {
222 fn new(inner: &'a InnerDiagStage) -> Self {
223 Self { inner }
224 }
225
226 /// Return axis pairs extracted by this diagonal stage.
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use tenferro_einsum::{ContractionTree, Subscripts};
232 ///
233 /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
234 /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
235 /// let stage = tree.step_plan(0).unwrap().lhs_diag().unwrap().stages().next().unwrap();
236 ///
237 /// assert_eq!(stage.axis_pairs(), &[(0, 1)]);
238 /// ```
239 #[must_use]
240 pub fn axis_pairs(&self) -> &'a [(usize, usize)] {
241 let inner: &'a InnerDiagStage = self.inner;
242 inner.axis_pairs.as_slice()
243 }
244
245 /// Return subscripts after this diagonal stage.
246 ///
247 /// # Examples
248 ///
249 /// ```
250 /// use tenferro_einsum::{ContractionTree, Subscripts};
251 ///
252 /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
253 /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
254 /// let stage = tree.step_plan(0).unwrap().lhs_diag().unwrap().stages().next().unwrap();
255 ///
256 /// assert_eq!(stage.result_subs(), &[0]);
257 /// ```
258 #[must_use]
259 pub fn result_subs(&self) -> &'a [u32] {
260 let inner: &'a InnerDiagStage = self.inner;
261 inner.result_subs.as_slice()
262 }
263}
264
265/// Read-only pre-reduction plan for axes unique to one operand.
266///
267/// # Examples
268///
269/// ```
270/// use tenferro_einsum::{ContractionTree, Subscripts};
271///
272/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
273/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
274/// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
275///
276/// assert_eq!(reduce.out_shape(), &[3]);
277/// ```
278#[derive(Clone, Copy, Debug)]
279pub struct ReducePlan<'a> {
280 inner: &'a InnerReducePlan,
281}
282
283impl<'a> ReducePlan<'a> {
284 fn new(inner: &'a InnerReducePlan) -> Self {
285 Self { inner }
286 }
287
288 /// Return subscripts before pre-reduction.
289 ///
290 /// # Examples
291 ///
292 /// ```
293 /// use tenferro_einsum::{ContractionTree, Subscripts};
294 ///
295 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
296 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
297 /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
298 ///
299 /// assert_eq!(reduce.original_subs(), &[0, 1]);
300 /// ```
301 #[must_use]
302 pub fn original_subs(&self) -> &'a [u32] {
303 let inner: &'a InnerReducePlan = self.inner;
304 inner.original_subs.as_slice()
305 }
306
307 /// Return subscripts kept after pre-reduction.
308 ///
309 /// # Examples
310 ///
311 /// ```
312 /// use tenferro_einsum::{ContractionTree, Subscripts};
313 ///
314 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
315 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
316 /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
317 ///
318 /// assert_eq!(reduce.kept_subs(), &[1]);
319 /// ```
320 #[must_use]
321 pub fn kept_subs(&self) -> &'a [u32] {
322 let inner: &'a InnerReducePlan = self.inner;
323 inner.kept_subs.as_slice()
324 }
325
326 /// Return the shape after pre-reduction.
327 ///
328 /// # Examples
329 ///
330 /// ```
331 /// use tenferro_einsum::{ContractionTree, Subscripts};
332 ///
333 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
334 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
335 /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
336 ///
337 /// assert_eq!(reduce.out_shape(), &[3]);
338 /// ```
339 #[must_use]
340 pub fn out_shape(&self) -> &'a [usize] {
341 let inner: &'a InnerReducePlan = self.inner;
342 inner.out_shape.as_slice()
343 }
344}
345
346/// Read-only GEMM decomposition plan for a pairwise contraction step.
347///
348/// # Examples
349///
350/// ```
351/// use tenferro_einsum::{ContractionTree, Subscripts};
352///
353/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
354/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
355/// let gemm = tree.step_plan(0).unwrap().gemm();
356///
357/// assert_eq!(gemm.lhs_gemm_shape(), &[2, 3]);
358/// ```
359#[derive(Clone, Copy, Debug)]
360pub struct GemmPlan<'a> {
361 inner: &'a InnerGemmPlan,
362}
363
364impl<'a> GemmPlan<'a> {
365 fn new(inner: &'a InnerGemmPlan) -> Self {
366 Self { inner }
367 }
368
369 /// Return modes present only on the left operand and output.
370 ///
371 /// # Examples
372 ///
373 /// ```
374 /// use tenferro_einsum::{ContractionTree, Subscripts};
375 ///
376 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
377 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
378 /// let gemm = tree.step_plan(0).unwrap().gemm();
379 ///
380 /// assert_eq!(gemm.left_only_modes(), &[0]);
381 /// ```
382 #[must_use]
383 pub fn left_only_modes(&self) -> &'a [u32] {
384 let inner: &'a InnerGemmPlan = self.inner;
385 inner.lo_modes.as_slice()
386 }
387
388 /// Return dimension sizes for [`Self::left_only_modes`].
389 ///
390 /// # Examples
391 ///
392 /// ```
393 /// use tenferro_einsum::{ContractionTree, Subscripts};
394 ///
395 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
396 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
397 /// let gemm = tree.step_plan(0).unwrap().gemm();
398 ///
399 /// assert_eq!(gemm.left_only_shape(), &[2]);
400 /// ```
401 #[must_use]
402 pub fn left_only_shape(&self) -> &'a [usize] {
403 let inner: &'a InnerGemmPlan = self.inner;
404 inner.lo_sizes.as_slice()
405 }
406
407 /// Return modes present only on the right operand and output.
408 ///
409 /// # Examples
410 ///
411 /// ```
412 /// use tenferro_einsum::{ContractionTree, Subscripts};
413 ///
414 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
415 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
416 /// let gemm = tree.step_plan(0).unwrap().gemm();
417 ///
418 /// assert_eq!(gemm.right_only_modes(), &[2]);
419 /// ```
420 #[must_use]
421 pub fn right_only_modes(&self) -> &'a [u32] {
422 let inner: &'a InnerGemmPlan = self.inner;
423 inner.ro_modes.as_slice()
424 }
425
426 /// Return dimension sizes for [`Self::right_only_modes`].
427 ///
428 /// # Examples
429 ///
430 /// ```
431 /// use tenferro_einsum::{ContractionTree, Subscripts};
432 ///
433 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
434 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
435 /// let gemm = tree.step_plan(0).unwrap().gemm();
436 ///
437 /// assert_eq!(gemm.right_only_shape(), &[4]);
438 /// ```
439 #[must_use]
440 pub fn right_only_shape(&self) -> &'a [usize] {
441 let inner: &'a InnerGemmPlan = self.inner;
442 inner.ro_sizes.as_slice()
443 }
444
445 /// Return modes contracted between the pairwise operands.
446 ///
447 /// # Examples
448 ///
449 /// ```
450 /// use tenferro_einsum::{ContractionTree, Subscripts};
451 ///
452 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
453 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
454 /// let gemm = tree.step_plan(0).unwrap().gemm();
455 ///
456 /// assert_eq!(gemm.contracted_modes(), &[1]);
457 /// ```
458 #[must_use]
459 pub fn contracted_modes(&self) -> &'a [u32] {
460 let inner: &'a InnerGemmPlan = self.inner;
461 inner.sum_modes.as_slice()
462 }
463
464 /// Return dimension sizes for [`Self::contracted_modes`].
465 ///
466 /// # Examples
467 ///
468 /// ```
469 /// use tenferro_einsum::{ContractionTree, Subscripts};
470 ///
471 /// let subs = Subscripts::new(&[&[0, 1, 2], &[1, 2, 3]], &[0, 3]);
472 /// let tree = ContractionTree::from_pairs(
473 /// &subs,
474 /// &[&[2, 3, 4], &[3, 4, 5]],
475 /// &[(0, 1)],
476 /// )
477 /// .unwrap();
478 /// let gemm = tree.step_plan(0).unwrap().gemm();
479 ///
480 /// assert_eq!(gemm.contracted_shape(), &[3, 4]);
481 /// ```
482 #[must_use]
483 pub fn contracted_shape(&self) -> &'a [usize] {
484 let inner: &'a InnerGemmPlan = self.inner;
485 inner.sum_sizes.as_slice()
486 }
487
488 /// Return modes shared by both operands and preserved in the output.
489 ///
490 /// # Examples
491 ///
492 /// ```
493 /// use tenferro_einsum::{ContractionTree, Subscripts};
494 ///
495 /// let subs = Subscripts::new(&[&[3, 0, 1], &[1, 2, 3]], &[3, 0, 2]);
496 /// let tree = ContractionTree::from_pairs(&subs, &[&[5, 2, 3], &[3, 4, 5]], &[(0, 1)]).unwrap();
497 /// let gemm = tree.step_plan(0).unwrap().gemm();
498 ///
499 /// assert_eq!(gemm.batch_modes(), &[3]);
500 /// ```
501 #[must_use]
502 pub fn batch_modes(&self) -> &'a [u32] {
503 let inner: &'a InnerGemmPlan = self.inner;
504 let batch_start = inner.lo_modes.len() + inner.ro_modes.len();
505 &inner.canonical_modes[batch_start..]
506 }
507
508 /// Return dimension sizes for [`Self::batch_modes`].
509 ///
510 /// # Examples
511 ///
512 /// ```
513 /// use tenferro_einsum::{ContractionTree, Subscripts};
514 ///
515 /// let subs = Subscripts::new(&[&[3, 0, 1], &[1, 2, 3]], &[3, 0, 2]);
516 /// let tree = ContractionTree::from_pairs(&subs, &[&[5, 2, 3], &[3, 4, 5]], &[(0, 1)]).unwrap();
517 /// let gemm = tree.step_plan(0).unwrap().gemm();
518 ///
519 /// assert_eq!(gemm.batch_shape(), &[5]);
520 /// ```
521 #[must_use]
522 pub fn batch_shape(&self) -> &'a [usize] {
523 let inner: &'a InnerGemmPlan = self.inner;
524 inner.batch_sizes.as_slice()
525 }
526
527 /// Return target mode order for preparing the left GEMM operand.
528 ///
529 /// # Examples
530 ///
531 /// ```
532 /// use tenferro_einsum::{ContractionTree, Subscripts};
533 ///
534 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
535 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
536 /// let gemm = tree.step_plan(0).unwrap().gemm();
537 ///
538 /// assert_eq!(gemm.lhs_target_modes(), &[0, 1]);
539 /// ```
540 #[must_use]
541 pub fn lhs_target_modes(&self) -> &'a [u32] {
542 let inner: &'a InnerGemmPlan = self.inner;
543 inner.target_a.as_slice()
544 }
545
546 /// Return target mode order for preparing the right GEMM operand.
547 ///
548 /// # Examples
549 ///
550 /// ```
551 /// use tenferro_einsum::{ContractionTree, Subscripts};
552 ///
553 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
554 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
555 /// let gemm = tree.step_plan(0).unwrap().gemm();
556 ///
557 /// assert_eq!(gemm.rhs_target_modes(), &[1, 2]);
558 /// ```
559 #[must_use]
560 pub fn rhs_target_modes(&self) -> &'a [u32] {
561 let inner: &'a InnerGemmPlan = self.inner;
562 inner.target_b.as_slice()
563 }
564
565 /// Return canonical output mode order before any final permutation.
566 ///
567 /// # Examples
568 ///
569 /// ```
570 /// use tenferro_einsum::{ContractionTree, Subscripts};
571 ///
572 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2, 0]);
573 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
574 /// let gemm = tree.step_plan(0).unwrap().gemm();
575 ///
576 /// assert_eq!(gemm.canonical_output_modes(), &[0, 2]);
577 /// ```
578 #[must_use]
579 pub fn canonical_output_modes(&self) -> &'a [u32] {
580 let inner: &'a InnerGemmPlan = self.inner;
581 inner.canonical_modes.as_slice()
582 }
583
584 /// Return fused left-only dimension size.
585 ///
586 /// # Examples
587 ///
588 /// ```
589 /// use tenferro_einsum::{ContractionTree, Subscripts};
590 ///
591 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
592 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
593 /// let gemm = tree.step_plan(0).unwrap().gemm();
594 ///
595 /// assert_eq!(gemm.m(), 2);
596 /// ```
597 #[must_use]
598 pub fn m(&self) -> usize {
599 self.inner.m
600 }
601
602 /// Return fused right-only dimension size.
603 ///
604 /// # Examples
605 ///
606 /// ```
607 /// use tenferro_einsum::{ContractionTree, Subscripts};
608 ///
609 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
610 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
611 /// let gemm = tree.step_plan(0).unwrap().gemm();
612 ///
613 /// assert_eq!(gemm.n(), 4);
614 /// ```
615 #[must_use]
616 pub fn n(&self) -> usize {
617 self.inner.n
618 }
619
620 /// Return fused contracted dimension size.
621 ///
622 /// # Examples
623 ///
624 /// ```
625 /// use tenferro_einsum::{ContractionTree, Subscripts};
626 ///
627 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
628 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
629 /// let gemm = tree.step_plan(0).unwrap().gemm();
630 ///
631 /// assert_eq!(gemm.k(), 3);
632 /// ```
633 #[must_use]
634 pub fn k(&self) -> usize {
635 self.inner.k
636 }
637
638 /// Return prepared left operand GEMM shape.
639 ///
640 /// # Examples
641 ///
642 /// ```
643 /// use tenferro_einsum::{ContractionTree, Subscripts};
644 ///
645 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
646 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
647 /// let gemm = tree.step_plan(0).unwrap().gemm();
648 ///
649 /// assert_eq!(gemm.lhs_gemm_shape(), &[2, 3]);
650 /// ```
651 #[must_use]
652 pub fn lhs_gemm_shape(&self) -> &'a [usize] {
653 let inner: &'a InnerGemmPlan = self.inner;
654 inner.a_gemm_shape.as_slice()
655 }
656
657 /// Return prepared right operand GEMM shape.
658 ///
659 /// # Examples
660 ///
661 /// ```
662 /// use tenferro_einsum::{ContractionTree, Subscripts};
663 ///
664 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
665 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
666 /// let gemm = tree.step_plan(0).unwrap().gemm();
667 ///
668 /// assert_eq!(gemm.rhs_gemm_shape(), &[3, 4]);
669 /// ```
670 #[must_use]
671 pub fn rhs_gemm_shape(&self) -> &'a [usize] {
672 let inner: &'a InnerGemmPlan = self.inner;
673 inner.b_gemm_shape.as_slice()
674 }
675
676 /// Return GEMM output shape before expanding fused dimensions.
677 ///
678 /// # Examples
679 ///
680 /// ```
681 /// use tenferro_einsum::{ContractionTree, Subscripts};
682 ///
683 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
684 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
685 /// let gemm = tree.step_plan(0).unwrap().gemm();
686 ///
687 /// assert_eq!(gemm.output_gemm_shape(), &[2, 4]);
688 /// ```
689 #[must_use]
690 pub fn output_gemm_shape(&self) -> &'a [usize] {
691 let inner: &'a InnerGemmPlan = self.inner;
692 inner.c_gemm_shape.as_slice()
693 }
694
695 /// Return expanded output shape in canonical output mode order.
696 ///
697 /// # Examples
698 ///
699 /// ```
700 /// use tenferro_einsum::{ContractionTree, Subscripts};
701 ///
702 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
703 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
704 /// let gemm = tree.step_plan(0).unwrap().gemm();
705 ///
706 /// assert_eq!(gemm.expanded_output_shape(), &[2, 4]);
707 /// ```
708 #[must_use]
709 pub fn expanded_output_shape(&self) -> &'a [usize] {
710 let inner: &'a InnerGemmPlan = self.inner;
711 inner.expanded_shape.as_slice()
712 }
713
714 /// Return whether canonical output requires a final permutation.
715 ///
716 /// # Examples
717 ///
718 /// ```
719 /// use tenferro_einsum::{ContractionTree, Subscripts};
720 ///
721 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2, 0]);
722 /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
723 /// let gemm = tree.step_plan(0).unwrap().gemm();
724 ///
725 /// assert!(gemm.needs_final_permute());
726 /// ```
727 #[must_use]
728 pub fn needs_final_permute(&self) -> bool {
729 self.inner.needs_final_permute
730 }
731}