Skip to main content

tenferro_einsum/
lowering.rs

1//! Read-only einsum lowering plans.
2//!
3//! These wrappers expose shape and mode metadata computed by the ordinary
4//! einsum planner without exposing graph-building or arithmetic execution.
5//!
6//! # Examples
7//!
8//! ```
9//! use tenferro_einsum::{ContractionTree, Subscripts};
10//!
11//! let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
12//! let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
13//! let gemm = tree.step_plan(0).unwrap().gemm();
14//!
15//! assert_eq!(gemm.contracted_modes(), &[1]);
16//! ```
17
18use crate::planning::plan::{
19    DiagPlan as InnerDiagPlan, DiagStage as InnerDiagStage, GemmPlan as InnerGemmPlan,
20    ReducePlan as InnerReducePlan, StepPlan as InnerStepPlan,
21};
22
23/// Read-only lowering data for one pairwise contraction step.
24///
25/// # Examples
26///
27/// ```
28/// use tenferro_einsum::{ContractionTree, Subscripts};
29///
30/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
31/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
32/// let step = tree.step_plan(0).unwrap();
33///
34/// assert_eq!(step.gemm().contracted_modes(), &[1]);
35/// ```
36#[derive(Clone, Copy, Debug)]
37pub struct PairwiseStepPlan<'a> {
38    inner: &'a InnerStepPlan,
39}
40
41impl<'a> PairwiseStepPlan<'a> {
42    pub(crate) fn new(inner: &'a InnerStepPlan) -> Self {
43        Self { inner }
44    }
45
46    /// Return the left operand diagonal extraction plan, if any.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use tenferro_einsum::{ContractionTree, Subscripts};
52    ///
53    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
54    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
55    ///
56    /// assert!(tree.step_plan(0).unwrap().lhs_diag().is_none());
57    /// ```
58    #[must_use]
59    pub fn lhs_diag(&self) -> Option<DiagPlan<'a>> {
60        let inner: &'a InnerStepPlan = self.inner;
61        inner.diag_a.as_ref().map(DiagPlan::new)
62    }
63
64    /// Return the right operand diagonal extraction plan, if any.
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use tenferro_einsum::{ContractionTree, Subscripts};
70    ///
71    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
72    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
73    ///
74    /// assert!(tree.step_plan(0).unwrap().rhs_diag().is_none());
75    /// ```
76    #[must_use]
77    pub fn rhs_diag(&self) -> Option<DiagPlan<'a>> {
78        let inner: &'a InnerStepPlan = self.inner;
79        inner.diag_b.as_ref().map(DiagPlan::new)
80    }
81
82    /// Return the left operand pre-reduction plan, if any.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use tenferro_einsum::{ContractionTree, Subscripts};
88    ///
89    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
90    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
91    /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
92    ///
93    /// assert_eq!(reduce.kept_subs(), &[1]);
94    /// ```
95    #[must_use]
96    pub fn lhs_reduce(&self) -> Option<ReducePlan<'a>> {
97        let inner: &'a InnerStepPlan = self.inner;
98        inner.gemm.reduce_a.as_ref().map(ReducePlan::new)
99    }
100
101    /// Return the right operand pre-reduction plan, if any.
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use tenferro_einsum::{ContractionTree, Subscripts};
107    ///
108    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0]);
109    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
110    /// let reduce = tree.step_plan(0).unwrap().rhs_reduce().unwrap();
111    ///
112    /// assert_eq!(reduce.kept_subs(), &[1]);
113    /// ```
114    #[must_use]
115    pub fn rhs_reduce(&self) -> Option<ReducePlan<'a>> {
116        let inner: &'a InnerStepPlan = self.inner;
117        inner.gemm.reduce_b.as_ref().map(ReducePlan::new)
118    }
119
120    /// Return the GEMM decomposition for this pairwise step.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use tenferro_einsum::{ContractionTree, Subscripts};
126    ///
127    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
128    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
129    /// let gemm = tree.step_plan(0).unwrap().gemm();
130    ///
131    /// assert_eq!(gemm.output_gemm_shape(), &[2, 4]);
132    /// ```
133    #[must_use]
134    pub fn gemm(&self) -> GemmPlan<'a> {
135        let inner: &'a InnerStepPlan = self.inner;
136        GemmPlan::new(&inner.gemm)
137    }
138}
139
140/// Read-only diagonal extraction plan for one operand.
141///
142/// # Examples
143///
144/// ```
145/// use tenferro_einsum::{ContractionTree, Subscripts};
146///
147/// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
148/// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
149/// let diag = tree.step_plan(0).unwrap().lhs_diag().unwrap();
150///
151/// assert_eq!(diag.result_subs(), &[0]);
152/// ```
153#[derive(Clone, Copy, Debug)]
154pub struct DiagPlan<'a> {
155    inner: &'a InnerDiagPlan,
156}
157
158impl<'a> DiagPlan<'a> {
159    fn new(inner: &'a InnerDiagPlan) -> Self {
160        Self { inner }
161    }
162
163    /// Return the sequential diagonal extraction stages.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// use tenferro_einsum::{ContractionTree, Subscripts};
169    ///
170    /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
171    /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
172    /// let diag = tree.step_plan(0).unwrap().lhs_diag().unwrap();
173    /// let mut stages = diag.stages();
174    ///
175    /// assert_eq!(stages.next().unwrap().axis_pairs(), &[(0, 1)]);
176    /// assert!(stages.next().is_none());
177    /// ```
178    pub fn stages(self) -> impl ExactSizeIterator<Item = DiagStage<'a>> + 'a {
179        let inner: &'a InnerDiagPlan = self.inner;
180        inner.stages.iter().map(DiagStage::new)
181    }
182
183    /// Return final subscripts after all diagonal extraction stages.
184    ///
185    /// # Examples
186    ///
187    /// ```
188    /// use tenferro_einsum::{ContractionTree, Subscripts};
189    ///
190    /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
191    /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
192    /// let diag = tree.step_plan(0).unwrap().lhs_diag().unwrap();
193    ///
194    /// assert_eq!(diag.result_subs(), &[0]);
195    /// ```
196    #[must_use]
197    pub fn result_subs(&self) -> &'a [u32] {
198        let inner: &'a InnerDiagPlan = self.inner;
199        inner.result_subs.as_slice()
200    }
201}
202
203/// Read-only metadata for one diagonal extraction stage.
204///
205/// # Examples
206///
207/// ```
208/// use tenferro_einsum::{ContractionTree, Subscripts};
209///
210/// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
211/// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
212/// let stage = tree.step_plan(0).unwrap().lhs_diag().unwrap().stages().next().unwrap();
213///
214/// assert_eq!(stage.axis_pairs(), &[(0, 1)]);
215/// ```
216#[derive(Clone, Copy, Debug)]
217pub struct DiagStage<'a> {
218    inner: &'a InnerDiagStage,
219}
220
221impl<'a> DiagStage<'a> {
222    fn new(inner: &'a InnerDiagStage) -> Self {
223        Self { inner }
224    }
225
226    /// Return axis pairs extracted by this diagonal stage.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use tenferro_einsum::{ContractionTree, Subscripts};
232    ///
233    /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
234    /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
235    /// let stage = tree.step_plan(0).unwrap().lhs_diag().unwrap().stages().next().unwrap();
236    ///
237    /// assert_eq!(stage.axis_pairs(), &[(0, 1)]);
238    /// ```
239    #[must_use]
240    pub fn axis_pairs(&self) -> &'a [(usize, usize)] {
241        let inner: &'a InnerDiagStage = self.inner;
242        inner.axis_pairs.as_slice()
243    }
244
245    /// Return subscripts after this diagonal stage.
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// use tenferro_einsum::{ContractionTree, Subscripts};
251    ///
252    /// let subs = Subscripts::new(&[&[0, 0], &[0]], &[0]);
253    /// let tree = ContractionTree::from_pairs(&subs, &[&[3, 3], &[3]], &[(0, 1)]).unwrap();
254    /// let stage = tree.step_plan(0).unwrap().lhs_diag().unwrap().stages().next().unwrap();
255    ///
256    /// assert_eq!(stage.result_subs(), &[0]);
257    /// ```
258    #[must_use]
259    pub fn result_subs(&self) -> &'a [u32] {
260        let inner: &'a InnerDiagStage = self.inner;
261        inner.result_subs.as_slice()
262    }
263}
264
265/// Read-only pre-reduction plan for axes unique to one operand.
266///
267/// # Examples
268///
269/// ```
270/// use tenferro_einsum::{ContractionTree, Subscripts};
271///
272/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
273/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
274/// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
275///
276/// assert_eq!(reduce.out_shape(), &[3]);
277/// ```
278#[derive(Clone, Copy, Debug)]
279pub struct ReducePlan<'a> {
280    inner: &'a InnerReducePlan,
281}
282
283impl<'a> ReducePlan<'a> {
284    fn new(inner: &'a InnerReducePlan) -> Self {
285        Self { inner }
286    }
287
288    /// Return subscripts before pre-reduction.
289    ///
290    /// # Examples
291    ///
292    /// ```
293    /// use tenferro_einsum::{ContractionTree, Subscripts};
294    ///
295    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
296    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
297    /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
298    ///
299    /// assert_eq!(reduce.original_subs(), &[0, 1]);
300    /// ```
301    #[must_use]
302    pub fn original_subs(&self) -> &'a [u32] {
303        let inner: &'a InnerReducePlan = self.inner;
304        inner.original_subs.as_slice()
305    }
306
307    /// Return subscripts kept after pre-reduction.
308    ///
309    /// # Examples
310    ///
311    /// ```
312    /// use tenferro_einsum::{ContractionTree, Subscripts};
313    ///
314    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
315    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
316    /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
317    ///
318    /// assert_eq!(reduce.kept_subs(), &[1]);
319    /// ```
320    #[must_use]
321    pub fn kept_subs(&self) -> &'a [u32] {
322        let inner: &'a InnerReducePlan = self.inner;
323        inner.kept_subs.as_slice()
324    }
325
326    /// Return the shape after pre-reduction.
327    ///
328    /// # Examples
329    ///
330    /// ```
331    /// use tenferro_einsum::{ContractionTree, Subscripts};
332    ///
333    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2]);
334    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
335    /// let reduce = tree.step_plan(0).unwrap().lhs_reduce().unwrap();
336    ///
337    /// assert_eq!(reduce.out_shape(), &[3]);
338    /// ```
339    #[must_use]
340    pub fn out_shape(&self) -> &'a [usize] {
341        let inner: &'a InnerReducePlan = self.inner;
342        inner.out_shape.as_slice()
343    }
344}
345
346/// Read-only GEMM decomposition plan for a pairwise contraction step.
347///
348/// # Examples
349///
350/// ```
351/// use tenferro_einsum::{ContractionTree, Subscripts};
352///
353/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
354/// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
355/// let gemm = tree.step_plan(0).unwrap().gemm();
356///
357/// assert_eq!(gemm.lhs_gemm_shape(), &[2, 3]);
358/// ```
359#[derive(Clone, Copy, Debug)]
360pub struct GemmPlan<'a> {
361    inner: &'a InnerGemmPlan,
362}
363
364impl<'a> GemmPlan<'a> {
365    fn new(inner: &'a InnerGemmPlan) -> Self {
366        Self { inner }
367    }
368
369    /// Return modes present only on the left operand and output.
370    ///
371    /// # Examples
372    ///
373    /// ```
374    /// use tenferro_einsum::{ContractionTree, Subscripts};
375    ///
376    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
377    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
378    /// let gemm = tree.step_plan(0).unwrap().gemm();
379    ///
380    /// assert_eq!(gemm.left_only_modes(), &[0]);
381    /// ```
382    #[must_use]
383    pub fn left_only_modes(&self) -> &'a [u32] {
384        let inner: &'a InnerGemmPlan = self.inner;
385        inner.lo_modes.as_slice()
386    }
387
388    /// Return dimension sizes for [`Self::left_only_modes`].
389    ///
390    /// # Examples
391    ///
392    /// ```
393    /// use tenferro_einsum::{ContractionTree, Subscripts};
394    ///
395    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
396    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
397    /// let gemm = tree.step_plan(0).unwrap().gemm();
398    ///
399    /// assert_eq!(gemm.left_only_shape(), &[2]);
400    /// ```
401    #[must_use]
402    pub fn left_only_shape(&self) -> &'a [usize] {
403        let inner: &'a InnerGemmPlan = self.inner;
404        inner.lo_sizes.as_slice()
405    }
406
407    /// Return modes present only on the right operand and output.
408    ///
409    /// # Examples
410    ///
411    /// ```
412    /// use tenferro_einsum::{ContractionTree, Subscripts};
413    ///
414    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
415    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
416    /// let gemm = tree.step_plan(0).unwrap().gemm();
417    ///
418    /// assert_eq!(gemm.right_only_modes(), &[2]);
419    /// ```
420    #[must_use]
421    pub fn right_only_modes(&self) -> &'a [u32] {
422        let inner: &'a InnerGemmPlan = self.inner;
423        inner.ro_modes.as_slice()
424    }
425
426    /// Return dimension sizes for [`Self::right_only_modes`].
427    ///
428    /// # Examples
429    ///
430    /// ```
431    /// use tenferro_einsum::{ContractionTree, Subscripts};
432    ///
433    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
434    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
435    /// let gemm = tree.step_plan(0).unwrap().gemm();
436    ///
437    /// assert_eq!(gemm.right_only_shape(), &[4]);
438    /// ```
439    #[must_use]
440    pub fn right_only_shape(&self) -> &'a [usize] {
441        let inner: &'a InnerGemmPlan = self.inner;
442        inner.ro_sizes.as_slice()
443    }
444
445    /// Return modes contracted between the pairwise operands.
446    ///
447    /// # Examples
448    ///
449    /// ```
450    /// use tenferro_einsum::{ContractionTree, Subscripts};
451    ///
452    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
453    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
454    /// let gemm = tree.step_plan(0).unwrap().gemm();
455    ///
456    /// assert_eq!(gemm.contracted_modes(), &[1]);
457    /// ```
458    #[must_use]
459    pub fn contracted_modes(&self) -> &'a [u32] {
460        let inner: &'a InnerGemmPlan = self.inner;
461        inner.sum_modes.as_slice()
462    }
463
464    /// Return dimension sizes for [`Self::contracted_modes`].
465    ///
466    /// # Examples
467    ///
468    /// ```
469    /// use tenferro_einsum::{ContractionTree, Subscripts};
470    ///
471    /// let subs = Subscripts::new(&[&[0, 1, 2], &[1, 2, 3]], &[0, 3]);
472    /// let tree = ContractionTree::from_pairs(
473    ///     &subs,
474    ///     &[&[2, 3, 4], &[3, 4, 5]],
475    ///     &[(0, 1)],
476    /// )
477    /// .unwrap();
478    /// let gemm = tree.step_plan(0).unwrap().gemm();
479    ///
480    /// assert_eq!(gemm.contracted_shape(), &[3, 4]);
481    /// ```
482    #[must_use]
483    pub fn contracted_shape(&self) -> &'a [usize] {
484        let inner: &'a InnerGemmPlan = self.inner;
485        inner.sum_sizes.as_slice()
486    }
487
488    /// Return modes shared by both operands and preserved in the output.
489    ///
490    /// # Examples
491    ///
492    /// ```
493    /// use tenferro_einsum::{ContractionTree, Subscripts};
494    ///
495    /// let subs = Subscripts::new(&[&[3, 0, 1], &[1, 2, 3]], &[3, 0, 2]);
496    /// let tree = ContractionTree::from_pairs(&subs, &[&[5, 2, 3], &[3, 4, 5]], &[(0, 1)]).unwrap();
497    /// let gemm = tree.step_plan(0).unwrap().gemm();
498    ///
499    /// assert_eq!(gemm.batch_modes(), &[3]);
500    /// ```
501    #[must_use]
502    pub fn batch_modes(&self) -> &'a [u32] {
503        let inner: &'a InnerGemmPlan = self.inner;
504        let batch_start = inner.lo_modes.len() + inner.ro_modes.len();
505        &inner.canonical_modes[batch_start..]
506    }
507
508    /// Return dimension sizes for [`Self::batch_modes`].
509    ///
510    /// # Examples
511    ///
512    /// ```
513    /// use tenferro_einsum::{ContractionTree, Subscripts};
514    ///
515    /// let subs = Subscripts::new(&[&[3, 0, 1], &[1, 2, 3]], &[3, 0, 2]);
516    /// let tree = ContractionTree::from_pairs(&subs, &[&[5, 2, 3], &[3, 4, 5]], &[(0, 1)]).unwrap();
517    /// let gemm = tree.step_plan(0).unwrap().gemm();
518    ///
519    /// assert_eq!(gemm.batch_shape(), &[5]);
520    /// ```
521    #[must_use]
522    pub fn batch_shape(&self) -> &'a [usize] {
523        let inner: &'a InnerGemmPlan = self.inner;
524        inner.batch_sizes.as_slice()
525    }
526
527    /// Return target mode order for preparing the left GEMM operand.
528    ///
529    /// # Examples
530    ///
531    /// ```
532    /// use tenferro_einsum::{ContractionTree, Subscripts};
533    ///
534    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
535    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
536    /// let gemm = tree.step_plan(0).unwrap().gemm();
537    ///
538    /// assert_eq!(gemm.lhs_target_modes(), &[0, 1]);
539    /// ```
540    #[must_use]
541    pub fn lhs_target_modes(&self) -> &'a [u32] {
542        let inner: &'a InnerGemmPlan = self.inner;
543        inner.target_a.as_slice()
544    }
545
546    /// Return target mode order for preparing the right GEMM operand.
547    ///
548    /// # Examples
549    ///
550    /// ```
551    /// use tenferro_einsum::{ContractionTree, Subscripts};
552    ///
553    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
554    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
555    /// let gemm = tree.step_plan(0).unwrap().gemm();
556    ///
557    /// assert_eq!(gemm.rhs_target_modes(), &[1, 2]);
558    /// ```
559    #[must_use]
560    pub fn rhs_target_modes(&self) -> &'a [u32] {
561        let inner: &'a InnerGemmPlan = self.inner;
562        inner.target_b.as_slice()
563    }
564
565    /// Return canonical output mode order before any final permutation.
566    ///
567    /// # Examples
568    ///
569    /// ```
570    /// use tenferro_einsum::{ContractionTree, Subscripts};
571    ///
572    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2, 0]);
573    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
574    /// let gemm = tree.step_plan(0).unwrap().gemm();
575    ///
576    /// assert_eq!(gemm.canonical_output_modes(), &[0, 2]);
577    /// ```
578    #[must_use]
579    pub fn canonical_output_modes(&self) -> &'a [u32] {
580        let inner: &'a InnerGemmPlan = self.inner;
581        inner.canonical_modes.as_slice()
582    }
583
584    /// Return fused left-only dimension size.
585    ///
586    /// # Examples
587    ///
588    /// ```
589    /// use tenferro_einsum::{ContractionTree, Subscripts};
590    ///
591    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
592    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
593    /// let gemm = tree.step_plan(0).unwrap().gemm();
594    ///
595    /// assert_eq!(gemm.m(), 2);
596    /// ```
597    #[must_use]
598    pub fn m(&self) -> usize {
599        self.inner.m
600    }
601
602    /// Return fused right-only dimension size.
603    ///
604    /// # Examples
605    ///
606    /// ```
607    /// use tenferro_einsum::{ContractionTree, Subscripts};
608    ///
609    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
610    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
611    /// let gemm = tree.step_plan(0).unwrap().gemm();
612    ///
613    /// assert_eq!(gemm.n(), 4);
614    /// ```
615    #[must_use]
616    pub fn n(&self) -> usize {
617        self.inner.n
618    }
619
620    /// Return fused contracted dimension size.
621    ///
622    /// # Examples
623    ///
624    /// ```
625    /// use tenferro_einsum::{ContractionTree, Subscripts};
626    ///
627    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
628    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
629    /// let gemm = tree.step_plan(0).unwrap().gemm();
630    ///
631    /// assert_eq!(gemm.k(), 3);
632    /// ```
633    #[must_use]
634    pub fn k(&self) -> usize {
635        self.inner.k
636    }
637
638    /// Return prepared left operand GEMM shape.
639    ///
640    /// # Examples
641    ///
642    /// ```
643    /// use tenferro_einsum::{ContractionTree, Subscripts};
644    ///
645    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
646    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
647    /// let gemm = tree.step_plan(0).unwrap().gemm();
648    ///
649    /// assert_eq!(gemm.lhs_gemm_shape(), &[2, 3]);
650    /// ```
651    #[must_use]
652    pub fn lhs_gemm_shape(&self) -> &'a [usize] {
653        let inner: &'a InnerGemmPlan = self.inner;
654        inner.a_gemm_shape.as_slice()
655    }
656
657    /// Return prepared right operand GEMM shape.
658    ///
659    /// # Examples
660    ///
661    /// ```
662    /// use tenferro_einsum::{ContractionTree, Subscripts};
663    ///
664    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
665    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
666    /// let gemm = tree.step_plan(0).unwrap().gemm();
667    ///
668    /// assert_eq!(gemm.rhs_gemm_shape(), &[3, 4]);
669    /// ```
670    #[must_use]
671    pub fn rhs_gemm_shape(&self) -> &'a [usize] {
672        let inner: &'a InnerGemmPlan = self.inner;
673        inner.b_gemm_shape.as_slice()
674    }
675
676    /// Return GEMM output shape before expanding fused dimensions.
677    ///
678    /// # Examples
679    ///
680    /// ```
681    /// use tenferro_einsum::{ContractionTree, Subscripts};
682    ///
683    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
684    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
685    /// let gemm = tree.step_plan(0).unwrap().gemm();
686    ///
687    /// assert_eq!(gemm.output_gemm_shape(), &[2, 4]);
688    /// ```
689    #[must_use]
690    pub fn output_gemm_shape(&self) -> &'a [usize] {
691        let inner: &'a InnerGemmPlan = self.inner;
692        inner.c_gemm_shape.as_slice()
693    }
694
695    /// Return expanded output shape in canonical output mode order.
696    ///
697    /// # Examples
698    ///
699    /// ```
700    /// use tenferro_einsum::{ContractionTree, Subscripts};
701    ///
702    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
703    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
704    /// let gemm = tree.step_plan(0).unwrap().gemm();
705    ///
706    /// assert_eq!(gemm.expanded_output_shape(), &[2, 4]);
707    /// ```
708    #[must_use]
709    pub fn expanded_output_shape(&self) -> &'a [usize] {
710        let inner: &'a InnerGemmPlan = self.inner;
711        inner.expanded_shape.as_slice()
712    }
713
714    /// Return whether canonical output requires a final permutation.
715    ///
716    /// # Examples
717    ///
718    /// ```
719    /// use tenferro_einsum::{ContractionTree, Subscripts};
720    ///
721    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[2, 0]);
722    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
723    /// let gemm = tree.step_plan(0).unwrap().gemm();
724    ///
725    /// assert!(gemm.needs_final_permute());
726    /// ```
727    #[must_use]
728    pub fn needs_final_permute(&self) -> bool {
729        self.inner.needs_final_permute
730    }
731}