tenferro_internal_ad_ops/
math.rs1use tenferro_algebra::{Scalar, Standard};
2use tenferro_einsum as tf_einsum;
3use tenferro_linalg::SolveGrad;
4use tenferro_tensor::Tensor;
5
6use chainrules_core::AutodiffError;
7
8use crate::runtime::contracts::{EinsumRuntimeValue, LinalgRuntimeValue};
9use crate::runtime::dispatch::{dispatch_einsum_runtime, with_linalg_runtime};
10use crate::{Error, Result};
11
12pub fn einsum_primal<'a, T>(subscripts: &'a str, operands: &'a [&'a Tensor<T>]) -> Result<Tensor<T>>
14where
15 T: EinsumRuntimeValue,
16{
17 dispatch_einsum_runtime!(T, "einsum", |ctx, Backend| {
18 tf_einsum::einsum::<Standard<T>, Backend>(ctx, subscripts, operands, None)
19 .map_err(Error::from)
20 })
21}
22
23pub fn einsum_rrule<'a, T>(
25 subscripts: &'a str,
26 operands: &'a [&'a Tensor<T>],
27 cotangent: &Tensor<T>,
28) -> Result<Vec<Tensor<T>>>
29where
30 T: EinsumRuntimeValue,
31{
32 dispatch_einsum_runtime!(T, "einsum_rrule", |ctx, Backend| {
33 tf_einsum::einsum_rrule::<Standard<T>, Backend>(ctx, subscripts, operands, cotangent)
34 .map_err(Error::from)
35 })
36}
37
38pub fn einsum_frule<'a, T>(
40 subscripts: &'a str,
41 primals: &'a [&'a Tensor<T>],
42 tangents: &'a [Option<&'a Tensor<T>>],
43) -> Result<Tensor<T>>
44where
45 T: EinsumRuntimeValue,
46{
47 if primals.len() != tangents.len() {
48 return Err(Error::Autodiff(AutodiffError::InvalidArgument(format!(
49 "einsum_frule requires tangents.len() == primals.len(), got {} vs {}",
50 tangents.len(),
51 primals.len()
52 ))));
53 }
54
55 dispatch_einsum_runtime!(T, "einsum_frule", |ctx, Backend| {
56 tf_einsum::einsum_frule::<Standard<T>, Backend>(ctx, subscripts, primals, tangents)
57 .map_err(Error::from)
58 })
59}
60
61pub fn solve_triangular_rrule<T>(
63 a: &Tensor<T>,
64 b: &Tensor<T>,
65 cotangent: &Tensor<T>,
66 upper: bool,
67) -> Result<SolveGrad<T>>
68where
69 T: Scalar + LinalgRuntimeValue,
70{
71 with_linalg_runtime::<T, _>(
72 "solve_triangular_rrule",
73 tenferro_linalg::backend::LinalgCapabilityOp::SolveTriangular,
74 |ctx| {
75 tenferro_linalg::solve_triangular_rrule::<T, _>(ctx, a, b, cotangent, upper)
76 .map_err(Error::from)
77 },
78 |ctx| {
79 tenferro_linalg::solve_triangular_rrule::<T, _>(ctx, a, b, cotangent, upper)
80 .map_err(Error::from)
81 },
82 |ctx| {
83 tenferro_linalg::solve_triangular_rrule::<T, _>(ctx, a, b, cotangent, upper)
84 .map_err(Error::from)
85 },
86 )
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use tenferro_internal_runtime::{set_default_runtime, RuntimeContext};
93 use tenferro_prims::CpuContext;
94 use tenferro_tensor::MemoryOrder;
95
96 fn dense_f64(values: &[f64], dims: &[usize]) -> Tensor<f64> {
97 Tensor::from_slice(values, dims, MemoryOrder::ColumnMajor).unwrap()
98 }
99
100 #[test]
101 fn einsum_frule_rejects_tangent_arity_mismatch() {
102 let x = dense_f64(&[1.0, 2.0], &[2]);
103 let y = dense_f64(&[3.0, 4.0], &[2]);
104
105 let err = einsum_frule("i,i->", &[&x, &y], &[Some(&x)]).unwrap_err();
106 assert!(err
107 .to_string()
108 .contains("einsum_frule requires tangents.len() == primals.len()"));
109 }
110
111 #[test]
112 fn einsum_helpers_run_with_cpu_runtime() {
113 let _runtime = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
114 let x = dense_f64(&[1.0, 2.0], &[2]);
115 let y = dense_f64(&[3.0, 5.0], &[2]);
116 let dx = dense_f64(&[0.5, -1.0], &[2]);
117 let cotangent = dense_f64(&[2.0], &[]);
118
119 let primal = einsum_primal("i,i->", &[&x, &y]).unwrap();
120 let tangent = einsum_frule("i,i->", &[&x, &y], &[Some(&dx), None]).unwrap();
121 let grads = einsum_rrule("i,i->", &[&x, &y], &cotangent).unwrap();
122
123 assert_eq!(primal.buffer().as_slice().unwrap(), &[13.0]);
124 assert_eq!(tangent.buffer().as_slice().unwrap(), &[-3.5]);
125 assert_eq!(grads[0].buffer().as_slice().unwrap(), &[6.0, 10.0]);
126 assert_eq!(grads[1].buffer().as_slice().unwrap(), &[2.0, 4.0]);
127 }
128
129 #[test]
130 fn solve_triangular_rrule_requires_runtime_and_preserves_shapes() {
131 let a = dense_f64(&[2.0, 0.0, 1.0, 3.0], &[2, 2]);
132 let b = dense_f64(&[4.0, 9.0], &[2]);
133 let cotangent = dense_f64(&[1.0, 2.0], &[2]);
134
135 let err = solve_triangular_rrule(&a, &b, &cotangent, false).unwrap_err();
136 assert!(matches!(err, Error::RuntimeNotConfigured));
137
138 let _runtime = set_default_runtime(RuntimeContext::Cpu(CpuContext::new(1)));
139 let grad = solve_triangular_rrule(&a, &b, &cotangent, false).unwrap();
140 assert_eq!(grad.a.dims(), &[2, 2]);
141 assert_eq!(grad.b.dims(), &[2]);
142 }
143}