tenferro_internal_ad_surface/
jvp.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use tenferro_internal_ad_core::{LinearizableOp, LinearizedOp};
5use tenferro_internal_ad_linalg::{
6    CholeskyOp, DetOp, EigOp, EigenOp, InvOp, LstsqOp, LuOp, MatrixExpOp, NormOp, PInvOp, QrOp,
7    SlogdetOp, SolveOp, SolveTriangularOp, SvdOp,
8};
9use tenferro_internal_ad_ops::{AddOp, EinsumOp, ExpOp, SumOp};
10use tenferro_internal_frontend_core::DynTensor;
11use tenferro_linalg::{LuPivot, MatrixNormOrd, NormKind, SvdOptions, VectorNormOrd};
12
13use crate::{Error, Result, Tensor};
14
15#[derive(Debug)]
16pub struct JvpResult {
17    pub outputs: Vec<Tensor>,
18    pub output_tangents: Vec<Option<Tensor>>,
19}
20
21#[derive(Default)]
22struct ForwardJvpContext {
23    tangents: HashMap<usize, DynTensor>,
24}
25
26impl ForwardJvpContext {
27    fn tangent_for_id(&self, id: usize) -> Option<DynTensor> {
28        self.tangents.get(&id).cloned()
29    }
30
31    fn set_tangent(&mut self, id: usize, tangent: DynTensor) {
32        self.tangents.insert(id, tangent);
33    }
34}
35
36thread_local! {
37    static FORWARD_JVP_STACK: RefCell<Vec<ForwardJvpContext>> = const { RefCell::new(Vec::new()) };
38}
39
40struct ForwardJvpGuard;
41
42impl Drop for ForwardJvpGuard {
43    fn drop(&mut self) {
44        FORWARD_JVP_STACK.with(|stack| {
45            stack.borrow_mut().pop();
46        });
47    }
48}
49
50fn invalid_argument(message: impl Into<String>) -> Error {
51    Error::InvalidTensorOperands {
52        message: message.into(),
53    }
54}
55
56fn validate_compatibility(index: usize, primal: &Tensor, tangent: &Tensor) -> Result<()> {
57    if primal.scalar_type() != tangent.scalar_type() {
58        return Err(invalid_argument(format!(
59            "jvp tangent {index} dtype mismatch: primal={:?}, tangent={:?}",
60            primal.scalar_type(),
61            tangent.scalar_type()
62        )));
63    }
64    if primal.dims() != tangent.dims() {
65        return Err(invalid_argument(format!(
66            "jvp tangent {index} shape mismatch: primal={:?}, tangent={:?}",
67            primal.dims(),
68            tangent.dims()
69        )));
70    }
71    if primal.axis_classes() != tangent.axis_classes()
72        || primal.is_dense() != tangent.is_dense()
73        || primal.is_diag() != tangent.is_diag()
74    {
75        return Err(invalid_argument(format!(
76            "jvp tangent {index} layout mismatch: primal dense={} diag={} classes={:?}, tangent dense={} diag={} classes={:?}",
77            primal.is_dense(),
78            primal.is_diag(),
79            primal.axis_classes(),
80            tangent.is_dense(),
81            tangent.is_diag(),
82            tangent.axis_classes()
83        )));
84    }
85    Ok(())
86}
87
88fn push_context(ctx: ForwardJvpContext) -> ForwardJvpGuard {
89    FORWARD_JVP_STACK.with(|stack| {
90        stack.borrow_mut().push(ctx);
91    });
92    ForwardJvpGuard
93}
94
95fn with_current_context<R>(f: impl FnOnce(&ForwardJvpContext) -> R) -> Option<R> {
96    FORWARD_JVP_STACK.with(|stack| stack.borrow().last().map(f))
97}
98
99fn with_current_context_mut<R>(f: impl FnOnce(&mut ForwardJvpContext) -> R) -> Option<R> {
100    FORWARD_JVP_STACK.with(|stack| stack.borrow_mut().last_mut().map(f))
101}
102
103pub(crate) fn is_active() -> bool {
104    with_current_context(|_| ()).is_some()
105}
106
107pub(crate) fn tangent_for(tensor: &Tensor) -> Option<DynTensor> {
108    with_current_context(|ctx| ctx.tangent_for_id(tensor.forward_id())).flatten()
109}
110
111pub(crate) fn record_tangent(tensor: &Tensor, tangent: Option<DynTensor>) {
112    if let Some(tangent) = tangent {
113        let _ = with_current_context_mut(|ctx| {
114            ctx.set_tangent(tensor.forward_id(), tangent);
115        });
116    }
117}
118
119pub(crate) fn add_tangent(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Result<()> {
120    if !is_active() {
121        return Ok(());
122    }
123    let linearized = AddOp.linearize(&[lhs.primal(), rhs.primal()], &[output.primal().clone()])?;
124    let mut outputs = linearized.jvp(&[tangent_for(lhs), tangent_for(rhs)])?;
125    record_tangent(output, outputs.pop().unwrap_or(None));
126    Ok(())
127}
128
129pub(crate) fn exp_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
130    if !is_active() {
131        return Ok(());
132    }
133    let linearized = ExpOp.linearize(&[input.primal()], &[output.primal().clone()])?;
134    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
135    record_tangent(output, outputs.pop().unwrap_or(None));
136    Ok(())
137}
138
139pub(crate) fn sum_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
140    if !is_active() {
141        return Ok(());
142    }
143    let linearized = SumOp.linearize(&[input.primal()], &[output.primal().clone()])?;
144    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
145    record_tangent(output, outputs.pop().unwrap_or(None));
146    Ok(())
147}
148
149pub(crate) fn qr_tangents(input: &Tensor, q: &Tensor, r: &Tensor) -> Result<()> {
150    if !is_active() {
151        return Ok(());
152    }
153    let linearized =
154        QrOp.linearize(&[input.primal()], &[q.primal().clone(), r.primal().clone()])?;
155    let outputs = linearized.jvp(&[tangent_for(input)])?;
156    let mut outputs = outputs.into_iter();
157    record_tangent(q, outputs.next().unwrap_or(None));
158    record_tangent(r, outputs.next().unwrap_or(None));
159    Ok(())
160}
161
162pub(crate) fn einsum_tangent(
163    subscripts: &str,
164    operands: &[&Tensor],
165    output: &Tensor,
166) -> Result<()> {
167    if !is_active() {
168        return Ok(());
169    }
170    let primals = operands
171        .iter()
172        .map(|tensor| tensor.primal())
173        .collect::<Vec<_>>();
174    let linearized = EinsumOp::new(subscripts).linearize(&primals, &[output.primal().clone()])?;
175    let tangents = operands
176        .iter()
177        .map(|tensor| tangent_for(tensor))
178        .collect::<Vec<_>>();
179    let mut outputs = linearized.jvp(&tangents)?;
180    record_tangent(output, outputs.pop().unwrap_or(None));
181    Ok(())
182}
183
184pub(crate) fn solve_tangent(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Result<()> {
185    if !is_active() {
186        return Ok(());
187    }
188    let linearized =
189        SolveOp.linearize(&[lhs.primal(), rhs.primal()], &[output.primal().clone()])?;
190    let mut outputs = linearized.jvp(&[tangent_for(lhs), tangent_for(rhs)])?;
191    record_tangent(output, outputs.pop().unwrap_or(None));
192    Ok(())
193}
194
195pub(crate) fn lstsq_tangents(
196    lhs: &Tensor,
197    rhs: &Tensor,
198    x: &Tensor,
199    residual: &Tensor,
200) -> Result<()> {
201    if !is_active() {
202        return Ok(());
203    }
204    let linearized = LstsqOp.linearize(
205        &[lhs.primal(), rhs.primal()],
206        &[x.primal().clone(), residual.primal().clone()],
207    )?;
208    let mut outputs = linearized
209        .jvp(&[tangent_for(lhs), tangent_for(rhs)])?
210        .into_iter();
211    record_tangent(x, outputs.next().unwrap_or(None));
212    record_tangent(residual, outputs.next().unwrap_or(None));
213    Ok(())
214}
215
216pub(crate) fn solve_triangular_tangent(
217    lhs: &Tensor,
218    rhs: &Tensor,
219    output: &Tensor,
220    upper: bool,
221) -> Result<()> {
222    if !is_active() {
223        return Ok(());
224    }
225    let linearized = SolveTriangularOp::new(upper)
226        .linearize(&[lhs.primal(), rhs.primal()], &[output.primal().clone()])?;
227    let mut outputs = linearized.jvp(&[tangent_for(lhs), tangent_for(rhs)])?;
228    record_tangent(output, outputs.pop().unwrap_or(None));
229    Ok(())
230}
231
232pub(crate) fn det_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
233    if !is_active() {
234        return Ok(());
235    }
236    let linearized = DetOp.linearize(&[input.primal()], &[output.primal().clone()])?;
237    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
238    record_tangent(output, outputs.pop().unwrap_or(None));
239    Ok(())
240}
241
242pub(crate) fn inv_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
243    if !is_active() {
244        return Ok(());
245    }
246    let linearized = InvOp.linearize(&[input.primal()], &[output.primal().clone()])?;
247    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
248    record_tangent(output, outputs.pop().unwrap_or(None));
249    Ok(())
250}
251
252pub(crate) fn slogdet_tangents(input: &Tensor, sign: &Tensor, logabsdet: &Tensor) -> Result<()> {
253    if !is_active() {
254        return Ok(());
255    }
256    let linearized = SlogdetOp.linearize(
257        &[input.primal()],
258        &[sign.primal().clone(), logabsdet.primal().clone()],
259    )?;
260    let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
261    record_tangent(sign, outputs.next().unwrap_or(None));
262    record_tangent(logabsdet, outputs.next().unwrap_or(None));
263    Ok(())
264}
265
266pub(crate) fn cholesky_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
267    if !is_active() {
268        return Ok(());
269    }
270    let linearized = CholeskyOp.linearize(&[input.primal()], &[output.primal().clone()])?;
271    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
272    record_tangent(output, outputs.pop().unwrap_or(None));
273    Ok(())
274}
275
276pub(crate) fn lu_tangents(
277    input: &Tensor,
278    p: &Tensor,
279    l: &Tensor,
280    u: &Tensor,
281    pivot: LuPivot,
282) -> Result<()> {
283    if !is_active() {
284        return Ok(());
285    }
286    let linearized = LuOp::new(pivot).linearize(
287        &[input.primal()],
288        &[p.primal().clone(), l.primal().clone(), u.primal().clone()],
289    )?;
290    let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
291    record_tangent(p, outputs.next().unwrap_or(None));
292    record_tangent(l, outputs.next().unwrap_or(None));
293    record_tangent(u, outputs.next().unwrap_or(None));
294    Ok(())
295}
296
297pub(crate) fn norm_tangent(input: &Tensor, output: &Tensor, kind: NormKind) -> Result<()> {
298    if !is_active() {
299        return Ok(());
300    }
301    let linearized = NormOp::new(kind).linearize(&[input.primal()], &[output.primal().clone()])?;
302    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
303    record_tangent(output, outputs.pop().unwrap_or(None));
304    Ok(())
305}
306
307pub(crate) fn vector_norm_tangent(
308    input: &Tensor,
309    output: &Tensor,
310    ord: VectorNormOrd,
311) -> Result<()> {
312    norm_tangent(input, output, map_vector_norm_ord(ord)?)
313}
314
315pub(crate) fn matrix_norm_tangent(
316    input: &Tensor,
317    output: &Tensor,
318    ord: MatrixNormOrd,
319) -> Result<()> {
320    norm_tangent(input, output, map_matrix_norm_ord(ord)?)
321}
322
323pub(crate) fn map_vector_norm_ord(ord: VectorNormOrd) -> Result<NormKind> {
324    match ord {
325        VectorNormOrd::P(1.0) => Ok(NormKind::L1),
326        VectorNormOrd::P(p) if p >= 1.0 => Ok(NormKind::Lp(p)),
327        VectorNormOrd::PosInf => Ok(NormKind::Inf),
328        VectorNormOrd::Zero | VectorNormOrd::NegInf | VectorNormOrd::P(_) => Err(invalid_argument(
329            format!("vector_norm order {ord:?} is not implemented yet"),
330        )),
331    }
332}
333
334pub(crate) fn map_matrix_norm_ord(ord: MatrixNormOrd) -> Result<NormKind> {
335    match ord {
336        MatrixNormOrd::Fro => Ok(NormKind::Fro),
337        MatrixNormOrd::Nuc => Ok(NormKind::Nuclear),
338        MatrixNormOrd::One => Ok(NormKind::L1),
339        MatrixNormOrd::Two => Ok(NormKind::Spectral),
340        MatrixNormOrd::PosInf => Ok(NormKind::Inf),
341        MatrixNormOrd::NegOne | MatrixNormOrd::NegTwo | MatrixNormOrd::NegInf => Err(
342            invalid_argument(format!("matrix_norm order {ord:?} is not implemented yet")),
343        ),
344    }
345}
346
347pub(crate) fn eig_tangents(input: &Tensor, values: &Tensor, vectors: &Tensor) -> Result<()> {
348    if !is_active() {
349        return Ok(());
350    }
351    let linearized = EigOp.linearize(
352        &[input.primal()],
353        &[values.primal().clone(), vectors.primal().clone()],
354    )?;
355    let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
356    record_tangent(values, outputs.next().unwrap_or(None));
357    record_tangent(vectors, outputs.next().unwrap_or(None));
358    Ok(())
359}
360
361pub(crate) fn eigen_tangents(input: &Tensor, values: &Tensor, vectors: &Tensor) -> Result<()> {
362    if !is_active() {
363        return Ok(());
364    }
365    let linearized = EigenOp.linearize(
366        &[input.primal()],
367        &[values.primal().clone(), vectors.primal().clone()],
368    )?;
369    let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
370    record_tangent(values, outputs.next().unwrap_or(None));
371    record_tangent(vectors, outputs.next().unwrap_or(None));
372    Ok(())
373}
374
375pub(crate) fn svd_tangents(
376    input: &Tensor,
377    u: &Tensor,
378    s: &Tensor,
379    vt: &Tensor,
380    options: Option<SvdOptions>,
381) -> Result<()> {
382    if !is_active() {
383        return Ok(());
384    }
385    let linearized = SvdOp::new(options).linearize(
386        &[input.primal()],
387        &[u.primal().clone(), s.primal().clone(), vt.primal().clone()],
388    )?;
389    let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
390    record_tangent(u, outputs.next().unwrap_or(None));
391    record_tangent(s, outputs.next().unwrap_or(None));
392    record_tangent(vt, outputs.next().unwrap_or(None));
393    Ok(())
394}
395
396pub(crate) fn pinv_tangent(input: &Tensor, output: &Tensor, rcond: Option<f64>) -> Result<()> {
397    if !is_active() {
398        return Ok(());
399    }
400    let linearized = PInvOp::new(rcond).linearize(&[input.primal()], &[output.primal().clone()])?;
401    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
402    record_tangent(output, outputs.pop().unwrap_or(None));
403    Ok(())
404}
405
406pub(crate) fn matrix_exp_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
407    if !is_active() {
408        return Ok(());
409    }
410    let linearized = MatrixExpOp.linearize(&[input.primal()], &[output.primal().clone()])?;
411    let mut outputs = linearized.jvp(&[tangent_for(input)])?;
412    record_tangent(output, outputs.pop().unwrap_or(None));
413    Ok(())
414}
415
416pub fn jvp<F>(f: F, primals: &[Tensor], tangents: &[Option<Tensor>]) -> Result<JvpResult>
417where
418    F: FnOnce(&[Tensor]) -> Result<Vec<Tensor>>,
419{
420    if primals.len() != tangents.len() {
421        return Err(invalid_argument(format!(
422            "jvp expected {} tangents for {} primals",
423            primals.len(),
424            tangents.len()
425        )));
426    }
427    for (index, (primal, tangent)) in primals.iter().zip(tangents.iter()).enumerate() {
428        if let Some(tangent) = tangent {
429            validate_compatibility(index, primal, tangent)?;
430        }
431    }
432
433    let mut ctx = ForwardJvpContext::default();
434    for (primal, tangent) in primals.iter().zip(tangents.iter()) {
435        if let Some(tangent) = tangent {
436            ctx.set_tangent(primal.forward_id(), tangent.primal().clone());
437        }
438    }
439
440    let _guard = push_context(ctx);
441    let outputs = f(primals)?;
442    let output_tangents = outputs
443        .iter()
444        .map(|output| tangent_for(output).map(Tensor::new))
445        .collect();
446
447    Ok(JvpResult {
448        outputs,
449        output_tangents,
450    })
451}
452
453pub(crate) fn forward_id(tensor: &Tensor) -> usize {
454    tensor.primal() as *const DynTensor as usize
455}
456
457#[cfg(test)]
458mod tests {
459    use super::forward_id;
460    use crate::Tensor;
461
462    fn round_trip(tensor: Tensor) -> (Tensor, usize) {
463        let id = forward_id(&tensor);
464        (tensor, id)
465    }
466
467    #[test]
468    fn forward_id_is_stable_across_moves() {
469        let x = Tensor::from_slice(&[1.0_f64, 2.0], &[2]).unwrap();
470        let before = forward_id(&x);
471        let (x, moved_id) = round_trip(x);
472        assert_eq!(before, moved_id);
473        assert_eq!(before, forward_id(&x));
474    }
475}