1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use tenferro_internal_ad_core::{LinearizableOp, LinearizedOp};
5use tenferro_internal_ad_linalg::{
6 CholeskyOp, DetOp, EigOp, EigenOp, InvOp, LstsqOp, LuOp, MatrixExpOp, NormOp, PInvOp, QrOp,
7 SlogdetOp, SolveOp, SolveTriangularOp, SvdOp,
8};
9use tenferro_internal_ad_ops::{AddOp, EinsumOp, ExpOp, SumOp};
10use tenferro_internal_frontend_core::DynTensor;
11use tenferro_linalg::{LuPivot, MatrixNormOrd, NormKind, SvdOptions, VectorNormOrd};
12
13use crate::{Error, Result, Tensor};
14
15#[derive(Debug)]
16pub struct JvpResult {
17 pub outputs: Vec<Tensor>,
18 pub output_tangents: Vec<Option<Tensor>>,
19}
20
21#[derive(Default)]
22struct ForwardJvpContext {
23 tangents: HashMap<usize, DynTensor>,
24}
25
26impl ForwardJvpContext {
27 fn tangent_for_id(&self, id: usize) -> Option<DynTensor> {
28 self.tangents.get(&id).cloned()
29 }
30
31 fn set_tangent(&mut self, id: usize, tangent: DynTensor) {
32 self.tangents.insert(id, tangent);
33 }
34}
35
36thread_local! {
37 static FORWARD_JVP_STACK: RefCell<Vec<ForwardJvpContext>> = const { RefCell::new(Vec::new()) };
38}
39
40struct ForwardJvpGuard;
41
42impl Drop for ForwardJvpGuard {
43 fn drop(&mut self) {
44 FORWARD_JVP_STACK.with(|stack| {
45 stack.borrow_mut().pop();
46 });
47 }
48}
49
50fn invalid_argument(message: impl Into<String>) -> Error {
51 Error::InvalidTensorOperands {
52 message: message.into(),
53 }
54}
55
56fn validate_compatibility(index: usize, primal: &Tensor, tangent: &Tensor) -> Result<()> {
57 if primal.scalar_type() != tangent.scalar_type() {
58 return Err(invalid_argument(format!(
59 "jvp tangent {index} dtype mismatch: primal={:?}, tangent={:?}",
60 primal.scalar_type(),
61 tangent.scalar_type()
62 )));
63 }
64 if primal.dims() != tangent.dims() {
65 return Err(invalid_argument(format!(
66 "jvp tangent {index} shape mismatch: primal={:?}, tangent={:?}",
67 primal.dims(),
68 tangent.dims()
69 )));
70 }
71 if primal.axis_classes() != tangent.axis_classes()
72 || primal.is_dense() != tangent.is_dense()
73 || primal.is_diag() != tangent.is_diag()
74 {
75 return Err(invalid_argument(format!(
76 "jvp tangent {index} layout mismatch: primal dense={} diag={} classes={:?}, tangent dense={} diag={} classes={:?}",
77 primal.is_dense(),
78 primal.is_diag(),
79 primal.axis_classes(),
80 tangent.is_dense(),
81 tangent.is_diag(),
82 tangent.axis_classes()
83 )));
84 }
85 Ok(())
86}
87
88fn push_context(ctx: ForwardJvpContext) -> ForwardJvpGuard {
89 FORWARD_JVP_STACK.with(|stack| {
90 stack.borrow_mut().push(ctx);
91 });
92 ForwardJvpGuard
93}
94
95fn with_current_context<R>(f: impl FnOnce(&ForwardJvpContext) -> R) -> Option<R> {
96 FORWARD_JVP_STACK.with(|stack| stack.borrow().last().map(f))
97}
98
99fn with_current_context_mut<R>(f: impl FnOnce(&mut ForwardJvpContext) -> R) -> Option<R> {
100 FORWARD_JVP_STACK.with(|stack| stack.borrow_mut().last_mut().map(f))
101}
102
103pub(crate) fn is_active() -> bool {
104 with_current_context(|_| ()).is_some()
105}
106
107pub(crate) fn tangent_for(tensor: &Tensor) -> Option<DynTensor> {
108 with_current_context(|ctx| ctx.tangent_for_id(tensor.forward_id())).flatten()
109}
110
111pub(crate) fn record_tangent(tensor: &Tensor, tangent: Option<DynTensor>) {
112 if let Some(tangent) = tangent {
113 let _ = with_current_context_mut(|ctx| {
114 ctx.set_tangent(tensor.forward_id(), tangent);
115 });
116 }
117}
118
119pub(crate) fn add_tangent(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Result<()> {
120 if !is_active() {
121 return Ok(());
122 }
123 let linearized = AddOp.linearize(&[lhs.primal(), rhs.primal()], &[output.primal().clone()])?;
124 let mut outputs = linearized.jvp(&[tangent_for(lhs), tangent_for(rhs)])?;
125 record_tangent(output, outputs.pop().unwrap_or(None));
126 Ok(())
127}
128
129pub(crate) fn exp_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
130 if !is_active() {
131 return Ok(());
132 }
133 let linearized = ExpOp.linearize(&[input.primal()], &[output.primal().clone()])?;
134 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
135 record_tangent(output, outputs.pop().unwrap_or(None));
136 Ok(())
137}
138
139pub(crate) fn sum_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
140 if !is_active() {
141 return Ok(());
142 }
143 let linearized = SumOp.linearize(&[input.primal()], &[output.primal().clone()])?;
144 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
145 record_tangent(output, outputs.pop().unwrap_or(None));
146 Ok(())
147}
148
149pub(crate) fn qr_tangents(input: &Tensor, q: &Tensor, r: &Tensor) -> Result<()> {
150 if !is_active() {
151 return Ok(());
152 }
153 let linearized =
154 QrOp.linearize(&[input.primal()], &[q.primal().clone(), r.primal().clone()])?;
155 let outputs = linearized.jvp(&[tangent_for(input)])?;
156 let mut outputs = outputs.into_iter();
157 record_tangent(q, outputs.next().unwrap_or(None));
158 record_tangent(r, outputs.next().unwrap_or(None));
159 Ok(())
160}
161
162pub(crate) fn einsum_tangent(
163 subscripts: &str,
164 operands: &[&Tensor],
165 output: &Tensor,
166) -> Result<()> {
167 if !is_active() {
168 return Ok(());
169 }
170 let primals = operands
171 .iter()
172 .map(|tensor| tensor.primal())
173 .collect::<Vec<_>>();
174 let linearized = EinsumOp::new(subscripts).linearize(&primals, &[output.primal().clone()])?;
175 let tangents = operands
176 .iter()
177 .map(|tensor| tangent_for(tensor))
178 .collect::<Vec<_>>();
179 let mut outputs = linearized.jvp(&tangents)?;
180 record_tangent(output, outputs.pop().unwrap_or(None));
181 Ok(())
182}
183
184pub(crate) fn solve_tangent(lhs: &Tensor, rhs: &Tensor, output: &Tensor) -> Result<()> {
185 if !is_active() {
186 return Ok(());
187 }
188 let linearized =
189 SolveOp.linearize(&[lhs.primal(), rhs.primal()], &[output.primal().clone()])?;
190 let mut outputs = linearized.jvp(&[tangent_for(lhs), tangent_for(rhs)])?;
191 record_tangent(output, outputs.pop().unwrap_or(None));
192 Ok(())
193}
194
195pub(crate) fn lstsq_tangents(
196 lhs: &Tensor,
197 rhs: &Tensor,
198 x: &Tensor,
199 residual: &Tensor,
200) -> Result<()> {
201 if !is_active() {
202 return Ok(());
203 }
204 let linearized = LstsqOp.linearize(
205 &[lhs.primal(), rhs.primal()],
206 &[x.primal().clone(), residual.primal().clone()],
207 )?;
208 let mut outputs = linearized
209 .jvp(&[tangent_for(lhs), tangent_for(rhs)])?
210 .into_iter();
211 record_tangent(x, outputs.next().unwrap_or(None));
212 record_tangent(residual, outputs.next().unwrap_or(None));
213 Ok(())
214}
215
216pub(crate) fn solve_triangular_tangent(
217 lhs: &Tensor,
218 rhs: &Tensor,
219 output: &Tensor,
220 upper: bool,
221) -> Result<()> {
222 if !is_active() {
223 return Ok(());
224 }
225 let linearized = SolveTriangularOp::new(upper)
226 .linearize(&[lhs.primal(), rhs.primal()], &[output.primal().clone()])?;
227 let mut outputs = linearized.jvp(&[tangent_for(lhs), tangent_for(rhs)])?;
228 record_tangent(output, outputs.pop().unwrap_or(None));
229 Ok(())
230}
231
232pub(crate) fn det_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
233 if !is_active() {
234 return Ok(());
235 }
236 let linearized = DetOp.linearize(&[input.primal()], &[output.primal().clone()])?;
237 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
238 record_tangent(output, outputs.pop().unwrap_or(None));
239 Ok(())
240}
241
242pub(crate) fn inv_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
243 if !is_active() {
244 return Ok(());
245 }
246 let linearized = InvOp.linearize(&[input.primal()], &[output.primal().clone()])?;
247 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
248 record_tangent(output, outputs.pop().unwrap_or(None));
249 Ok(())
250}
251
252pub(crate) fn slogdet_tangents(input: &Tensor, sign: &Tensor, logabsdet: &Tensor) -> Result<()> {
253 if !is_active() {
254 return Ok(());
255 }
256 let linearized = SlogdetOp.linearize(
257 &[input.primal()],
258 &[sign.primal().clone(), logabsdet.primal().clone()],
259 )?;
260 let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
261 record_tangent(sign, outputs.next().unwrap_or(None));
262 record_tangent(logabsdet, outputs.next().unwrap_or(None));
263 Ok(())
264}
265
266pub(crate) fn cholesky_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
267 if !is_active() {
268 return Ok(());
269 }
270 let linearized = CholeskyOp.linearize(&[input.primal()], &[output.primal().clone()])?;
271 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
272 record_tangent(output, outputs.pop().unwrap_or(None));
273 Ok(())
274}
275
276pub(crate) fn lu_tangents(
277 input: &Tensor,
278 p: &Tensor,
279 l: &Tensor,
280 u: &Tensor,
281 pivot: LuPivot,
282) -> Result<()> {
283 if !is_active() {
284 return Ok(());
285 }
286 let linearized = LuOp::new(pivot).linearize(
287 &[input.primal()],
288 &[p.primal().clone(), l.primal().clone(), u.primal().clone()],
289 )?;
290 let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
291 record_tangent(p, outputs.next().unwrap_or(None));
292 record_tangent(l, outputs.next().unwrap_or(None));
293 record_tangent(u, outputs.next().unwrap_or(None));
294 Ok(())
295}
296
297pub(crate) fn norm_tangent(input: &Tensor, output: &Tensor, kind: NormKind) -> Result<()> {
298 if !is_active() {
299 return Ok(());
300 }
301 let linearized = NormOp::new(kind).linearize(&[input.primal()], &[output.primal().clone()])?;
302 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
303 record_tangent(output, outputs.pop().unwrap_or(None));
304 Ok(())
305}
306
307pub(crate) fn vector_norm_tangent(
308 input: &Tensor,
309 output: &Tensor,
310 ord: VectorNormOrd,
311) -> Result<()> {
312 norm_tangent(input, output, map_vector_norm_ord(ord)?)
313}
314
315pub(crate) fn matrix_norm_tangent(
316 input: &Tensor,
317 output: &Tensor,
318 ord: MatrixNormOrd,
319) -> Result<()> {
320 norm_tangent(input, output, map_matrix_norm_ord(ord)?)
321}
322
323pub(crate) fn map_vector_norm_ord(ord: VectorNormOrd) -> Result<NormKind> {
324 match ord {
325 VectorNormOrd::P(1.0) => Ok(NormKind::L1),
326 VectorNormOrd::P(p) if p >= 1.0 => Ok(NormKind::Lp(p)),
327 VectorNormOrd::PosInf => Ok(NormKind::Inf),
328 VectorNormOrd::Zero | VectorNormOrd::NegInf | VectorNormOrd::P(_) => Err(invalid_argument(
329 format!("vector_norm order {ord:?} is not implemented yet"),
330 )),
331 }
332}
333
334pub(crate) fn map_matrix_norm_ord(ord: MatrixNormOrd) -> Result<NormKind> {
335 match ord {
336 MatrixNormOrd::Fro => Ok(NormKind::Fro),
337 MatrixNormOrd::Nuc => Ok(NormKind::Nuclear),
338 MatrixNormOrd::One => Ok(NormKind::L1),
339 MatrixNormOrd::Two => Ok(NormKind::Spectral),
340 MatrixNormOrd::PosInf => Ok(NormKind::Inf),
341 MatrixNormOrd::NegOne | MatrixNormOrd::NegTwo | MatrixNormOrd::NegInf => Err(
342 invalid_argument(format!("matrix_norm order {ord:?} is not implemented yet")),
343 ),
344 }
345}
346
347pub(crate) fn eig_tangents(input: &Tensor, values: &Tensor, vectors: &Tensor) -> Result<()> {
348 if !is_active() {
349 return Ok(());
350 }
351 let linearized = EigOp.linearize(
352 &[input.primal()],
353 &[values.primal().clone(), vectors.primal().clone()],
354 )?;
355 let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
356 record_tangent(values, outputs.next().unwrap_or(None));
357 record_tangent(vectors, outputs.next().unwrap_or(None));
358 Ok(())
359}
360
361pub(crate) fn eigen_tangents(input: &Tensor, values: &Tensor, vectors: &Tensor) -> Result<()> {
362 if !is_active() {
363 return Ok(());
364 }
365 let linearized = EigenOp.linearize(
366 &[input.primal()],
367 &[values.primal().clone(), vectors.primal().clone()],
368 )?;
369 let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
370 record_tangent(values, outputs.next().unwrap_or(None));
371 record_tangent(vectors, outputs.next().unwrap_or(None));
372 Ok(())
373}
374
375pub(crate) fn svd_tangents(
376 input: &Tensor,
377 u: &Tensor,
378 s: &Tensor,
379 vt: &Tensor,
380 options: Option<SvdOptions>,
381) -> Result<()> {
382 if !is_active() {
383 return Ok(());
384 }
385 let linearized = SvdOp::new(options).linearize(
386 &[input.primal()],
387 &[u.primal().clone(), s.primal().clone(), vt.primal().clone()],
388 )?;
389 let mut outputs = linearized.jvp(&[tangent_for(input)])?.into_iter();
390 record_tangent(u, outputs.next().unwrap_or(None));
391 record_tangent(s, outputs.next().unwrap_or(None));
392 record_tangent(vt, outputs.next().unwrap_or(None));
393 Ok(())
394}
395
396pub(crate) fn pinv_tangent(input: &Tensor, output: &Tensor, rcond: Option<f64>) -> Result<()> {
397 if !is_active() {
398 return Ok(());
399 }
400 let linearized = PInvOp::new(rcond).linearize(&[input.primal()], &[output.primal().clone()])?;
401 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
402 record_tangent(output, outputs.pop().unwrap_or(None));
403 Ok(())
404}
405
406pub(crate) fn matrix_exp_tangent(input: &Tensor, output: &Tensor) -> Result<()> {
407 if !is_active() {
408 return Ok(());
409 }
410 let linearized = MatrixExpOp.linearize(&[input.primal()], &[output.primal().clone()])?;
411 let mut outputs = linearized.jvp(&[tangent_for(input)])?;
412 record_tangent(output, outputs.pop().unwrap_or(None));
413 Ok(())
414}
415
416pub fn jvp<F>(f: F, primals: &[Tensor], tangents: &[Option<Tensor>]) -> Result<JvpResult>
417where
418 F: FnOnce(&[Tensor]) -> Result<Vec<Tensor>>,
419{
420 if primals.len() != tangents.len() {
421 return Err(invalid_argument(format!(
422 "jvp expected {} tangents for {} primals",
423 primals.len(),
424 tangents.len()
425 )));
426 }
427 for (index, (primal, tangent)) in primals.iter().zip(tangents.iter()).enumerate() {
428 if let Some(tangent) = tangent {
429 validate_compatibility(index, primal, tangent)?;
430 }
431 }
432
433 let mut ctx = ForwardJvpContext::default();
434 for (primal, tangent) in primals.iter().zip(tangents.iter()) {
435 if let Some(tangent) = tangent {
436 ctx.set_tangent(primal.forward_id(), tangent.primal().clone());
437 }
438 }
439
440 let _guard = push_context(ctx);
441 let outputs = f(primals)?;
442 let output_tangents = outputs
443 .iter()
444 .map(|output| tangent_for(output).map(Tensor::new))
445 .collect();
446
447 Ok(JvpResult {
448 outputs,
449 output_tangents,
450 })
451}
452
453pub(crate) fn forward_id(tensor: &Tensor) -> usize {
454 tensor.primal() as *const DynTensor as usize
455}
456
457#[cfg(test)]
458mod tests {
459 use super::forward_id;
460 use crate::Tensor;
461
462 fn round_trip(tensor: Tensor) -> (Tensor, usize) {
463 let id = forward_id(&tensor);
464 (tensor, id)
465 }
466
467 #[test]
468 fn forward_id_is_stable_across_moves() {
469 let x = Tensor::from_slice(&[1.0_f64, 2.0], &[2]).unwrap();
470 let before = forward_id(&x);
471 let (x, moved_id) = round_trip(x);
472 assert_eq!(before, moved_id);
473 assert_eq!(before, forward_id(&x));
474 }
475}