tenferro_internal_ad_ops/
math.rs

1use tenferro_algebra::{Scalar, Standard};
2use tenferro_einsum as tf_einsum;
3use tenferro_linalg::SolveGrad;
4use tenferro_tensor::Tensor;
5
6use chainrules_core::AutodiffError;
7
8use crate::runtime::contracts::{EinsumRuntimeValue, LinalgRuntimeValue};
9use crate::runtime::dispatch::{dispatch_einsum_runtime, with_linalg_runtime};
10use crate::{Error, Result};
11
12/// Stateless reverse-mode rule (VJP) for einsum over dense primals.
13pub fn einsum_primal<'a, T>(subscripts: &'a str, operands: &'a [&'a Tensor<T>]) -> Result<Tensor<T>>
14where
15    T: EinsumRuntimeValue,
16{
17    dispatch_einsum_runtime!(T, "einsum", |ctx, Backend| {
18        tf_einsum::einsum::<Standard<T>, Backend>(ctx, subscripts, operands, None)
19            .map_err(Error::from)
20    })
21}
22
23/// Stateless reverse-mode rule (VJP) for einsum over dense primals.
24pub fn einsum_rrule<'a, T>(
25    subscripts: &'a str,
26    operands: &'a [&'a Tensor<T>],
27    cotangent: &Tensor<T>,
28) -> Result<Vec<Tensor<T>>>
29where
30    T: EinsumRuntimeValue,
31{
32    dispatch_einsum_runtime!(T, "einsum_rrule", |ctx, Backend| {
33        tf_einsum::einsum_rrule::<Standard<T>, Backend>(ctx, subscripts, operands, cotangent)
34            .map_err(Error::from)
35    })
36}
37
38/// Stateless forward-mode rule (JVP) for einsum over dense primals.
39pub fn einsum_frule<'a, T>(
40    subscripts: &'a str,
41    primals: &'a [&'a Tensor<T>],
42    tangents: &'a [Option<&'a Tensor<T>>],
43) -> Result<Tensor<T>>
44where
45    T: EinsumRuntimeValue,
46{
47    if primals.len() != tangents.len() {
48        return Err(Error::Autodiff(AutodiffError::InvalidArgument(format!(
49            "einsum_frule requires tangents.len() == primals.len(), got {} vs {}",
50            tangents.len(),
51            primals.len()
52        ))));
53    }
54
55    dispatch_einsum_runtime!(T, "einsum_frule", |ctx, Backend| {
56        tf_einsum::einsum_frule::<Standard<T>, Backend>(ctx, subscripts, primals, tangents)
57            .map_err(Error::from)
58    })
59}
60
61/// Stateless reverse-mode rule (VJP) for triangular solve.
62pub fn solve_triangular_rrule<T>(
63    a: &Tensor<T>,
64    b: &Tensor<T>,
65    cotangent: &Tensor<T>,
66    upper: bool,
67) -> Result<SolveGrad<T>>
68where
69    T: Scalar + LinalgRuntimeValue,
70{
71    with_linalg_runtime::<T, _>(
72        "solve_triangular_rrule",
73        tenferro_linalg::backend::LinalgCapabilityOp::SolveTriangular,
74        |ctx| {
75            tenferro_linalg::solve_triangular_rrule::<T, _>(ctx, a, b, cotangent, upper)
76                .map_err(Error::from)
77        },
78        |ctx| {
79            tenferro_linalg::solve_triangular_rrule::<T, _>(ctx, a, b, cotangent, upper)
80                .map_err(Error::from)
81        },
82        |ctx| {
83            tenferro_linalg::solve_triangular_rrule::<T, _>(ctx, a, b, cotangent, upper)
84                .map_err(Error::from)
85        },
86    )
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use tenferro_internal_runtime::{set_default_runtime, RuntimeContext};
93    use tenferro_prims::CpuContext;
94    use tenferro_tensor::MemoryOrder;
95
96    fn dense_f64(values: &[f64], dims: &[usize]) -> Tensor<f64> {
97        Tensor::from_slice(values, dims, MemoryOrder::ColumnMajor).unwrap()
98    }
99
100    #[test]
101    fn einsum_frule_rejects_tangent_arity_mismatch() {
102        let x = dense_f64(&[1.0, 2.0], &[2]);
103        let y = dense_f64(&[3.0, 4.0], &[2]);
104
105        let err = einsum_frule("i,i->", &[&x, &y], &[Some(&x)]).unwrap_err();
106        assert!(err
107            .to_string()
108            .contains("einsum_frule requires tangents.len() == primals.len()"));
109    }
110
111    #[test]
112    fn einsum_helpers_run_with_cpu_runtime() {
113        let _runtime = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
114        let x = dense_f64(&[1.0, 2.0], &[2]);
115        let y = dense_f64(&[3.0, 5.0], &[2]);
116        let dx = dense_f64(&[0.5, -1.0], &[2]);
117        let cotangent = dense_f64(&[2.0], &[]);
118
119        let primal = einsum_primal("i,i->", &[&x, &y]).unwrap();
120        let tangent = einsum_frule("i,i->", &[&x, &y], &[Some(&dx), None]).unwrap();
121        let grads = einsum_rrule("i,i->", &[&x, &y], &cotangent).unwrap();
122
123        assert_eq!(primal.buffer().as_slice().unwrap(), &[13.0]);
124        assert_eq!(tangent.buffer().as_slice().unwrap(), &[-3.5]);
125        assert_eq!(grads[0].buffer().as_slice().unwrap(), &[6.0, 10.0]);
126        assert_eq!(grads[1].buffer().as_slice().unwrap(), &[2.0, 4.0]);
127    }
128
129    #[test]
130    fn solve_triangular_rrule_requires_runtime_and_preserves_shapes() {
131        let a = dense_f64(&[2.0, 0.0, 1.0, 3.0], &[2, 2]);
132        let b = dense_f64(&[4.0, 9.0], &[2]);
133        let cotangent = dense_f64(&[1.0, 2.0], &[2]);
134
135        let err = solve_triangular_rrule(&a, &b, &cotangent, false).unwrap_err();
136        assert!(matches!(err, Error::RuntimeNotConfigured));
137
138        let _runtime = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
139        let grad = solve_triangular_rrule(&a, &b, &cotangent, false).unwrap();
140        assert_eq!(grad.a.dims(), &[2, 2]);
141        assert_eq!(grad.b.dims(), &[2]);
142    }
143}