tenferro_internal_ad_ops/
linearized.rs

1use std::ops::Add;
2use std::sync::Arc;
3
4use num_complex::{Complex32, Complex64};
5use num_traits::Zero;
6use tenferro_algebra::Scalar;
7use tenferro_internal_ad_core::{
8    AdResult, AutodiffError, CheckpointHint, DynValue, LinearizableOp, LinearizedOp, Schema,
9    SlotSchema,
10};
11use tenferro_internal_frontend_core::tensor_ops::{
12    tensor_element, tensor_map_binary_typed, tensor_map_unary_typed,
13};
14use tenferro_internal_frontend_core::{DynTensor, DynTensorTyped, StructuredTensor};
15use tenferro_tensor::{MemoryOrder, Tensor as DenseTensor};
16
17use crate::math::{einsum_frule, einsum_primal, einsum_rrule};
18use crate::{Error, Result};
19
20#[derive(Clone, Copy)]
21pub struct AddOp;
22
23#[derive(Clone, Copy)]
24pub struct ExpOp;
25
26#[derive(Clone, Copy)]
27pub struct SumOp;
28
29#[derive(Clone)]
30pub struct EinsumOp {
31    subscripts: Arc<str>,
32}
33
34#[doc(hidden)]
35pub struct AddLinearized;
36
37#[doc(hidden)]
38pub struct ExpLinearized {
39    output: DynTensor,
40}
41
42#[doc(hidden)]
43pub struct SumLinearized {
44    input: DynTensor,
45}
46
47#[doc(hidden)]
48pub struct EinsumLinearized {
49    subscripts: Arc<str>,
50    inputs: Vec<DynTensor>,
51}
52
53fn differentiable_schema(slots: usize) -> Schema {
54    Schema {
55        slots: (0..slots)
56            .map(|_| SlotSchema {
57                differentiable: true,
58                auxiliary: false,
59            })
60            .collect(),
61    }
62}
63
64fn invalid_argument(message: impl Into<String>) -> Error {
65    AutodiffError::InvalidArgument(message.into()).into()
66}
67
68fn into_ad_error(error: Error) -> AutodiffError {
69    match error {
70        Error::Autodiff(error) => error,
71        other => AutodiffError::InvalidArgument(other.to_string()),
72    }
73}
74
75fn structured_binary<T>(
76    lhs: &StructuredTensor<T>,
77    rhs: &StructuredTensor<T>,
78    f: impl FnMut(T, T) -> T,
79) -> Result<StructuredTensor<T>>
80where
81    T: Scalar + Copy,
82{
83    lhs.with_payload_like(tensor_map_binary_typed(lhs.payload(), rhs.payload(), f)?)
84}
85
86fn structured_unary<T, U>(
87    input: &StructuredTensor<T>,
88    f: impl FnMut(T) -> U,
89) -> Result<StructuredTensor<U>>
90where
91    T: Scalar + Copy,
92    U: Scalar + Copy,
93{
94    let payload = tensor_map_unary_typed(input.payload(), f)?;
95    Ok(StructuredTensor::from(payload))
96}
97
98fn dense_host_slice<'a, T>(tensor: &'a DenseTensor<T>, context: &str) -> Result<&'a [T]> {
99    tensor.buffer().as_slice().ok_or_else(|| {
100        invalid_argument(format!("{context} requires host-accessible dense payload"))
101    })
102}
103
104fn scalar_from_rank0<T>(value: &StructuredTensor<T>, context: &str) -> Result<T>
105where
106    T: Scalar + Copy,
107{
108    if !value.logical_dims().is_empty() {
109        return Err(invalid_argument(format!(
110            "{context} requires a rank-0 tensor, got {:?}",
111            value.logical_dims()
112        )));
113    }
114    tensor_element(value.payload(), &[])
115}
116
117fn structured_sum_all<T>(input: &StructuredTensor<T>) -> Result<StructuredTensor<T>>
118where
119    T: Scalar + Copy + Zero + Add<Output = T>,
120{
121    let dense = input.to_dense()?;
122    let mut acc = T::zero();
123    for &value in dense_host_slice(&dense, "sum")? {
124        acc = acc + value;
125    }
126    let payload = DenseTensor::from_slice(&[acc], &[], MemoryOrder::ColumnMajor)?;
127    Ok(StructuredTensor::from(payload))
128}
129
130fn structured_broadcast_scalar_like<T>(
131    scalar: &StructuredTensor<T>,
132    like: &StructuredTensor<T>,
133) -> Result<StructuredTensor<T>>
134where
135    T: Scalar + Copy,
136{
137    let value = scalar_from_rank0(scalar, "broadcast_scalar_like")?;
138    let total = like.logical_dims().iter().product();
139    let payload = DenseTensor::from_slice(
140        &vec![value; total],
141        like.logical_dims(),
142        MemoryOrder::ColumnMajor,
143    )?;
144    like.with_payload_like(payload)
145}
146
147fn dyn_add(lhs: &DynTensor, rhs: &DynTensor) -> Result<DynTensor> {
148    match (lhs, rhs) {
149        (DynTensor::F32(lhs), DynTensor::F32(rhs)) => {
150            Ok(DynTensor::F32(structured_binary(lhs, rhs, |x, y| x + y)?))
151        }
152        (DynTensor::F64(lhs), DynTensor::F64(rhs)) => {
153            Ok(DynTensor::F64(structured_binary(lhs, rhs, |x, y| x + y)?))
154        }
155        (DynTensor::C32(lhs), DynTensor::C32(rhs)) => {
156            Ok(DynTensor::C32(structured_binary(lhs, rhs, |x, y| x + y)?))
157        }
158        (DynTensor::C64(lhs), DynTensor::C64(rhs)) => {
159            Ok(DynTensor::C64(structured_binary(lhs, rhs, |x, y| x + y)?))
160        }
161        _ => Err(invalid_argument(format!(
162            "add requires matching dtypes, got lhs={:?}, rhs={:?}",
163            lhs.scalar_type(),
164            rhs.scalar_type()
165        ))),
166    }
167}
168
169fn dyn_mul(lhs: &DynTensor, rhs: &DynTensor) -> Result<DynTensor> {
170    match (lhs, rhs) {
171        (DynTensor::F32(lhs), DynTensor::F32(rhs)) => {
172            Ok(DynTensor::F32(structured_binary(lhs, rhs, |x, y| x * y)?))
173        }
174        (DynTensor::F64(lhs), DynTensor::F64(rhs)) => {
175            Ok(DynTensor::F64(structured_binary(lhs, rhs, |x, y| x * y)?))
176        }
177        (DynTensor::C32(lhs), DynTensor::C32(rhs)) => {
178            Ok(DynTensor::C32(structured_binary(lhs, rhs, |x, y| x * y)?))
179        }
180        (DynTensor::C64(lhs), DynTensor::C64(rhs)) => {
181            Ok(DynTensor::C64(structured_binary(lhs, rhs, |x, y| x * y)?))
182        }
183        _ => Err(invalid_argument(format!(
184            "mul requires matching dtypes, got lhs={:?}, rhs={:?}",
185            lhs.scalar_type(),
186            rhs.scalar_type()
187        ))),
188    }
189}
190
191fn dyn_exp(input: &DynTensor) -> Result<DynTensor> {
192    match input {
193        DynTensor::F32(value) => Ok(DynTensor::F32(structured_unary(value, |x: f32| x.exp())?)),
194        DynTensor::F64(value) => Ok(DynTensor::F64(structured_unary(value, |x: f64| x.exp())?)),
195        DynTensor::C32(value) => Ok(DynTensor::C32(structured_unary(value, |z: Complex32| {
196            z.exp()
197        })?)),
198        DynTensor::C64(value) => Ok(DynTensor::C64(structured_unary(value, |z: Complex64| {
199            z.exp()
200        })?)),
201    }
202}
203
204fn dyn_sum_all(input: &DynTensor) -> Result<DynTensor> {
205    match input {
206        DynTensor::F32(value) => Ok(DynTensor::F32(structured_sum_all(value)?)),
207        DynTensor::F64(value) => Ok(DynTensor::F64(structured_sum_all(value)?)),
208        DynTensor::C32(value) => Ok(DynTensor::C32(structured_sum_all(value)?)),
209        DynTensor::C64(value) => Ok(DynTensor::C64(structured_sum_all(value)?)),
210    }
211}
212
213fn dyn_broadcast_scalar_like(scalar: &DynTensor, like: &DynTensor) -> Result<DynTensor> {
214    match (scalar, like) {
215        (DynTensor::F32(scalar), DynTensor::F32(like)) => Ok(DynTensor::F32(
216            structured_broadcast_scalar_like(scalar, like)?,
217        )),
218        (DynTensor::F64(scalar), DynTensor::F64(like)) => Ok(DynTensor::F64(
219            structured_broadcast_scalar_like(scalar, like)?,
220        )),
221        (DynTensor::C32(scalar), DynTensor::C32(like)) => Ok(DynTensor::C32(
222            structured_broadcast_scalar_like(scalar, like)?,
223        )),
224        (DynTensor::C64(scalar), DynTensor::C64(like)) => Ok(DynTensor::C64(
225            structured_broadcast_scalar_like(scalar, like)?,
226        )),
227        _ => Err(invalid_argument(format!(
228            "broadcast requires matching dtypes, got scalar={:?}, like={:?}",
229            scalar.scalar_type(),
230            like.scalar_type()
231        ))),
232    }
233}
234
235fn dense_dyn_tensor_typed<T>(value: &DynTensor, context: &str) -> Result<DenseTensor<T>>
236where
237    T: DynTensorTyped + Copy,
238{
239    let structured = T::structured_ref(value)
240        .ok_or_else(|| invalid_argument(format!("{context} requires matching dtypes")))?;
241    structured.to_dense()
242}
243
244fn collect_dense_dyn_tensors<T>(values: &[&DynTensor], context: &str) -> Result<Vec<DenseTensor<T>>>
245where
246    T: DynTensorTyped + Copy,
247{
248    values
249        .iter()
250        .map(|value| dense_dyn_tensor_typed::<T>(value, context))
251        .collect()
252}
253
254fn optional_dense_dyn_tensor_typed<T>(
255    value: &Option<DynTensor>,
256    context: &str,
257) -> Result<Option<DenseTensor<T>>>
258where
259    T: DynTensorTyped + Copy,
260{
261    value
262        .as_ref()
263        .map(|tensor| dense_dyn_tensor_typed::<T>(tensor, context))
264        .transpose()
265}
266
267fn collect_optional_dense_dyn_tensors<T>(
268    values: &[Option<DynTensor>],
269    context: &str,
270) -> Result<Vec<Option<DenseTensor<T>>>>
271where
272    T: DynTensorTyped + Copy,
273{
274    values
275        .iter()
276        .map(|value| optional_dense_dyn_tensor_typed::<T>(value, context))
277        .collect()
278}
279
280fn dyn_from_dense<T>(value: DenseTensor<T>) -> DynTensor
281where
282    T: DynTensorTyped + Copy,
283{
284    T::into_dyn(StructuredTensor::from(value))
285}
286
287fn dyn_einsum_primal_t<T>(subscripts: &str, inputs: &[&DynTensor]) -> Result<DynTensor>
288where
289    T: crate::runtime::contracts::EinsumRuntimeValue + DynTensorTyped + Copy,
290{
291    let dense_inputs = collect_dense_dyn_tensors::<T>(inputs, "einsum")?;
292    let input_refs: Vec<&DenseTensor<T>> = dense_inputs.iter().collect();
293    let output = einsum_primal(subscripts, &input_refs)?;
294    Ok(dyn_from_dense(output))
295}
296
297fn dyn_einsum_jvp_t<T>(
298    subscripts: &str,
299    primals: &[DynTensor],
300    tangents: &[Option<DynTensor>],
301) -> Result<Option<DynTensor>>
302where
303    T: crate::runtime::contracts::EinsumRuntimeValue + DynTensorTyped + Copy,
304{
305    if tangents.iter().all(Option::is_none) {
306        return Ok(None);
307    }
308    let primal_refs: Vec<&DynTensor> = primals.iter().collect();
309    let dense_primals = collect_dense_dyn_tensors::<T>(&primal_refs, "einsum_jvp")?;
310    let dense_tangents = collect_optional_dense_dyn_tensors::<T>(tangents, "einsum_jvp")?;
311    let primal_refs: Vec<&DenseTensor<T>> = dense_primals.iter().collect();
312    let tangent_refs: Vec<Option<&DenseTensor<T>>> =
313        dense_tangents.iter().map(Option::as_ref).collect();
314    let tangent = einsum_frule(subscripts, &primal_refs, &tangent_refs)?;
315    Ok(Some(dyn_from_dense(tangent)))
316}
317
318fn dyn_einsum_vjp_t<T>(
319    subscripts: &str,
320    inputs: &[DynTensor],
321    cotangent: &DynTensor,
322    input_grad_mask: &[bool],
323) -> Result<Vec<Option<DynTensor>>>
324where
325    T: crate::runtime::contracts::EinsumRuntimeValue + DynTensorTyped + Copy,
326{
327    let input_refs: Vec<&DynTensor> = inputs.iter().collect();
328    let dense_inputs = collect_dense_dyn_tensors::<T>(&input_refs, "einsum_vjp")?;
329    let input_refs: Vec<&DenseTensor<T>> = dense_inputs.iter().collect();
330    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "einsum_vjp")?;
331    let grads = einsum_rrule(subscripts, &input_refs, &dense_cotangent)?;
332    Ok(grads
333        .into_iter()
334        .zip(input_grad_mask.iter().copied())
335        .map(|(grad, needed)| needed.then(|| dyn_from_dense(grad)))
336        .collect())
337}
338
339impl LinearizableOp<DynTensor> for AddOp {
340    type Linearized = AddLinearized;
341
342    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
343        Ok(vec![dyn_add(inputs[0], inputs[1]).map_err(into_ad_error)?])
344    }
345
346    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
347        Ok(differentiable_schema(2))
348    }
349
350    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
351        Ok(differentiable_schema(1))
352    }
353
354    fn linearize(
355        &self,
356        _inputs: &[&DynTensor],
357        _outputs: &[DynTensor],
358    ) -> AdResult<Self::Linearized> {
359        Ok(AddLinearized)
360    }
361
362    fn checkpoint_hint(&self) -> CheckpointHint {
363        CheckpointHint::CheapReplay
364    }
365}
366
367impl LinearizedOp<DynTensor> for AddLinearized {
368    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
369        let tangent = match (&input_tangents[0], &input_tangents[1]) {
370            (None, None) => None,
371            (Some(lhs), None) => Some(lhs.clone()),
372            (None, Some(rhs)) => Some(rhs.clone()),
373            (Some(lhs), Some(rhs)) => Some(dyn_add(lhs, rhs).map_err(into_ad_error)?),
374        };
375        Ok(vec![tangent])
376    }
377
378    fn vjp(
379        &self,
380        output_cotangents: &[Option<DynTensor>],
381        input_grad_mask: &[bool],
382    ) -> AdResult<Vec<Option<DynTensor>>> {
383        let grad = output_cotangents[0].clone();
384        Ok(vec![
385            input_grad_mask[0].then(|| grad.clone()).flatten(),
386            input_grad_mask[1].then_some(grad).flatten(),
387        ])
388    }
389}
390
391impl LinearizableOp<DynTensor> for ExpOp {
392    type Linearized = ExpLinearized;
393
394    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
395        Ok(vec![dyn_exp(inputs[0]).map_err(into_ad_error)?])
396    }
397
398    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
399        Ok(differentiable_schema(1))
400    }
401
402    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
403        Ok(differentiable_schema(1))
404    }
405
406    fn linearize(
407        &self,
408        _inputs: &[&DynTensor],
409        outputs: &[DynTensor],
410    ) -> AdResult<Self::Linearized> {
411        Ok(ExpLinearized {
412            output: outputs[0].clone(),
413        })
414    }
415
416    fn checkpoint_hint(&self) -> CheckpointHint {
417        CheckpointHint::CheapReplay
418    }
419}
420
421impl LinearizedOp<DynTensor> for ExpLinearized {
422    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
423        Ok(vec![match &input_tangents[0] {
424            Some(tangent) => Some(dyn_mul(&self.output, tangent).map_err(into_ad_error)?),
425            None => None,
426        }])
427    }
428
429    fn vjp(
430        &self,
431        output_cotangents: &[Option<DynTensor>],
432        input_grad_mask: &[bool],
433    ) -> AdResult<Vec<Option<DynTensor>>> {
434        Ok(vec![if input_grad_mask[0] {
435            match &output_cotangents[0] {
436                Some(grad_out) => Some(dyn_mul(&self.output, grad_out).map_err(into_ad_error)?),
437                None => None,
438            }
439        } else {
440            None
441        }])
442    }
443}
444
445impl LinearizableOp<DynTensor> for SumOp {
446    type Linearized = SumLinearized;
447
448    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
449        Ok(vec![dyn_sum_all(inputs[0]).map_err(into_ad_error)?])
450    }
451
452    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
453        Ok(differentiable_schema(1))
454    }
455
456    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
457        Ok(differentiable_schema(1))
458    }
459
460    fn linearize(
461        &self,
462        inputs: &[&DynTensor],
463        _outputs: &[DynTensor],
464    ) -> AdResult<Self::Linearized> {
465        Ok(SumLinearized {
466            input: inputs[0].clone(),
467        })
468    }
469
470    fn checkpoint_hint(&self) -> CheckpointHint {
471        CheckpointHint::CheapReplay
472    }
473}
474
475impl LinearizedOp<DynTensor> for SumLinearized {
476    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
477        Ok(vec![match &input_tangents[0] {
478            Some(tangent) => Some(dyn_sum_all(tangent).map_err(into_ad_error)?),
479            None => None,
480        }])
481    }
482
483    fn vjp(
484        &self,
485        output_cotangents: &[Option<DynTensor>],
486        input_grad_mask: &[bool],
487    ) -> AdResult<Vec<Option<DynTensor>>> {
488        Ok(vec![if input_grad_mask[0] {
489            match &output_cotangents[0] {
490                Some(grad_out) => {
491                    Some(dyn_broadcast_scalar_like(grad_out, &self.input).map_err(into_ad_error)?)
492                }
493                None => None,
494            }
495        } else {
496            None
497        }])
498    }
499}
500
501impl EinsumOp {
502    pub fn new(subscripts: impl Into<String>) -> Self {
503        Self {
504            subscripts: Arc::<str>::from(subscripts.into()),
505        }
506    }
507}
508
509impl LinearizableOp<DynTensor> for EinsumOp {
510    type Linearized = EinsumLinearized;
511
512    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
513        let output = match inputs.first() {
514            Some(DynTensor::F32(_)) => dyn_einsum_primal_t::<f32>(&self.subscripts, inputs),
515            Some(DynTensor::F64(_)) => dyn_einsum_primal_t::<f64>(&self.subscripts, inputs),
516            Some(DynTensor::C32(_)) => dyn_einsum_primal_t::<Complex32>(&self.subscripts, inputs),
517            Some(DynTensor::C64(_)) => dyn_einsum_primal_t::<Complex64>(&self.subscripts, inputs),
518            None => Err(invalid_argument("einsum requires at least one input")),
519        }
520        .map_err(into_ad_error)?;
521        Ok(vec![output])
522    }
523
524    fn input_schema(&self, inputs: &[&DynTensor]) -> AdResult<Schema> {
525        Ok(differentiable_schema(inputs.len()))
526    }
527
528    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
529        Ok(differentiable_schema(1))
530    }
531
532    fn linearize(
533        &self,
534        inputs: &[&DynTensor],
535        _outputs: &[DynTensor],
536    ) -> AdResult<Self::Linearized> {
537        Ok(EinsumLinearized {
538            subscripts: self.subscripts.clone(),
539            inputs: inputs.iter().map(|input| (*input).clone()).collect(),
540        })
541    }
542
543    fn checkpoint_hint(&self) -> CheckpointHint {
544        CheckpointHint::ExpensiveReplay
545    }
546}
547
548impl LinearizedOp<DynTensor> for EinsumLinearized {
549    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
550        let tangent = match self.inputs.first() {
551            Some(DynTensor::F32(_)) => {
552                dyn_einsum_jvp_t::<f32>(&self.subscripts, &self.inputs, input_tangents)
553            }
554            Some(DynTensor::F64(_)) => {
555                dyn_einsum_jvp_t::<f64>(&self.subscripts, &self.inputs, input_tangents)
556            }
557            Some(DynTensor::C32(_)) => {
558                dyn_einsum_jvp_t::<Complex32>(&self.subscripts, &self.inputs, input_tangents)
559            }
560            Some(DynTensor::C64(_)) => {
561                dyn_einsum_jvp_t::<Complex64>(&self.subscripts, &self.inputs, input_tangents)
562            }
563            None => Err(invalid_argument(
564                "einsum linearization requires at least one input",
565            )),
566        }
567        .map_err(into_ad_error)?;
568        Ok(vec![tangent])
569    }
570
571    fn vjp(
572        &self,
573        output_cotangents: &[Option<DynTensor>],
574        input_grad_mask: &[bool],
575    ) -> AdResult<Vec<Option<DynTensor>>> {
576        let Some(cotangent) = output_cotangents[0].as_ref() else {
577            return Ok((0..self.inputs.len()).map(|_| None).collect());
578        };
579        match self.inputs.first() {
580            Some(DynTensor::F32(_)) => {
581                dyn_einsum_vjp_t::<f32>(&self.subscripts, &self.inputs, cotangent, input_grad_mask)
582            }
583            Some(DynTensor::F64(_)) => {
584                dyn_einsum_vjp_t::<f64>(&self.subscripts, &self.inputs, cotangent, input_grad_mask)
585            }
586            Some(DynTensor::C32(_)) => dyn_einsum_vjp_t::<Complex32>(
587                &self.subscripts,
588                &self.inputs,
589                cotangent,
590                input_grad_mask,
591            ),
592            Some(DynTensor::C64(_)) => dyn_einsum_vjp_t::<Complex64>(
593                &self.subscripts,
594                &self.inputs,
595                cotangent,
596                input_grad_mask,
597            ),
598            None => Err(invalid_argument(
599                "einsum linearization requires at least one input",
600            )),
601        }
602        .map_err(into_ad_error)
603    }
604}
605
606pub fn add_dyn_values(lhs: &DynValue, rhs: &DynValue) -> AdResult<DynValue> {
607    AddOp.apply_one(&[lhs, rhs])
608}
609
610pub fn exp_dyn_value(input: &DynValue) -> AdResult<DynValue> {
611    ExpOp.apply_one(&[input])
612}
613
614pub fn sum_dyn_value(input: &DynValue) -> AdResult<DynValue> {
615    SumOp.apply_one(&[input])
616}
617
618pub fn einsum_dyn_values(subscripts: &str, inputs: &[&DynValue]) -> AdResult<DynValue> {
619    EinsumOp::new(subscripts).apply_one(inputs)
620}