tenferro_einsum/ad/
rules.rs

1use tenferro_algebra::{Conjugate, HasAlgebra, Scalar, Semiring};
2use tenferro_device::Result;
3use tenferro_prims::TensorTempPoolContext;
4use tenferro_tensor::{MemoryOrder, Tensor};
5
6use crate::ad::delta::{prepare_reverse_context, DeltaContext, ReverseContext};
7use crate::api::{einsum_with_subscripts, einsum_with_subscripts_into};
8use crate::execution::backend::{BackendContext, EinsumBackend};
9use crate::execution::execute::execute_nested;
10use crate::syntax::nested::NestedEinsum;
11use crate::syntax::subscripts::Subscripts;
12
13fn apply_embed<Alg, Backend>(
14    ctx: &mut BackendContext<Alg, Backend>,
15    dctx: &DeltaContext<Alg::Scalar>,
16    base: Tensor<Alg::Scalar>,
17) -> Result<Tensor<Alg::Scalar>>
18where
19    Alg: Semiring,
20    Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
21    Backend: EinsumBackend<Alg>,
22    BackendContext<Alg, Backend>: TensorTempPoolContext,
23{
24    if let Some(ref es) = dctx.embed_subs {
25        einsum_with_subscripts::<Alg, Backend>(ctx, es, &[&base], None)
26    } else {
27        Ok(base)
28    }
29}
30
31/// Evaluate the base reverse einsum and, if needed, apply a diagonal-embedding
32/// pass so that repeated output labels (e.g. trace `"ii->"`) are correctly
33/// handled in the reverse-mode gradient.
34fn eval_with_embed<Alg, Backend>(
35    ctx: &mut BackendContext<Alg, Backend>,
36    rctx: &ReverseContext<Alg::Scalar>,
37    leading: &Tensor<Alg::Scalar>,
38) -> Result<Tensor<Alg::Scalar>>
39where
40    Alg: Semiring,
41    Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
42    Backend: EinsumBackend<Alg>,
43    BackendContext<Alg, Backend>: TensorTempPoolContext,
44{
45    let ops = rctx.assemble_rev_operands(leading);
46    let base = einsum_with_subscripts::<Alg, Backend>(ctx, &rctx.dctx.base_subs, &ops, None)?;
47    apply_embed::<Alg, Backend>(ctx, &rctx.dctx, base)
48}
49/// Reverse-mode rule (rrule) for einsum without building a global tape.
50///
51/// # Examples
52///
53/// ```ignore
54/// use tenferro_algebra::Standard;
55/// use tenferro_einsum::einsum_rrule;
56/// use tenferro_prims::{CpuBackend, CpuContext};
57///
58/// let mut ctx = CpuContext::new(1);
59/// let grads = einsum_rrule::<Standard<f64>, CpuBackend>(&mut ctx, "ij,jk->ik", &[&a, &b], &dc)
60///     .unwrap();
61/// assert_eq!(grads.len(), 2);
62/// ```
63pub fn einsum_rrule<Alg, Backend>(
64    ctx: &mut BackendContext<Alg, Backend>,
65    subscripts: &str,
66    operands: &[&Tensor<Alg::Scalar>],
67    cotangent: &Tensor<Alg::Scalar>,
68) -> Result<Vec<Tensor<Alg::Scalar>>>
69where
70    Alg: Semiring,
71    Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
72    Backend: EinsumBackend<Alg>,
73    BackendContext<Alg, Backend>: TensorTempPoolContext,
74{
75    let subs = Subscripts::parse(subscripts)?;
76    let n = operands.len();
77    let mut grads = Vec::with_capacity(n);
78
79    // Build a size dictionary mapping labels to dimensions.
80    let shapes: Vec<&[usize]> = operands.iter().map(|op| op.dims()).collect();
81    let size_dict = crate::execution::util::build_size_dict(&subs, &shapes, None)?;
82
83    for k in 0..n {
84        let rctx = prepare_reverse_context::<Alg::Scalar>(&subs, operands, k, &size_dict)?;
85        let grad = eval_with_embed::<Alg, Backend>(ctx, &rctx, cotangent)?;
86        grads.push(grad);
87    }
88
89    Ok(grads)
90}
91
92/// Forward-mode rule (frule) for einsum without building a global tape.
93///
94/// # Examples
95///
96/// ```ignore
97/// use tenferro_algebra::Standard;
98/// use tenferro_einsum::einsum_frule;
99/// use tenferro_prims::{CpuBackend, CpuContext};
100///
101/// let mut ctx = CpuContext::new(1);
102/// let tangent =
103///     einsum_frule::<Standard<f64>, CpuBackend>(&mut ctx, "ij,jk->ik", &[&a, &b], &[Some(&da), None])
104///         .unwrap();
105/// ```
106pub fn einsum_frule<Alg, Backend>(
107    ctx: &mut BackendContext<Alg, Backend>,
108    subscripts: &str,
109    primals: &[&Tensor<Alg::Scalar>],
110    tangents: &[Option<&Tensor<Alg::Scalar>>],
111) -> Result<Tensor<Alg::Scalar>>
112where
113    Alg: Semiring,
114    Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
115    Backend: EinsumBackend<Alg>,
116    BackendContext<Alg, Backend>: TensorTempPoolContext,
117{
118    let subs = Subscripts::parse(subscripts)?;
119    let nested = if subscripts.contains('(') {
120        Some(NestedEinsum::parse(subscripts)?)
121    } else {
122        None
123    };
124    einsum_frule_impl::<Alg, Backend>(ctx, &subs, nested.as_ref(), primals, tangents)
125}
126
127/// Internal frule implementation with pre-parsed subscripts.
128pub(crate) fn einsum_frule_impl<Alg, Backend>(
129    ctx: &mut BackendContext<Alg, Backend>,
130    subs: &Subscripts,
131    nested: Option<&NestedEinsum>,
132    primals: &[&Tensor<Alg::Scalar>],
133    tangents: &[Option<&Tensor<Alg::Scalar>>],
134) -> Result<Tensor<Alg::Scalar>>
135where
136    Alg: Semiring,
137    Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
138    Backend: EinsumBackend<Alg>,
139    BackendContext<Alg, Backend>: TensorTempPoolContext,
140{
141    let n = primals.len();
142    let mut result: Option<Tensor<Alg::Scalar>> = None;
143
144    for k in 0..n {
145        if let Some(tangent_k) = tangents[k] {
146            let mut ops: Vec<&Tensor<Alg::Scalar>> = primals.to_vec();
147            ops[k] = tangent_k;
148
149            match &mut result {
150                None => {
151                    let term = if let Some(nested) = nested {
152                        execute_nested::<Alg, Backend>(ctx, nested, &ops, None)?
153                    } else {
154                        einsum_with_subscripts::<Alg, Backend>(ctx, subs, &ops, None)?
155                    };
156                    result = Some(term);
157                }
158                Some(existing) => {
159                    let one = <Alg::Scalar as num_traits::One>::one();
160                    if let Some(nested) = nested {
161                        // Nested einsum does not support _into; materialize then
162                        // accumulate via an identity contraction (the nested tree
163                        // may have fewer roots than `subs.inputs`).
164                        let term = execute_nested::<Alg, Backend>(ctx, nested, &ops, None)?;
165                        let out_labels: &[u32] = &subs.output;
166                        let identity_subs = Subscripts::new(&[out_labels], out_labels);
167                        einsum_with_subscripts_into::<Alg, Backend>(
168                            ctx,
169                            &identity_subs,
170                            &[&term],
171                            one,
172                            one,
173                            existing,
174                            None,
175                        )?;
176                    } else {
177                        einsum_with_subscripts_into::<Alg, Backend>(
178                            ctx, subs, &ops, one, one, existing, None,
179                        )?;
180                    }
181                }
182            }
183        }
184    }
185
186    match result {
187        Some(r) => Ok(r),
188        None => {
189            let primal_out = if let Some(nested) = nested {
190                execute_nested::<Alg, Backend>(ctx, nested, primals, None)?
191            } else {
192                einsum_with_subscripts::<Alg, Backend>(ctx, subs, primals, None)?
193            };
194            Tensor::<Alg::Scalar>::zeros(
195                primal_out.dims(),
196                primal_out.logical_memory_space(),
197                MemoryOrder::ColumnMajor,
198            )
199        }
200    }
201}