1use tenferro_algebra::{Conjugate, HasAlgebra, Scalar, Semiring};
2use tenferro_device::Result;
3use tenferro_prims::TensorTempPoolContext;
4use tenferro_tensor::{MemoryOrder, Tensor};
5
6use crate::ad::delta::{prepare_reverse_context, DeltaContext, ReverseContext};
7use crate::api::{einsum_with_subscripts, einsum_with_subscripts_into};
8use crate::execution::backend::{BackendContext, EinsumBackend};
9use crate::execution::execute::execute_nested;
10use crate::syntax::nested::NestedEinsum;
11use crate::syntax::subscripts::Subscripts;
12
13fn apply_embed<Alg, Backend>(
14 ctx: &mut BackendContext<Alg, Backend>,
15 dctx: &DeltaContext<Alg::Scalar>,
16 base: Tensor<Alg::Scalar>,
17) -> Result<Tensor<Alg::Scalar>>
18where
19 Alg: Semiring,
20 Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
21 Backend: EinsumBackend<Alg>,
22 BackendContext<Alg, Backend>: TensorTempPoolContext,
23{
24 if let Some(ref es) = dctx.embed_subs {
25 einsum_with_subscripts::<Alg, Backend>(ctx, es, &[&base], None)
26 } else {
27 Ok(base)
28 }
29}
30
31fn eval_with_embed<Alg, Backend>(
35 ctx: &mut BackendContext<Alg, Backend>,
36 rctx: &ReverseContext<Alg::Scalar>,
37 leading: &Tensor<Alg::Scalar>,
38) -> Result<Tensor<Alg::Scalar>>
39where
40 Alg: Semiring,
41 Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
42 Backend: EinsumBackend<Alg>,
43 BackendContext<Alg, Backend>: TensorTempPoolContext,
44{
45 let ops = rctx.assemble_rev_operands(leading);
46 let base = einsum_with_subscripts::<Alg, Backend>(ctx, &rctx.dctx.base_subs, &ops, None)?;
47 apply_embed::<Alg, Backend>(ctx, &rctx.dctx, base)
48}
49pub fn einsum_rrule<Alg, Backend>(
64 ctx: &mut BackendContext<Alg, Backend>,
65 subscripts: &str,
66 operands: &[&Tensor<Alg::Scalar>],
67 cotangent: &Tensor<Alg::Scalar>,
68) -> Result<Vec<Tensor<Alg::Scalar>>>
69where
70 Alg: Semiring,
71 Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
72 Backend: EinsumBackend<Alg>,
73 BackendContext<Alg, Backend>: TensorTempPoolContext,
74{
75 let subs = Subscripts::parse(subscripts)?;
76 let n = operands.len();
77 let mut grads = Vec::with_capacity(n);
78
79 let shapes: Vec<&[usize]> = operands.iter().map(|op| op.dims()).collect();
81 let size_dict = crate::execution::util::build_size_dict(&subs, &shapes, None)?;
82
83 for k in 0..n {
84 let rctx = prepare_reverse_context::<Alg::Scalar>(&subs, operands, k, &size_dict)?;
85 let grad = eval_with_embed::<Alg, Backend>(ctx, &rctx, cotangent)?;
86 grads.push(grad);
87 }
88
89 Ok(grads)
90}
91
92pub fn einsum_frule<Alg, Backend>(
107 ctx: &mut BackendContext<Alg, Backend>,
108 subscripts: &str,
109 primals: &[&Tensor<Alg::Scalar>],
110 tangents: &[Option<&Tensor<Alg::Scalar>>],
111) -> Result<Tensor<Alg::Scalar>>
112where
113 Alg: Semiring,
114 Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
115 Backend: EinsumBackend<Alg>,
116 BackendContext<Alg, Backend>: TensorTempPoolContext,
117{
118 let subs = Subscripts::parse(subscripts)?;
119 let nested = if subscripts.contains('(') {
120 Some(NestedEinsum::parse(subscripts)?)
121 } else {
122 None
123 };
124 einsum_frule_impl::<Alg, Backend>(ctx, &subs, nested.as_ref(), primals, tangents)
125}
126
127pub(crate) fn einsum_frule_impl<Alg, Backend>(
129 ctx: &mut BackendContext<Alg, Backend>,
130 subs: &Subscripts,
131 nested: Option<&NestedEinsum>,
132 primals: &[&Tensor<Alg::Scalar>],
133 tangents: &[Option<&Tensor<Alg::Scalar>>],
134) -> Result<Tensor<Alg::Scalar>>
135where
136 Alg: Semiring,
137 Alg::Scalar: Scalar + Conjugate + HasAlgebra<Algebra = Alg>,
138 Backend: EinsumBackend<Alg>,
139 BackendContext<Alg, Backend>: TensorTempPoolContext,
140{
141 let n = primals.len();
142 let mut result: Option<Tensor<Alg::Scalar>> = None;
143
144 for k in 0..n {
145 if let Some(tangent_k) = tangents[k] {
146 let mut ops: Vec<&Tensor<Alg::Scalar>> = primals.to_vec();
147 ops[k] = tangent_k;
148
149 match &mut result {
150 None => {
151 let term = if let Some(nested) = nested {
152 execute_nested::<Alg, Backend>(ctx, nested, &ops, None)?
153 } else {
154 einsum_with_subscripts::<Alg, Backend>(ctx, subs, &ops, None)?
155 };
156 result = Some(term);
157 }
158 Some(existing) => {
159 let one = <Alg::Scalar as num_traits::One>::one();
160 if let Some(nested) = nested {
161 let term = execute_nested::<Alg, Backend>(ctx, nested, &ops, None)?;
165 let out_labels: &[u32] = &subs.output;
166 let identity_subs = Subscripts::new(&[out_labels], out_labels);
167 einsum_with_subscripts_into::<Alg, Backend>(
168 ctx,
169 &identity_subs,
170 &[&term],
171 one,
172 one,
173 existing,
174 None,
175 )?;
176 } else {
177 einsum_with_subscripts_into::<Alg, Backend>(
178 ctx, subs, &ops, one, one, existing, None,
179 )?;
180 }
181 }
182 }
183 }
184 }
185
186 match result {
187 Some(r) => Ok(r),
188 None => {
189 let primal_out = if let Some(nested) = nested {
190 execute_nested::<Alg, Backend>(ctx, nested, primals, None)?
191 } else {
192 einsum_with_subscripts::<Alg, Backend>(ctx, subs, primals, None)?
193 };
194 Tensor::<Alg::Scalar>::zeros(
195 primal_out.dims(),
196 primal_out.logical_memory_space(),
197 MemoryOrder::ColumnMajor,
198 )
199 }
200 }
201}