tenferro_internal_ad_linalg/
linearized.rs

1use chainrules_core::AutodiffError;
2use num_complex::{Complex32, Complex64};
3use num_traits::Zero;
4use tenferro_algebra::{Conjugate, Scalar};
5use tenferro_device::LogicalMemorySpace;
6use tenferro_internal_ad_core::{
7    AdResult, CheckpointHint, DynValue, LinearizableOp, LinearizedOp, Schema, SlotSchema,
8};
9use tenferro_internal_frontend_core::{DynTensor, DynTensorTyped, StructuredTensor};
10use tenferro_internal_runtime::contracts::{LinalgRuntimeValue, RealLinalgRuntimeValue};
11use tenferro_internal_runtime::dispatch::{
12    with_linalg_runtime, LuLinalgDispatchValue, MatrixExpLinalgDispatchValue,
13    NormLinalgDispatchValue, ScaledLinalgDispatchValue, ScaledRealLinalgDispatchValue,
14    SlogdetLinalgDispatchValue,
15};
16use tenferro_linalg::backend::{CudaLinalgScalar, LinalgCapabilityOp};
17use tenferro_linalg::{
18    cholesky, cholesky_frule, cholesky_rrule, det, det_frule, det_rrule, eig, eig_frule, eig_rrule,
19    eigen, eigen_frule, eigen_rrule, inv, inv_frule, inv_rrule, lstsq, lstsq_aux, lstsq_frule,
20    lstsq_rrule, lu, lu_frule, lu_rrule, matrix_exp, matrix_exp_frule, matrix_exp_rrule, norm,
21    norm_frule, norm_frule_complex, norm_rrule, norm_rrule_complex, pinv, pinv_frule, pinv_rrule,
22    qr, qr_frule, qr_rrule, slogdet, slogdet_frule, slogdet_rrule, solve, solve_frule, solve_rrule,
23    solve_triangular, solve_triangular_frule, solve_triangular_rrule, svd, svd_frule, svd_rrule,
24    EigCotangent, EigenCotangent, KernelLinalgScalar, LuCotangent, LuPivot, NormKind, QrCotangent,
25    SlogdetCotangent, SvdCotangent, SvdOptions,
26};
27use tenferro_tensor::{KeepCountScalar, MemoryOrder, Tensor as DenseTensor};
28
29use crate::{Error, Result};
30
31#[derive(Clone, Copy)]
32pub struct SolveOp;
33
34#[derive(Clone, Copy)]
35pub struct LstsqOp;
36
37#[derive(Clone, Copy)]
38pub struct SolveTriangularOp {
39    upper: bool,
40}
41
42#[derive(Clone, Copy)]
43pub struct NormOp {
44    kind: NormKind,
45}
46
47#[derive(Clone, Copy)]
48pub struct DetOp;
49
50#[derive(Clone, Copy)]
51pub struct InvOp;
52
53#[derive(Clone, Copy)]
54pub struct SlogdetOp;
55
56#[derive(Clone, Copy)]
57pub struct CholeskyOp;
58
59#[derive(Clone, Copy)]
60pub struct LuOp {
61    pivot: LuPivot,
62}
63
64#[derive(Clone, Copy)]
65pub struct QrOp;
66
67#[derive(Clone, Copy)]
68pub struct EigOp;
69
70#[derive(Clone, Copy)]
71pub struct EigenOp;
72
73#[derive(Clone, Default)]
74pub struct SvdOp {
75    options: Option<SvdOptions>,
76}
77
78#[derive(Clone, Default)]
79pub struct PInvOp {
80    rcond: Option<f64>,
81}
82
83#[derive(Clone, Copy)]
84pub struct MatrixExpOp;
85
86pub struct DynQrValues {
87    pub q: DynValue,
88    pub r: DynValue,
89}
90
91pub struct DynLstsqValues {
92    pub solution: DynValue,
93    pub residuals: DynValue,
94    pub rank: Vec<usize>,
95    pub singular_values: DynTensor,
96}
97
98pub struct DynLuValues {
99    pub p: DynValue,
100    pub l: DynValue,
101    pub u: DynValue,
102}
103
104pub struct DynEigValues {
105    pub values: DynValue,
106    pub vectors: DynValue,
107}
108
109pub struct DynEigenValues {
110    pub values: DynValue,
111    pub vectors: DynValue,
112}
113
114pub struct DynSlogdetValues {
115    pub sign: DynValue,
116    pub logabsdet: DynValue,
117}
118
119pub struct DynSvdValues {
120    pub u: DynValue,
121    pub s: DynValue,
122    pub vt: DynValue,
123}
124
125#[doc(hidden)]
126pub struct SolveLinearized {
127    a: DynTensor,
128    b: DynTensor,
129}
130
131#[doc(hidden)]
132pub struct LstsqLinearized {
133    a: DynTensor,
134    b: DynTensor,
135}
136
137#[doc(hidden)]
138pub struct SolveTriangularLinearized {
139    a: DynTensor,
140    b: DynTensor,
141    upper: bool,
142}
143
144#[doc(hidden)]
145pub struct NormLinearized {
146    input: DynTensor,
147    kind: NormKind,
148}
149
150#[doc(hidden)]
151pub struct DetLinearized {
152    input: DynTensor,
153}
154
155#[doc(hidden)]
156pub struct InvLinearized {
157    input: DynTensor,
158}
159
160#[doc(hidden)]
161pub struct SlogdetLinearized {
162    input: DynTensor,
163}
164
165#[doc(hidden)]
166pub struct CholeskyLinearized {
167    input: DynTensor,
168}
169
170#[doc(hidden)]
171pub struct LuLinearized {
172    input: DynTensor,
173    pivot: LuPivot,
174}
175
176#[doc(hidden)]
177pub struct QrLinearized {
178    input: DynTensor,
179}
180
181#[doc(hidden)]
182pub struct EigLinearized {
183    input: DynTensor,
184}
185
186#[doc(hidden)]
187pub struct EigenLinearized {
188    input: DynTensor,
189}
190
191#[doc(hidden)]
192pub struct SvdLinearized {
193    input: DynTensor,
194    options: Option<SvdOptions>,
195}
196
197#[doc(hidden)]
198pub struct PInvLinearized {
199    input: DynTensor,
200    rcond: Option<f64>,
201}
202
203#[doc(hidden)]
204pub struct MatrixExpLinearized {
205    input: DynTensor,
206}
207
208fn differentiable_schema(slots: usize) -> Schema {
209    Schema {
210        slots: (0..slots)
211            .map(|_| SlotSchema {
212                differentiable: true,
213                auxiliary: false,
214            })
215            .collect(),
216    }
217}
218
219fn slogdet_output_schema() -> Schema {
220    Schema {
221        slots: vec![
222            SlotSchema {
223                differentiable: false,
224                auxiliary: true,
225            },
226            SlotSchema {
227                differentiable: true,
228                auxiliary: false,
229            },
230        ],
231    }
232}
233
234fn lstsq_output_schema() -> Schema {
235    Schema {
236        slots: vec![
237            SlotSchema {
238                differentiable: true,
239                auxiliary: false,
240            },
241            SlotSchema {
242                differentiable: false,
243                auxiliary: true,
244            },
245        ],
246    }
247}
248
249fn lu_output_schema() -> Schema {
250    Schema {
251        slots: vec![
252            SlotSchema {
253                differentiable: false,
254                auxiliary: true,
255            },
256            SlotSchema {
257                differentiable: true,
258                auxiliary: false,
259            },
260            SlotSchema {
261                differentiable: true,
262                auxiliary: false,
263            },
264        ],
265    }
266}
267
268fn invalid_argument(message: impl Into<String>) -> Error {
269    AutodiffError::InvalidArgument(message.into()).into()
270}
271
272fn into_ad_error(error: Error) -> AutodiffError {
273    match error {
274        Error::Autodiff(error) => error,
275        other => AutodiffError::InvalidArgument(other.to_string()),
276    }
277}
278
279macro_rules! dispatch_linalg {
280    ($ty:ty, $op:expr, $cap:expr, |$ctx:ident| $body:expr) => {{
281        with_linalg_runtime::<$ty, _>($op, $cap, |$ctx| $body, |$ctx| $body, |$ctx| $body)
282    }};
283}
284
285fn dense_dyn_tensor_typed<T>(value: &DynTensor, context: &str) -> Result<DenseTensor<T>>
286where
287    T: DynTensorTyped + Copy,
288{
289    let structured = T::structured_ref(value)
290        .ok_or_else(|| invalid_argument(format!("{context} requires matching dtypes")))?;
291    structured.to_dense()
292}
293
294fn optional_dense_dyn_tensor_typed<T>(
295    value: &Option<DynTensor>,
296    context: &str,
297) -> Result<Option<DenseTensor<T>>>
298where
299    T: DynTensorTyped + Copy,
300{
301    value
302        .as_ref()
303        .map(|tensor| dense_dyn_tensor_typed::<T>(tensor, context))
304        .transpose()
305}
306
307fn dense_zeros_like<T>(like: &DenseTensor<T>) -> Result<DenseTensor<T>>
308where
309    T: Scalar + Zero + Copy,
310{
311    let total: usize = like.dims().iter().product();
312    DenseTensor::from_slice(
313        &vec![T::zero(); total],
314        like.dims(),
315        MemoryOrder::ColumnMajor,
316    )
317    .map_err(Error::from)
318}
319
320fn dense_optional_or_zero<T>(
321    value: &Option<DynTensor>,
322    like: &DenseTensor<T>,
323    context: &str,
324) -> Result<DenseTensor<T>>
325where
326    T: DynTensorTyped + Scalar + Zero + Copy,
327{
328    optional_dense_dyn_tensor_typed::<T>(value, context)?.map_or_else(|| dense_zeros_like(like), Ok)
329}
330
331fn dyn_from_dense<T>(value: DenseTensor<T>) -> DynTensor
332where
333    T: DynTensorTyped + Copy,
334{
335    T::into_dyn(StructuredTensor::from(value))
336}
337
338fn solve_primal_t<T>(a: &StructuredTensor<T>, b: &StructuredTensor<T>) -> Result<DynTensor>
339where
340    T: LinalgRuntimeValue + DynTensorTyped + Copy,
341{
342    let dense_a = a.to_dense()?;
343    let dense_b = b.to_dense()?;
344    let output = dispatch_linalg!(T, "solve_dyn_value", LinalgCapabilityOp::Solve, |ctx| {
345        solve(ctx, &dense_a, &dense_b).map_err(Error::from)
346    })?;
347    Ok(dyn_from_dense(output))
348}
349
350fn solve_jvp_t<T>(
351    a: &StructuredTensor<T>,
352    b: &StructuredTensor<T>,
353    tangents: &[Option<DynTensor>],
354) -> Result<Option<DynTensor>>
355where
356    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
357{
358    if tangents.iter().all(Option::is_none) {
359        return Ok(None);
360    }
361    let dense_a = a.to_dense()?;
362    let dense_b = b.to_dense()?;
363    let tangent_a = dense_optional_or_zero(&tangents[0], &dense_a, "solve_jvp tangent_a")?;
364    let tangent_b = dense_optional_or_zero(&tangents[1], &dense_b, "solve_jvp tangent_b")?;
365    let (_, tangent) = dispatch_linalg!(T, "solve_jvp", LinalgCapabilityOp::Solve, |ctx| {
366        solve_frule(ctx, &dense_a, &dense_b, &tangent_a, &tangent_b).map_err(Error::from)
367    })?;
368    Ok(Some(dyn_from_dense(tangent)))
369}
370
371fn solve_vjp_t<T>(
372    a: &StructuredTensor<T>,
373    b: &StructuredTensor<T>,
374    cotangent: &DynTensor,
375    input_grad_mask: &[bool],
376) -> Result<Vec<Option<DynTensor>>>
377where
378    T: LinalgRuntimeValue + DynTensorTyped + Copy,
379{
380    if !input_grad_mask.iter().any(|needed| *needed) {
381        return Ok(vec![None, None]);
382    }
383    let dense_a = a.to_dense()?;
384    let dense_b = b.to_dense()?;
385    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "solve_vjp")?;
386    let grad = dispatch_linalg!(T, "solve_vjp", LinalgCapabilityOp::Solve, |ctx| {
387        solve_rrule(ctx, &dense_a, &dense_b, &dense_cotangent).map_err(Error::from)
388    })?;
389    Ok(vec![
390        input_grad_mask[0].then(|| dyn_from_dense(grad.a)),
391        input_grad_mask[1].then(|| dyn_from_dense(grad.b)),
392    ])
393}
394
395fn lstsq_primal_t<T>(a: &StructuredTensor<T>, b: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
396where
397    T: RealLinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
398{
399    let dense_a = a.to_dense()?;
400    let dense_b = b.to_dense()?;
401    let output = dispatch_linalg!(T, "lstsq_dyn_values", LinalgCapabilityOp::Lstsq, |ctx| {
402        lstsq(ctx, &dense_a, &dense_b).map_err(Error::from)
403    })?;
404    Ok(vec![
405        dyn_from_dense(output.solution),
406        dyn_from_dense(output.residuals),
407    ])
408}
409
410fn lstsq_aux_t<T>(a: &StructuredTensor<T>) -> Result<(Vec<usize>, DynTensor)>
411where
412    T: RealLinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
413{
414    let dense_a = a.to_dense()?;
415    let aux = dispatch_linalg!(T, "lstsq_aux", LinalgCapabilityOp::Lstsq, |ctx| {
416        lstsq_aux(ctx, &dense_a).map_err(Error::from)
417    })?;
418    let rank = dense_rank_counts_to_vec(&aux.rank)?;
419    Ok((rank, dyn_from_dense(aux.singular_values)))
420}
421
422fn dense_rank_counts_to_vec<T>(rank_counts: &DenseTensor<T>) -> Result<Vec<usize>>
423where
424    T: KeepCountScalar + Copy,
425{
426    let host = rank_counts.to_memory_space_async(LogicalMemorySpace::MainMemory)?;
427    let contiguous = host.contiguous(MemoryOrder::ColumnMajor);
428    let slice = contiguous.buffer().as_slice().ok_or_else(|| {
429        invalid_argument("lstsq rank counts require host-accessible contiguous storage")
430    })?;
431    slice
432        .iter()
433        .map(|value| {
434            value
435                .to_usize()
436                .ok_or_else(|| invalid_argument("failed to convert lstsq rank count to usize"))
437        })
438        .collect()
439}
440
441fn lstsq_jvp_t<T>(
442    a: &StructuredTensor<T>,
443    b: &StructuredTensor<T>,
444    tangents: &[Option<DynTensor>],
445) -> Result<Vec<Option<DynTensor>>>
446where
447    T: ScaledRealLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
448{
449    if tangents.iter().all(Option::is_none) {
450        return Ok(vec![None, None]);
451    }
452    let dense_a = a.to_dense()?;
453    let dense_b = b.to_dense()?;
454    let tangent_a = dense_optional_or_zero(&tangents[0], &dense_a, "lstsq_jvp tangent_a")?;
455    let tangent_b = dense_optional_or_zero(&tangents[1], &dense_b, "lstsq_jvp tangent_b")?;
456    let (_, tangent) = dispatch_linalg!(T, "lstsq_jvp", LinalgCapabilityOp::Lstsq, |ctx| {
457        lstsq_frule(ctx, &dense_a, &dense_b, &tangent_a, &tangent_b).map_err(Error::from)
458    })?;
459    Ok(vec![
460        Some(dyn_from_dense(tangent.solution)),
461        Some(dyn_from_dense(tangent.residuals)),
462    ])
463}
464
465fn lstsq_vjp_t<T>(
466    a: &StructuredTensor<T>,
467    b: &StructuredTensor<T>,
468    output_cotangents: &[Option<DynTensor>],
469    input_grad_mask: &[bool],
470) -> Result<Vec<Option<DynTensor>>>
471where
472    T: ScaledRealLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
473{
474    if !input_grad_mask.iter().any(|needed| *needed) {
475        return Ok(vec![None, None]);
476    }
477    let cotangent_solution = output_cotangents[0]
478        .as_ref()
479        .map(|cotangent| dense_dyn_tensor_typed::<T>(cotangent, "lstsq_vjp solution"))
480        .transpose()?;
481    let cotangent_residuals = output_cotangents[1]
482        .as_ref()
483        .map(|cotangent| dense_dyn_tensor_typed::<T>(cotangent, "lstsq_vjp residuals"))
484        .transpose()?;
485    if cotangent_solution.is_none() && cotangent_residuals.is_none() {
486        return Ok(vec![None, None]);
487    }
488    let dense_a = a.to_dense()?;
489    let dense_b = b.to_dense()?;
490    let grad = dispatch_linalg!(T, "lstsq_vjp", LinalgCapabilityOp::Lstsq, |ctx| {
491        lstsq_rrule(
492            ctx,
493            &dense_a,
494            &dense_b,
495            cotangent_solution.as_ref(),
496            cotangent_residuals.as_ref(),
497        )
498        .map_err(Error::from)
499    })?;
500    Ok(vec![
501        input_grad_mask[0].then(|| dyn_from_dense(grad.a)),
502        input_grad_mask[1].then(|| dyn_from_dense(grad.b)),
503    ])
504}
505
506fn solve_triangular_primal_t<T>(
507    a: &StructuredTensor<T>,
508    b: &StructuredTensor<T>,
509    upper: bool,
510) -> Result<DynTensor>
511where
512    T: LinalgRuntimeValue + DynTensorTyped + Copy,
513{
514    let dense_a = a.to_dense()?;
515    let dense_b = b.to_dense()?;
516    let output = dispatch_linalg!(
517        T,
518        "solve_triangular_dyn_value",
519        LinalgCapabilityOp::SolveTriangular,
520        |ctx| { solve_triangular(ctx, &dense_a, &dense_b, upper).map_err(Error::from) }
521    )?;
522    Ok(dyn_from_dense(output))
523}
524
525fn solve_triangular_jvp_t<T>(
526    a: &StructuredTensor<T>,
527    b: &StructuredTensor<T>,
528    tangents: &[Option<DynTensor>],
529    upper: bool,
530) -> Result<Option<DynTensor>>
531where
532    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
533{
534    if tangents.iter().all(Option::is_none) {
535        return Ok(None);
536    }
537    let dense_a = a.to_dense()?;
538    let dense_b = b.to_dense()?;
539    let tangent_a =
540        dense_optional_or_zero(&tangents[0], &dense_a, "solve_triangular_jvp tangent_a")?;
541    let tangent_b =
542        dense_optional_or_zero(&tangents[1], &dense_b, "solve_triangular_jvp tangent_b")?;
543    let (_, tangent) = dispatch_linalg!(
544        T,
545        "solve_triangular_jvp",
546        LinalgCapabilityOp::SolveTriangular,
547        |ctx| {
548            solve_triangular_frule(ctx, &dense_a, &dense_b, &tangent_a, &tangent_b, upper)
549                .map_err(Error::from)
550        }
551    )?;
552    Ok(Some(dyn_from_dense(tangent)))
553}
554
555fn solve_triangular_vjp_t<T>(
556    a: &StructuredTensor<T>,
557    b: &StructuredTensor<T>,
558    cotangent: &DynTensor,
559    input_grad_mask: &[bool],
560    upper: bool,
561) -> Result<Vec<Option<DynTensor>>>
562where
563    T: LinalgRuntimeValue + DynTensorTyped + Copy,
564{
565    if !input_grad_mask.iter().any(|needed| *needed) {
566        return Ok(vec![None, None]);
567    }
568    let dense_a = a.to_dense()?;
569    let dense_b = b.to_dense()?;
570    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "solve_triangular_vjp")?;
571    let grad = dispatch_linalg!(
572        T,
573        "solve_triangular_vjp",
574        LinalgCapabilityOp::SolveTriangular,
575        |ctx| {
576            solve_triangular_rrule(ctx, &dense_a, &dense_b, &dense_cotangent, upper)
577                .map_err(Error::from)
578        }
579    )?;
580    Ok(vec![
581        input_grad_mask[0].then(|| dyn_from_dense(grad.a)),
582        input_grad_mask[1].then(|| dyn_from_dense(grad.b)),
583    ])
584}
585
586fn inv_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
587where
588    T: LinalgRuntimeValue + DynTensorTyped + Copy,
589{
590    let dense_input = input.to_dense()?;
591    let output = dispatch_linalg!(T, "inv_dyn_value", LinalgCapabilityOp::Inv, |ctx| {
592        inv(ctx, &dense_input).map_err(Error::from)
593    })?;
594    Ok(dyn_from_dense(output))
595}
596
597fn inv_jvp_t<T>(
598    input: &StructuredTensor<T>,
599    tangent: &Option<DynTensor>,
600) -> Result<Option<DynTensor>>
601where
602    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
603{
604    if tangent.is_none() {
605        return Ok(None);
606    }
607    let dense_input = input.to_dense()?;
608    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "inv_jvp tangent")?;
609    let (_, output_tangent) = dispatch_linalg!(T, "inv_jvp", LinalgCapabilityOp::Inv, |ctx| {
610        inv_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
611    })?;
612    Ok(Some(dyn_from_dense(output_tangent)))
613}
614
615fn inv_vjp_t<T>(
616    input: &StructuredTensor<T>,
617    cotangent: &DynTensor,
618    input_grad_mask: &[bool],
619) -> Result<Vec<Option<DynTensor>>>
620where
621    T: LinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
622{
623    if !input_grad_mask[0] {
624        return Ok(vec![None]);
625    }
626    let dense_input = input.to_dense()?;
627    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "inv_vjp")?;
628    let grad = dispatch_linalg!(T, "inv_vjp", LinalgCapabilityOp::Inv, |ctx| {
629        inv_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
630    })?;
631    Ok(vec![Some(dyn_from_dense(grad))])
632}
633
634fn slogdet_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
635where
636    T: SlogdetLinalgDispatchValue + DynTensorTyped + Copy,
637    T::Real: DynTensorTyped,
638{
639    let dense_input = input.to_dense()?;
640    let output = dispatch_linalg!(T, "slogdet_dyn_value", LinalgCapabilityOp::Slogdet, |ctx| {
641        slogdet(ctx, &dense_input).map_err(Error::from)
642    })?;
643    Ok(vec![
644        dyn_from_dense(output.sign),
645        dyn_from_dense(output.logabsdet),
646    ])
647}
648
649fn slogdet_jvp_t<T>(
650    input: &StructuredTensor<T>,
651    tangent: &Option<DynTensor>,
652) -> Result<Vec<Option<DynTensor>>>
653where
654    T: SlogdetLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
655    T::Real: DynTensorTyped,
656{
657    if tangent.is_none() {
658        return Ok(vec![None, None]);
659    }
660    let dense_input = input.to_dense()?;
661    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "slogdet_jvp tangent")?;
662    let (_, tangent_output) =
663        dispatch_linalg!(T, "slogdet_jvp", LinalgCapabilityOp::Slogdet, |ctx| {
664            slogdet_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
665        })?;
666    Ok(vec![
667        Some(dyn_from_dense(tangent_output.sign)),
668        Some(dyn_from_dense(tangent_output.logabsdet)),
669    ])
670}
671
672fn slogdet_vjp_t<T>(
673    input: &StructuredTensor<T>,
674    output_cotangents: &[Option<DynTensor>],
675    input_grad_mask: &[bool],
676) -> Result<Vec<Option<DynTensor>>>
677where
678    T: SlogdetLinalgDispatchValue + DynTensorTyped + Copy,
679    T::Real: DynTensorTyped,
680{
681    if !input_grad_mask[0] {
682        return Ok(vec![None]);
683    }
684    let dense_input = input.to_dense()?;
685    let cotangent = SlogdetCotangent {
686        sign: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[0], "slogdet_vjp sign")?,
687        logabsdet: optional_dense_dyn_tensor_typed::<T::Real>(
688            &output_cotangents[1],
689            "slogdet_vjp logabsdet",
690        )?,
691    };
692    if cotangent.sign.is_none() && cotangent.logabsdet.is_none() {
693        return Ok(vec![None]);
694    }
695    let grad = dispatch_linalg!(T, "slogdet_vjp", LinalgCapabilityOp::Slogdet, |ctx| {
696        slogdet_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
697    })?;
698    Ok(vec![Some(dyn_from_dense(grad))])
699}
700
701fn cholesky_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
702where
703    T: LinalgRuntimeValue + DynTensorTyped + Copy,
704{
705    let dense_input = input.to_dense()?;
706    let output = dispatch_linalg!(
707        T,
708        "cholesky_dyn_value",
709        LinalgCapabilityOp::Cholesky,
710        |ctx| { cholesky(ctx, &dense_input).map_err(Error::from) }
711    )?;
712    Ok(dyn_from_dense(output))
713}
714
715fn cholesky_jvp_t<T>(
716    input: &StructuredTensor<T>,
717    tangent: &Option<DynTensor>,
718) -> Result<Option<DynTensor>>
719where
720    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
721{
722    if tangent.is_none() {
723        return Ok(None);
724    }
725    let dense_input = input.to_dense()?;
726    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "cholesky_jvp tangent")?;
727    let (_, output_tangent) =
728        dispatch_linalg!(T, "cholesky_jvp", LinalgCapabilityOp::Cholesky, |ctx| {
729            cholesky_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
730        })?;
731    Ok(Some(dyn_from_dense(output_tangent)))
732}
733
734fn cholesky_vjp_t<T>(
735    input: &StructuredTensor<T>,
736    cotangent: &DynTensor,
737    input_grad_mask: &[bool],
738) -> Result<Vec<Option<DynTensor>>>
739where
740    T: LinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
741{
742    if !input_grad_mask[0] {
743        return Ok(vec![None]);
744    }
745    let dense_input = input.to_dense()?;
746    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "cholesky_vjp")?;
747    let grad = dispatch_linalg!(T, "cholesky_vjp", LinalgCapabilityOp::Cholesky, |ctx| {
748        cholesky_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
749    })?;
750    Ok(vec![Some(dyn_from_dense(grad))])
751}
752
753fn lu_primal_t<T>(input: &StructuredTensor<T>, pivot: LuPivot) -> Result<Vec<DynTensor>>
754where
755    T: LuLinalgDispatchValue + DynTensorTyped + Copy,
756{
757    let dense_input = input.to_dense()?;
758    let output = dispatch_linalg!(T, "lu_dyn_value", LinalgCapabilityOp::LuFactor, |ctx| {
759        lu(ctx, &dense_input, pivot).map_err(Error::from)
760    })?;
761    Ok(vec![
762        dyn_from_dense(output.p),
763        dyn_from_dense(output.l),
764        dyn_from_dense(output.u),
765    ])
766}
767
768fn lu_jvp_t<T>(
769    input: &StructuredTensor<T>,
770    tangent: &Option<DynTensor>,
771    pivot: LuPivot,
772) -> Result<Vec<Option<DynTensor>>>
773where
774    T: LuLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
775{
776    if tangent.is_none() {
777        return Ok(vec![None, None, None]);
778    }
779    let dense_input = input.to_dense()?;
780    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "lu_jvp tangent")?;
781    let (_, tangent_output) = dispatch_linalg!(T, "lu_jvp", LinalgCapabilityOp::LuFactor, |ctx| {
782        lu_frule(ctx, &dense_input, &dense_tangent, pivot).map_err(Error::from)
783    })?;
784    Ok(vec![
785        None,
786        Some(dyn_from_dense(tangent_output.l)),
787        Some(dyn_from_dense(tangent_output.u)),
788    ])
789}
790
791fn lu_vjp_t<T>(
792    input: &StructuredTensor<T>,
793    output_cotangents: &[Option<DynTensor>],
794    input_grad_mask: &[bool],
795    pivot: LuPivot,
796) -> Result<Vec<Option<DynTensor>>>
797where
798    T: LuLinalgDispatchValue + DynTensorTyped + Copy,
799{
800    if !input_grad_mask[0] {
801        return Ok(vec![None]);
802    }
803    if output_cotangents
804        .first()
805        .and_then(|value| value.as_ref())
806        .is_some()
807    {
808        return Err(invalid_argument(
809            "lu permutation cotangent is unsupported; permutation output is auxiliary",
810        ));
811    }
812    let dense_input = input.to_dense()?;
813    let cotangent = LuCotangent {
814        l: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[1], "lu_vjp l")?,
815        u: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[2], "lu_vjp u")?,
816    };
817    if cotangent.l.is_none() && cotangent.u.is_none() {
818        return Ok(vec![None]);
819    }
820    let grad = dispatch_linalg!(T, "lu_vjp", LinalgCapabilityOp::LuFactor, |ctx| {
821        lu_rrule(ctx, &dense_input, &cotangent, pivot).map_err(Error::from)
822    })?;
823    Ok(vec![Some(dyn_from_dense(grad))])
824}
825
826fn pinv_primal_t<T>(input: &StructuredTensor<T>, rcond: Option<f64>) -> Result<DynTensor>
827where
828    T: ScaledLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
829    T::Real: KeepCountScalar,
830{
831    let dense_input = input.to_dense()?;
832    let output = dispatch_linalg!(T, "pinv_dyn_value", LinalgCapabilityOp::Pinv, |ctx| {
833        pinv(ctx, &dense_input, rcond).map_err(Error::from)
834    })?;
835    Ok(dyn_from_dense(output))
836}
837
838fn pinv_jvp_t<T>(
839    input: &StructuredTensor<T>,
840    tangent: &Option<DynTensor>,
841    rcond: Option<f64>,
842) -> Result<Option<DynTensor>>
843where
844    T: ScaledLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
845    T::Real: KeepCountScalar,
846{
847    if tangent.is_none() {
848        return Ok(None);
849    }
850    let dense_input = input.to_dense()?;
851    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "pinv_jvp tangent")?;
852    let (_, output_tangent) = dispatch_linalg!(T, "pinv_jvp", LinalgCapabilityOp::Pinv, |ctx| {
853        pinv_frule(ctx, &dense_input, &dense_tangent, rcond).map_err(Error::from)
854    })?;
855    Ok(Some(dyn_from_dense(output_tangent)))
856}
857
858fn pinv_vjp_t<T>(
859    input: &StructuredTensor<T>,
860    cotangent: &DynTensor,
861    input_grad_mask: &[bool],
862    rcond: Option<f64>,
863) -> Result<Vec<Option<DynTensor>>>
864where
865    T: ScaledLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
866    T::Real: KeepCountScalar,
867{
868    if !input_grad_mask[0] {
869        return Ok(vec![None]);
870    }
871    let dense_input = input.to_dense()?;
872    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "pinv_vjp")?;
873    let grad = dispatch_linalg!(T, "pinv_vjp", LinalgCapabilityOp::Pinv, |ctx| {
874        pinv_rrule(ctx, &dense_input, &dense_cotangent, rcond).map_err(Error::from)
875    })?;
876    Ok(vec![Some(dyn_from_dense(grad))])
877}
878
879fn matrix_exp_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
880where
881    T: MatrixExpLinalgDispatchValue + DynTensorTyped + Copy,
882    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
883{
884    let dense_input = input.to_dense()?;
885    let output = dispatch_linalg!(
886        T,
887        "matrix_exp_dyn_value",
888        LinalgCapabilityOp::MatrixExp,
889        |ctx| { matrix_exp(ctx, &dense_input).map_err(Error::from) }
890    )?;
891    Ok(dyn_from_dense(output))
892}
893
894fn matrix_exp_jvp_t<T>(
895    input: &StructuredTensor<T>,
896    tangent: &Option<DynTensor>,
897) -> Result<Option<DynTensor>>
898where
899    T: MatrixExpLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
900    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
901{
902    if tangent.is_none() {
903        return Ok(None);
904    }
905    let dense_input = input.to_dense()?;
906    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "matrix_exp_jvp tangent")?;
907    let (_, output_tangent) =
908        dispatch_linalg!(T, "matrix_exp_jvp", LinalgCapabilityOp::MatrixExp, |ctx| {
909            matrix_exp_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
910        })?;
911    Ok(Some(dyn_from_dense(output_tangent)))
912}
913
914fn matrix_exp_vjp_t<T>(
915    input: &StructuredTensor<T>,
916    cotangent: &DynTensor,
917    input_grad_mask: &[bool],
918) -> Result<Vec<Option<DynTensor>>>
919where
920    T: MatrixExpLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
921    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
922{
923    if !input_grad_mask[0] {
924        return Ok(vec![None]);
925    }
926    let dense_input = input.to_dense()?;
927    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "matrix_exp_vjp")?;
928    let grad = dispatch_linalg!(T, "matrix_exp_vjp", LinalgCapabilityOp::MatrixExp, |ctx| {
929        matrix_exp_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
930    })?;
931    Ok(vec![Some(dyn_from_dense(grad))])
932}
933
934fn eig_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
935where
936    T: RealLinalgRuntimeValue
937        + DynTensorTyped
938        + num_traits::Float
939        + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
940        + Copy,
941    num_complex::Complex<T>: DynTensorTyped
942        + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
943        + CudaLinalgScalar
944        + Copy,
945{
946    let dense_input = input.to_dense()?;
947    let output = dispatch_linalg!(T, "eig_dyn_value", LinalgCapabilityOp::Eig, |ctx| {
948        eig(ctx, &dense_input).map_err(Error::from)
949    })?;
950    Ok(vec![
951        dyn_from_dense(output.values),
952        dyn_from_dense(output.vectors),
953    ])
954}
955
956fn eig_jvp_t<T>(
957    input: &StructuredTensor<T>,
958    tangent: &Option<DynTensor>,
959) -> Result<Vec<Option<DynTensor>>>
960where
961    T: RealLinalgRuntimeValue
962        + DynTensorTyped
963        + Scalar
964        + Zero
965        + num_traits::Float
966        + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
967        + Copy,
968    num_complex::Complex<T>: DynTensorTyped
969        + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
970        + CudaLinalgScalar
971        + Copy,
972{
973    if tangent.is_none() {
974        return Ok(vec![None, None]);
975    }
976    let dense_input = input.to_dense()?;
977    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "eig_jvp tangent")?;
978    let (_, tangent_output) = dispatch_linalg!(T, "eig_jvp", LinalgCapabilityOp::Eig, |ctx| {
979        eig_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
980    })?;
981    Ok(vec![
982        Some(dyn_from_dense(tangent_output.values)),
983        Some(dyn_from_dense(tangent_output.vectors)),
984    ])
985}
986
987fn eig_vjp_t<T>(
988    input: &StructuredTensor<T>,
989    output_cotangents: &[Option<DynTensor>],
990    input_grad_mask: &[bool],
991) -> Result<Vec<Option<DynTensor>>>
992where
993    T: RealLinalgRuntimeValue
994        + DynTensorTyped
995        + num_traits::Float
996        + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
997        + Copy,
998    num_complex::Complex<T>: DynTensorTyped
999        + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
1000        + CudaLinalgScalar
1001        + Copy,
1002{
1003    if !input_grad_mask[0] {
1004        return Ok(vec![None]);
1005    }
1006    let dense_input = input.to_dense()?;
1007    let cotangent = EigCotangent {
1008        values: optional_dense_dyn_tensor_typed::<num_complex::Complex<T>>(
1009            &output_cotangents[0],
1010            "eig_vjp values",
1011        )?,
1012        vectors: optional_dense_dyn_tensor_typed::<num_complex::Complex<T>>(
1013            &output_cotangents[1],
1014            "eig_vjp vectors",
1015        )?,
1016    };
1017    if cotangent.values.is_none() && cotangent.vectors.is_none() {
1018        return Ok(vec![None]);
1019    }
1020    let grad = dispatch_linalg!(T, "eig_vjp", LinalgCapabilityOp::Eig, |ctx| {
1021        eig_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
1022    })?;
1023    Ok(vec![Some(dyn_from_dense(grad))])
1024}
1025
1026fn eigen_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
1027where
1028    T: LinalgRuntimeValue + DynTensorTyped + Copy,
1029    T::Real: DynTensorTyped,
1030{
1031    let dense_input = input.to_dense()?;
1032    let output = dispatch_linalg!(T, "eigen_dyn_value", LinalgCapabilityOp::EigenSym, |ctx| {
1033        eigen(ctx, &dense_input).map_err(Error::from)
1034    })?;
1035    Ok(vec![
1036        dyn_from_dense(output.values),
1037        dyn_from_dense(output.vectors),
1038    ])
1039}
1040
1041fn eigen_jvp_t<T>(
1042    input: &StructuredTensor<T>,
1043    tangent: &Option<DynTensor>,
1044) -> Result<Vec<Option<DynTensor>>>
1045where
1046    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
1047    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float + DynTensorTyped,
1048{
1049    if tangent.is_none() {
1050        return Ok(vec![None, None]);
1051    }
1052    let dense_input = input.to_dense()?;
1053    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "eigen_jvp tangent")?;
1054    let (_, tangent_output) =
1055        dispatch_linalg!(T, "eigen_jvp", LinalgCapabilityOp::EigenSym, |ctx| {
1056            eigen_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
1057        })?;
1058    Ok(vec![
1059        Some(dyn_from_dense(tangent_output.values)),
1060        Some(dyn_from_dense(tangent_output.vectors)),
1061    ])
1062}
1063
1064fn eigen_vjp_t<T>(
1065    input: &StructuredTensor<T>,
1066    output_cotangents: &[Option<DynTensor>],
1067    input_grad_mask: &[bool],
1068) -> Result<Vec<Option<DynTensor>>>
1069where
1070    T: LinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
1071    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float + DynTensorTyped,
1072{
1073    if !input_grad_mask[0] {
1074        return Ok(vec![None]);
1075    }
1076    let dense_input = input.to_dense()?;
1077    let cotangent = EigenCotangent {
1078        values: optional_dense_dyn_tensor_typed::<T::Real>(
1079            &output_cotangents[0],
1080            "eigen_vjp values",
1081        )?,
1082        vectors: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[1], "eigen_vjp vectors")?,
1083    };
1084    if cotangent.values.is_none() && cotangent.vectors.is_none() {
1085        return Ok(vec![None]);
1086    }
1087    let grad = dispatch_linalg!(T, "eigen_vjp", LinalgCapabilityOp::EigenSym, |ctx| {
1088        eigen_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
1089    })?;
1090    Ok(vec![Some(dyn_from_dense(grad))])
1091}
1092
1093fn norm_primal_t<T>(input: &StructuredTensor<T>, kind: NormKind) -> Result<DynTensor>
1094where
1095    T: NormLinalgDispatchValue + DynTensorTyped + Copy,
1096    <T as tenferro_linalg::LinalgScalar>::Real: DynTensorTyped,
1097{
1098    let dense_input = input.to_dense()?;
1099    let output = dispatch_linalg!(T, "norm_dyn_value", LinalgCapabilityOp::Norm, |ctx| {
1100        norm(ctx, &dense_input, kind).map_err(Error::from)
1101    })?;
1102    Ok(dyn_from_dense(output))
1103}
1104
1105fn norm_jvp_t<T>(
1106    input: &StructuredTensor<T>,
1107    tangent: &Option<DynTensor>,
1108    kind: NormKind,
1109) -> Result<Option<DynTensor>>
1110where
1111    T: RealLinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
1112{
1113    if tangent.is_none() {
1114        return Ok(None);
1115    }
1116    let dense_input = input.to_dense()?;
1117    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "norm_jvp tangent")?;
1118    let (_, output_tangent) = dispatch_linalg!(T, "norm_jvp", LinalgCapabilityOp::Norm, |ctx| {
1119        norm_frule(ctx, &dense_input, &dense_tangent, kind).map_err(Error::from)
1120    })?;
1121    Ok(Some(dyn_from_dense(output_tangent)))
1122}
1123
1124fn norm_jvp_c32_t(
1125    input: &StructuredTensor<Complex32>,
1126    tangent: &Option<DynTensor>,
1127    kind: NormKind,
1128) -> Result<Option<DynTensor>> {
1129    if tangent.is_none() {
1130        return Ok(None);
1131    }
1132    let dense_input = input.to_dense()?;
1133    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "norm_jvp tangent")?;
1134    let (_, output_tangent) =
1135        dispatch_linalg!(Complex32, "norm_jvp", LinalgCapabilityOp::Norm, |ctx| {
1136            norm_frule_complex::<Complex32, f32, _>(ctx, &dense_input, &dense_tangent, kind)
1137                .map_err(Error::from)
1138        })?;
1139    Ok(Some(dyn_from_dense(output_tangent)))
1140}
1141
1142fn norm_jvp_c64_t(
1143    input: &StructuredTensor<Complex64>,
1144    tangent: &Option<DynTensor>,
1145    kind: NormKind,
1146) -> Result<Option<DynTensor>> {
1147    if tangent.is_none() {
1148        return Ok(None);
1149    }
1150    let dense_input = input.to_dense()?;
1151    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "norm_jvp tangent")?;
1152    let (_, output_tangent) =
1153        dispatch_linalg!(Complex64, "norm_jvp", LinalgCapabilityOp::Norm, |ctx| {
1154            norm_frule_complex::<Complex64, f64, _>(ctx, &dense_input, &dense_tangent, kind)
1155                .map_err(Error::from)
1156        })?;
1157    Ok(Some(dyn_from_dense(output_tangent)))
1158}
1159
1160fn norm_vjp_c32_t(
1161    input: &StructuredTensor<Complex32>,
1162    cotangent: &DynTensor,
1163    kind: NormKind,
1164    input_grad_mask: &[bool],
1165) -> Result<Vec<Option<DynTensor>>> {
1166    if !input_grad_mask[0] {
1167        return Ok(vec![None]);
1168    }
1169    let dense_input = input.to_dense()?;
1170    let dense_cotangent = dense_dyn_tensor_typed::<f32>(cotangent, "norm_vjp")?;
1171    let grad = dispatch_linalg!(Complex32, "norm_vjp", LinalgCapabilityOp::Norm, |ctx| {
1172        norm_rrule_complex::<Complex32, f32, _>(ctx, &dense_input, &dense_cotangent, kind)
1173            .map_err(Error::from)
1174    })?;
1175    Ok(vec![Some(dyn_from_dense(grad))])
1176}
1177
1178fn norm_vjp_t<T>(
1179    input: &StructuredTensor<T>,
1180    cotangent: &DynTensor,
1181    kind: NormKind,
1182    input_grad_mask: &[bool],
1183) -> Result<Vec<Option<DynTensor>>>
1184where
1185    T: RealLinalgRuntimeValue + DynTensorTyped + Copy,
1186{
1187    if !input_grad_mask[0] {
1188        return Ok(vec![None]);
1189    }
1190    let dense_input = input.to_dense()?;
1191    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "norm_vjp")?;
1192    let grad = dispatch_linalg!(T, "norm_vjp", LinalgCapabilityOp::Norm, |ctx| {
1193        norm_rrule(ctx, &dense_input, &dense_cotangent, kind).map_err(Error::from)
1194    })?;
1195    Ok(vec![Some(dyn_from_dense(grad))])
1196}
1197
1198fn norm_vjp_c64_t(
1199    input: &StructuredTensor<Complex64>,
1200    cotangent: &DynTensor,
1201    kind: NormKind,
1202    input_grad_mask: &[bool],
1203) -> Result<Vec<Option<DynTensor>>> {
1204    if !input_grad_mask[0] {
1205        return Ok(vec![None]);
1206    }
1207    let dense_input = input.to_dense()?;
1208    let dense_cotangent = dense_dyn_tensor_typed::<f64>(cotangent, "norm_vjp")?;
1209    let grad = dispatch_linalg!(Complex64, "norm_vjp", LinalgCapabilityOp::Norm, |ctx| {
1210        norm_rrule_complex::<Complex64, f64, _>(ctx, &dense_input, &dense_cotangent, kind)
1211            .map_err(Error::from)
1212    })?;
1213    Ok(vec![Some(dyn_from_dense(grad))])
1214}
1215
1216fn det_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
1217where
1218    T: ScaledLinalgDispatchValue + DynTensorTyped + Copy,
1219{
1220    let dense_input = input.to_dense()?;
1221    let output = dispatch_linalg!(T, "det_dyn_value", LinalgCapabilityOp::Det, |ctx| {
1222        det(ctx, &dense_input).map_err(Error::from)
1223    })?;
1224    Ok(dyn_from_dense(output))
1225}
1226
1227fn det_jvp_t<T>(
1228    input: &StructuredTensor<T>,
1229    tangent: &Option<DynTensor>,
1230) -> Result<Option<DynTensor>>
1231where
1232    T: ScaledLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
1233{
1234    if tangent.is_none() {
1235        return Ok(None);
1236    }
1237    let dense_input = input.to_dense()?;
1238    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "det_jvp tangent")?;
1239    let (_, output_tangent) = dispatch_linalg!(T, "det_jvp", LinalgCapabilityOp::Det, |ctx| {
1240        det_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
1241    })?;
1242    Ok(Some(dyn_from_dense(output_tangent)))
1243}
1244
1245fn det_vjp_t<T>(
1246    input: &StructuredTensor<T>,
1247    cotangent: &DynTensor,
1248    input_grad_mask: &[bool],
1249) -> Result<Vec<Option<DynTensor>>>
1250where
1251    T: ScaledLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
1252{
1253    if !input_grad_mask[0] {
1254        return Ok(vec![None]);
1255    }
1256    let dense_input = input.to_dense()?;
1257    let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "det_vjp")?;
1258    let grad = dispatch_linalg!(T, "det_vjp", LinalgCapabilityOp::Det, |ctx| {
1259        det_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
1260    })?;
1261    Ok(vec![Some(dyn_from_dense(grad))])
1262}
1263
1264fn qr_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
1265where
1266    T: LinalgRuntimeValue + DynTensorTyped + Copy,
1267{
1268    let dense_input = input.to_dense()?;
1269    let output = dispatch_linalg!(T, "qr_dyn_value", LinalgCapabilityOp::Qr, |ctx| {
1270        qr(ctx, &dense_input).map_err(Error::from)
1271    })?;
1272    Ok(vec![dyn_from_dense(output.q), dyn_from_dense(output.r)])
1273}
1274
1275fn qr_jvp_t<T>(
1276    input: &StructuredTensor<T>,
1277    tangent: &Option<DynTensor>,
1278) -> Result<Vec<Option<DynTensor>>>
1279where
1280    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
1281{
1282    if tangent.is_none() {
1283        return Ok(vec![None, None]);
1284    }
1285    let dense_input = input.to_dense()?;
1286    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "qr_jvp tangent")?;
1287    let (_, tangent_output) = dispatch_linalg!(T, "qr_jvp", LinalgCapabilityOp::Qr, |ctx| {
1288        qr_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
1289    })?;
1290    Ok(vec![
1291        Some(dyn_from_dense(tangent_output.q)),
1292        Some(dyn_from_dense(tangent_output.r)),
1293    ])
1294}
1295
1296fn qr_vjp_t<T>(
1297    input: &StructuredTensor<T>,
1298    output_cotangents: &[Option<DynTensor>],
1299    input_grad_mask: &[bool],
1300) -> Result<Vec<Option<DynTensor>>>
1301where
1302    T: LinalgRuntimeValue + DynTensorTyped + Copy,
1303{
1304    if !input_grad_mask[0] {
1305        return Ok(vec![None]);
1306    }
1307    let dense_input = input.to_dense()?;
1308    let cotangent = QrCotangent {
1309        q: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[0], "qr_vjp q")?,
1310        r: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[1], "qr_vjp r")?,
1311    };
1312    if cotangent.q.is_none() && cotangent.r.is_none() {
1313        return Ok(vec![None]);
1314    }
1315    let grad = dispatch_linalg!(T, "qr_vjp", LinalgCapabilityOp::Qr, |ctx| {
1316        qr_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
1317    })?;
1318    Ok(vec![Some(dyn_from_dense(grad))])
1319}
1320
1321fn svd_primal_t<T>(
1322    input: &StructuredTensor<T>,
1323    options: Option<&SvdOptions>,
1324) -> Result<Vec<DynTensor>>
1325where
1326    T: LinalgRuntimeValue + DynTensorTyped + Copy,
1327    T::Real: DynTensorTyped + Copy + tenferro_tensor::KeepCountScalar,
1328{
1329    let dense_input = input.to_dense()?;
1330    let output = dispatch_linalg!(T, "svd_dyn_value", LinalgCapabilityOp::ThinSvd, |ctx| {
1331        svd(ctx, &dense_input, options).map_err(Error::from)
1332    })?;
1333    Ok(vec![
1334        dyn_from_dense(output.u),
1335        dyn_from_dense(output.s),
1336        dyn_from_dense(output.vt),
1337    ])
1338}
1339
1340fn svd_jvp_t<T>(
1341    input: &StructuredTensor<T>,
1342    tangent: &Option<DynTensor>,
1343    options: Option<&SvdOptions>,
1344) -> Result<Vec<Option<DynTensor>>>
1345where
1346    T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
1347    T::Real: DynTensorTyped + Copy + num_traits::Float + tenferro_tensor::KeepCountScalar,
1348{
1349    if tangent.is_none() {
1350        return Ok(vec![None, None, None]);
1351    }
1352    let dense_input = input.to_dense()?;
1353    let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "svd_jvp tangent")?;
1354    let (_, tangent_output) = dispatch_linalg!(T, "svd_jvp", LinalgCapabilityOp::ThinSvd, |ctx| {
1355        svd_frule(ctx, &dense_input, &dense_tangent, options).map_err(Error::from)
1356    })?;
1357    Ok(vec![
1358        Some(dyn_from_dense(tangent_output.u)),
1359        Some(dyn_from_dense(tangent_output.s)),
1360        Some(dyn_from_dense(tangent_output.vt)),
1361    ])
1362}
1363
1364fn svd_vjp_t<T>(
1365    input: &StructuredTensor<T>,
1366    output_cotangents: &[Option<DynTensor>],
1367    input_grad_mask: &[bool],
1368    options: Option<&SvdOptions>,
1369) -> Result<Vec<Option<DynTensor>>>
1370where
1371    T: LinalgRuntimeValue + DynTensorTyped + Copy,
1372    T::Real: DynTensorTyped + Copy + num_traits::Float + tenferro_tensor::KeepCountScalar,
1373{
1374    if !input_grad_mask[0] {
1375        return Ok(vec![None]);
1376    }
1377    let dense_input = input.to_dense()?;
1378    let cotangent = SvdCotangent {
1379        u: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[0], "svd_vjp u")?,
1380        s: optional_dense_dyn_tensor_typed::<T::Real>(&output_cotangents[1], "svd_vjp s")?,
1381        vt: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[2], "svd_vjp vt")?,
1382    };
1383    if cotangent.u.is_none() && cotangent.s.is_none() && cotangent.vt.is_none() {
1384        return Ok(vec![None]);
1385    }
1386    let grad = dispatch_linalg!(T, "svd_vjp", LinalgCapabilityOp::ThinSvd, |ctx| {
1387        svd_rrule(ctx, &dense_input, &cotangent, options).map_err(Error::from)
1388    })?;
1389    Ok(vec![Some(dyn_from_dense(grad))])
1390}
1391
1392impl NormOp {
1393    pub fn new(kind: NormKind) -> Self {
1394        Self { kind }
1395    }
1396}
1397
1398impl SolveTriangularOp {
1399    pub fn new(upper: bool) -> Self {
1400        Self { upper }
1401    }
1402}
1403
1404impl LuOp {
1405    pub fn new(pivot: LuPivot) -> Self {
1406        Self { pivot }
1407    }
1408}
1409
1410impl SvdOp {
1411    pub fn new(options: Option<SvdOptions>) -> Self {
1412        Self { options }
1413    }
1414}
1415
1416impl PInvOp {
1417    pub fn new(rcond: Option<f64>) -> Self {
1418        Self { rcond }
1419    }
1420}
1421
1422impl LinearizableOp<DynTensor> for SolveOp {
1423    type Linearized = SolveLinearized;
1424
1425    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1426        let output = match (inputs[0], inputs[1]) {
1427            (DynTensor::F32(a), DynTensor::F32(b)) => solve_primal_t::<f32>(a, b),
1428            (DynTensor::F64(a), DynTensor::F64(b)) => solve_primal_t::<f64>(a, b),
1429            (DynTensor::C32(a), DynTensor::C32(b)) => solve_primal_t::<Complex32>(a, b),
1430            (DynTensor::C64(a), DynTensor::C64(b)) => solve_primal_t::<Complex64>(a, b),
1431            _ => Err(invalid_argument("solve requires matching dtypes")),
1432        }
1433        .map_err(into_ad_error)?;
1434        Ok(vec![output])
1435    }
1436
1437    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1438        Ok(differentiable_schema(2))
1439    }
1440
1441    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1442        Ok(differentiable_schema(1))
1443    }
1444
1445    fn linearize(
1446        &self,
1447        inputs: &[&DynTensor],
1448        _outputs: &[DynTensor],
1449    ) -> AdResult<Self::Linearized> {
1450        Ok(SolveLinearized {
1451            a: inputs[0].clone(),
1452            b: inputs[1].clone(),
1453        })
1454    }
1455
1456    fn checkpoint_hint(&self) -> CheckpointHint {
1457        CheckpointHint::ExpensiveReplay
1458    }
1459}
1460
1461impl LinearizedOp<DynTensor> for SolveLinearized {
1462    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1463        let tangent = match (&self.a, &self.b) {
1464            (DynTensor::F32(a), DynTensor::F32(b)) => solve_jvp_t::<f32>(a, b, input_tangents),
1465            (DynTensor::F64(a), DynTensor::F64(b)) => solve_jvp_t::<f64>(a, b, input_tangents),
1466            (DynTensor::C32(a), DynTensor::C32(b)) => {
1467                solve_jvp_t::<Complex32>(a, b, input_tangents)
1468            }
1469            (DynTensor::C64(a), DynTensor::C64(b)) => {
1470                solve_jvp_t::<Complex64>(a, b, input_tangents)
1471            }
1472            _ => Err(invalid_argument(
1473                "solve linearization requires matching dtypes",
1474            )),
1475        }
1476        .map_err(into_ad_error)?;
1477        Ok(vec![tangent])
1478    }
1479
1480    fn vjp(
1481        &self,
1482        output_cotangents: &[Option<DynTensor>],
1483        input_grad_mask: &[bool],
1484    ) -> AdResult<Vec<Option<DynTensor>>> {
1485        let Some(cotangent) = output_cotangents[0].as_ref() else {
1486            return Ok(vec![None, None]);
1487        };
1488        match (&self.a, &self.b) {
1489            (DynTensor::F32(a), DynTensor::F32(b)) => {
1490                solve_vjp_t::<f32>(a, b, cotangent, input_grad_mask)
1491            }
1492            (DynTensor::F64(a), DynTensor::F64(b)) => {
1493                solve_vjp_t::<f64>(a, b, cotangent, input_grad_mask)
1494            }
1495            (DynTensor::C32(a), DynTensor::C32(b)) => {
1496                solve_vjp_t::<Complex32>(a, b, cotangent, input_grad_mask)
1497            }
1498            (DynTensor::C64(a), DynTensor::C64(b)) => {
1499                solve_vjp_t::<Complex64>(a, b, cotangent, input_grad_mask)
1500            }
1501            _ => Err(invalid_argument(
1502                "solve linearization requires matching dtypes",
1503            )),
1504        }
1505        .map_err(into_ad_error)
1506    }
1507}
1508
1509impl LinearizableOp<DynTensor> for LstsqOp {
1510    type Linearized = LstsqLinearized;
1511
1512    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1513        match (inputs[0], inputs[1]) {
1514            (DynTensor::F32(a), DynTensor::F32(b)) => lstsq_primal_t::<f32>(a, b),
1515            (DynTensor::F64(a), DynTensor::F64(b)) => lstsq_primal_t::<f64>(a, b),
1516            (DynTensor::C32(_), DynTensor::C32(_)) | (DynTensor::C64(_), DynTensor::C64(_)) => Err(
1517                invalid_argument("lstsq AD currently supports real dtypes only"),
1518            ),
1519            _ => Err(invalid_argument("lstsq requires matching dtypes")),
1520        }
1521        .map_err(into_ad_error)
1522    }
1523
1524    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1525        Ok(differentiable_schema(2))
1526    }
1527
1528    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1529        Ok(lstsq_output_schema())
1530    }
1531
1532    fn linearize(
1533        &self,
1534        inputs: &[&DynTensor],
1535        _outputs: &[DynTensor],
1536    ) -> AdResult<Self::Linearized> {
1537        Ok(LstsqLinearized {
1538            a: inputs[0].clone(),
1539            b: inputs[1].clone(),
1540        })
1541    }
1542
1543    fn checkpoint_hint(&self) -> CheckpointHint {
1544        CheckpointHint::ExpensiveReplay
1545    }
1546}
1547
1548impl LinearizedOp<DynTensor> for LstsqLinearized {
1549    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1550        match (&self.a, &self.b) {
1551            (DynTensor::F32(a), DynTensor::F32(b)) => lstsq_jvp_t::<f32>(a, b, input_tangents),
1552            (DynTensor::F64(a), DynTensor::F64(b)) => lstsq_jvp_t::<f64>(a, b, input_tangents),
1553            (DynTensor::C32(_), DynTensor::C32(_)) | (DynTensor::C64(_), DynTensor::C64(_)) => Err(
1554                invalid_argument("lstsq AD currently supports real dtypes only"),
1555            ),
1556            _ => Err(invalid_argument(
1557                "lstsq linearization requires matching dtypes",
1558            )),
1559        }
1560        .map_err(into_ad_error)
1561    }
1562
1563    fn vjp(
1564        &self,
1565        output_cotangents: &[Option<DynTensor>],
1566        input_grad_mask: &[bool],
1567    ) -> AdResult<Vec<Option<DynTensor>>> {
1568        match (&self.a, &self.b) {
1569            (DynTensor::F32(a), DynTensor::F32(b)) => {
1570                lstsq_vjp_t::<f32>(a, b, output_cotangents, input_grad_mask)
1571            }
1572            (DynTensor::F64(a), DynTensor::F64(b)) => {
1573                lstsq_vjp_t::<f64>(a, b, output_cotangents, input_grad_mask)
1574            }
1575            (DynTensor::C32(_), DynTensor::C32(_)) | (DynTensor::C64(_), DynTensor::C64(_)) => Err(
1576                invalid_argument("lstsq AD currently supports real dtypes only"),
1577            ),
1578            _ => Err(invalid_argument(
1579                "lstsq linearization requires matching dtypes",
1580            )),
1581        }
1582        .map_err(into_ad_error)
1583    }
1584}
1585
1586impl LinearizableOp<DynTensor> for NormOp {
1587    type Linearized = NormLinearized;
1588
1589    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1590        let output = match inputs[0] {
1591            DynTensor::F32(input) => norm_primal_t::<f32>(input, self.kind),
1592            DynTensor::F64(input) => norm_primal_t::<f64>(input, self.kind),
1593            DynTensor::C32(input) => norm_primal_t::<Complex32>(input, self.kind),
1594            DynTensor::C64(input) => norm_primal_t::<Complex64>(input, self.kind),
1595        }
1596        .map_err(into_ad_error)?;
1597        Ok(vec![output])
1598    }
1599
1600    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1601        Ok(differentiable_schema(1))
1602    }
1603
1604    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1605        Ok(differentiable_schema(1))
1606    }
1607
1608    fn linearize(
1609        &self,
1610        inputs: &[&DynTensor],
1611        _outputs: &[DynTensor],
1612    ) -> AdResult<Self::Linearized> {
1613        Ok(NormLinearized {
1614            input: inputs[0].clone(),
1615            kind: self.kind,
1616        })
1617    }
1618
1619    fn checkpoint_hint(&self) -> CheckpointHint {
1620        CheckpointHint::CheapReplay
1621    }
1622}
1623
1624impl LinearizableOp<DynTensor> for SolveTriangularOp {
1625    type Linearized = SolveTriangularLinearized;
1626
1627    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1628        let output = match (inputs[0], inputs[1]) {
1629            (DynTensor::F32(a), DynTensor::F32(b)) => {
1630                solve_triangular_primal_t::<f32>(a, b, self.upper)
1631            }
1632            (DynTensor::F64(a), DynTensor::F64(b)) => {
1633                solve_triangular_primal_t::<f64>(a, b, self.upper)
1634            }
1635            (DynTensor::C32(a), DynTensor::C32(b)) => {
1636                solve_triangular_primal_t::<Complex32>(a, b, self.upper)
1637            }
1638            (DynTensor::C64(a), DynTensor::C64(b)) => {
1639                solve_triangular_primal_t::<Complex64>(a, b, self.upper)
1640            }
1641            _ => Err(invalid_argument(
1642                "solve_triangular requires matching dtypes",
1643            )),
1644        }
1645        .map_err(into_ad_error)?;
1646        Ok(vec![output])
1647    }
1648
1649    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1650        Ok(differentiable_schema(2))
1651    }
1652
1653    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1654        Ok(differentiable_schema(1))
1655    }
1656
1657    fn linearize(
1658        &self,
1659        inputs: &[&DynTensor],
1660        _outputs: &[DynTensor],
1661    ) -> AdResult<Self::Linearized> {
1662        Ok(SolveTriangularLinearized {
1663            a: inputs[0].clone(),
1664            b: inputs[1].clone(),
1665            upper: self.upper,
1666        })
1667    }
1668
1669    fn checkpoint_hint(&self) -> CheckpointHint {
1670        CheckpointHint::ExpensiveReplay
1671    }
1672}
1673
1674impl LinearizedOp<DynTensor> for SolveTriangularLinearized {
1675    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1676        let tangent = match (&self.a, &self.b) {
1677            (DynTensor::F32(a), DynTensor::F32(b)) => {
1678                solve_triangular_jvp_t::<f32>(a, b, input_tangents, self.upper)
1679            }
1680            (DynTensor::F64(a), DynTensor::F64(b)) => {
1681                solve_triangular_jvp_t::<f64>(a, b, input_tangents, self.upper)
1682            }
1683            (DynTensor::C32(a), DynTensor::C32(b)) => {
1684                solve_triangular_jvp_t::<Complex32>(a, b, input_tangents, self.upper)
1685            }
1686            (DynTensor::C64(a), DynTensor::C64(b)) => {
1687                solve_triangular_jvp_t::<Complex64>(a, b, input_tangents, self.upper)
1688            }
1689            _ => Err(invalid_argument(
1690                "solve_triangular linearization requires matching dtypes",
1691            )),
1692        }
1693        .map_err(into_ad_error)?;
1694        Ok(vec![tangent])
1695    }
1696
1697    fn vjp(
1698        &self,
1699        output_cotangents: &[Option<DynTensor>],
1700        input_grad_mask: &[bool],
1701    ) -> AdResult<Vec<Option<DynTensor>>> {
1702        let Some(cotangent) = output_cotangents[0].as_ref() else {
1703            return Ok(vec![None, None]);
1704        };
1705        match (&self.a, &self.b) {
1706            (DynTensor::F32(a), DynTensor::F32(b)) => {
1707                solve_triangular_vjp_t::<f32>(a, b, cotangent, input_grad_mask, self.upper)
1708            }
1709            (DynTensor::F64(a), DynTensor::F64(b)) => {
1710                solve_triangular_vjp_t::<f64>(a, b, cotangent, input_grad_mask, self.upper)
1711            }
1712            (DynTensor::C32(a), DynTensor::C32(b)) => {
1713                solve_triangular_vjp_t::<Complex32>(a, b, cotangent, input_grad_mask, self.upper)
1714            }
1715            (DynTensor::C64(a), DynTensor::C64(b)) => {
1716                solve_triangular_vjp_t::<Complex64>(a, b, cotangent, input_grad_mask, self.upper)
1717            }
1718            _ => Err(invalid_argument(
1719                "solve_triangular linearization requires matching dtypes",
1720            )),
1721        }
1722        .map_err(into_ad_error)
1723    }
1724}
1725
1726impl LinearizedOp<DynTensor> for NormLinearized {
1727    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1728        let tangent = match &self.input {
1729            DynTensor::F32(input) => norm_jvp_t::<f32>(input, &input_tangents[0], self.kind),
1730            DynTensor::F64(input) => norm_jvp_t::<f64>(input, &input_tangents[0], self.kind),
1731            DynTensor::C32(input) => norm_jvp_c32_t(input, &input_tangents[0], self.kind),
1732            DynTensor::C64(input) => norm_jvp_c64_t(input, &input_tangents[0], self.kind),
1733        }
1734        .map_err(into_ad_error)?;
1735        Ok(vec![tangent])
1736    }
1737
1738    fn vjp(
1739        &self,
1740        output_cotangents: &[Option<DynTensor>],
1741        input_grad_mask: &[bool],
1742    ) -> AdResult<Vec<Option<DynTensor>>> {
1743        let Some(cotangent) = output_cotangents[0].as_ref() else {
1744            return Ok(vec![None]);
1745        };
1746        match &self.input {
1747            DynTensor::F32(input) => {
1748                norm_vjp_t::<f32>(input, cotangent, self.kind, input_grad_mask)
1749            }
1750            DynTensor::F64(input) => {
1751                norm_vjp_t::<f64>(input, cotangent, self.kind, input_grad_mask)
1752            }
1753            DynTensor::C32(input) => norm_vjp_c32_t(input, cotangent, self.kind, input_grad_mask),
1754            DynTensor::C64(input) => norm_vjp_c64_t(input, cotangent, self.kind, input_grad_mask),
1755        }
1756        .map_err(into_ad_error)
1757    }
1758}
1759
1760impl LinearizableOp<DynTensor> for InvOp {
1761    type Linearized = InvLinearized;
1762
1763    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1764        let output = match inputs[0] {
1765            DynTensor::F32(input) => inv_primal_t::<f32>(input),
1766            DynTensor::F64(input) => inv_primal_t::<f64>(input),
1767            DynTensor::C32(input) => inv_primal_t::<Complex32>(input),
1768            DynTensor::C64(input) => inv_primal_t::<Complex64>(input),
1769        }
1770        .map_err(into_ad_error)?;
1771        Ok(vec![output])
1772    }
1773
1774    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1775        Ok(differentiable_schema(1))
1776    }
1777
1778    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1779        Ok(differentiable_schema(1))
1780    }
1781
1782    fn linearize(
1783        &self,
1784        inputs: &[&DynTensor],
1785        _outputs: &[DynTensor],
1786    ) -> AdResult<Self::Linearized> {
1787        Ok(InvLinearized {
1788            input: inputs[0].clone(),
1789        })
1790    }
1791
1792    fn checkpoint_hint(&self) -> CheckpointHint {
1793        CheckpointHint::ExpensiveReplay
1794    }
1795}
1796
1797impl LinearizedOp<DynTensor> for InvLinearized {
1798    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1799        let tangent = match &self.input {
1800            DynTensor::F32(input) => inv_jvp_t::<f32>(input, &input_tangents[0]),
1801            DynTensor::F64(input) => inv_jvp_t::<f64>(input, &input_tangents[0]),
1802            DynTensor::C32(input) => inv_jvp_t::<Complex32>(input, &input_tangents[0]),
1803            DynTensor::C64(input) => inv_jvp_t::<Complex64>(input, &input_tangents[0]),
1804        }
1805        .map_err(into_ad_error)?;
1806        Ok(vec![tangent])
1807    }
1808
1809    fn vjp(
1810        &self,
1811        output_cotangents: &[Option<DynTensor>],
1812        input_grad_mask: &[bool],
1813    ) -> AdResult<Vec<Option<DynTensor>>> {
1814        let Some(cotangent) = output_cotangents[0].as_ref() else {
1815            return Ok(vec![None]);
1816        };
1817        match &self.input {
1818            DynTensor::F32(input) => inv_vjp_t::<f32>(input, cotangent, input_grad_mask),
1819            DynTensor::F64(input) => inv_vjp_t::<f64>(input, cotangent, input_grad_mask),
1820            DynTensor::C32(input) => inv_vjp_t::<Complex32>(input, cotangent, input_grad_mask),
1821            DynTensor::C64(input) => inv_vjp_t::<Complex64>(input, cotangent, input_grad_mask),
1822        }
1823        .map_err(into_ad_error)
1824    }
1825}
1826
1827impl LinearizableOp<DynTensor> for SlogdetOp {
1828    type Linearized = SlogdetLinearized;
1829
1830    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1831        match inputs[0] {
1832            DynTensor::F32(input) => slogdet_primal_t::<f32>(input),
1833            DynTensor::F64(input) => slogdet_primal_t::<f64>(input),
1834            DynTensor::C32(input) => slogdet_primal_t::<Complex32>(input),
1835            DynTensor::C64(input) => slogdet_primal_t::<Complex64>(input),
1836        }
1837        .map_err(into_ad_error)
1838    }
1839
1840    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1841        Ok(differentiable_schema(1))
1842    }
1843
1844    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1845        Ok(slogdet_output_schema())
1846    }
1847
1848    fn linearize(
1849        &self,
1850        inputs: &[&DynTensor],
1851        _outputs: &[DynTensor],
1852    ) -> AdResult<Self::Linearized> {
1853        Ok(SlogdetLinearized {
1854            input: inputs[0].clone(),
1855        })
1856    }
1857
1858    fn checkpoint_hint(&self) -> CheckpointHint {
1859        CheckpointHint::ExpensiveReplay
1860    }
1861}
1862
1863impl LinearizedOp<DynTensor> for SlogdetLinearized {
1864    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1865        match &self.input {
1866            DynTensor::F32(input) => slogdet_jvp_t::<f32>(input, &input_tangents[0]),
1867            DynTensor::F64(input) => slogdet_jvp_t::<f64>(input, &input_tangents[0]),
1868            DynTensor::C32(input) => slogdet_jvp_t::<Complex32>(input, &input_tangents[0]),
1869            DynTensor::C64(input) => slogdet_jvp_t::<Complex64>(input, &input_tangents[0]),
1870        }
1871        .map_err(into_ad_error)
1872    }
1873
1874    fn vjp(
1875        &self,
1876        output_cotangents: &[Option<DynTensor>],
1877        input_grad_mask: &[bool],
1878    ) -> AdResult<Vec<Option<DynTensor>>> {
1879        match &self.input {
1880            DynTensor::F32(input) => {
1881                slogdet_vjp_t::<f32>(input, output_cotangents, input_grad_mask)
1882            }
1883            DynTensor::F64(input) => {
1884                slogdet_vjp_t::<f64>(input, output_cotangents, input_grad_mask)
1885            }
1886            DynTensor::C32(input) => {
1887                slogdet_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask)
1888            }
1889            DynTensor::C64(input) => {
1890                slogdet_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask)
1891            }
1892        }
1893        .map_err(into_ad_error)
1894    }
1895}
1896
1897impl LinearizableOp<DynTensor> for CholeskyOp {
1898    type Linearized = CholeskyLinearized;
1899
1900    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1901        let output = match inputs[0] {
1902            DynTensor::F32(input) => cholesky_primal_t::<f32>(input),
1903            DynTensor::F64(input) => cholesky_primal_t::<f64>(input),
1904            DynTensor::C32(input) => cholesky_primal_t::<Complex32>(input),
1905            DynTensor::C64(input) => cholesky_primal_t::<Complex64>(input),
1906        }
1907        .map_err(into_ad_error)?;
1908        Ok(vec![output])
1909    }
1910
1911    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1912        Ok(differentiable_schema(1))
1913    }
1914
1915    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1916        Ok(differentiable_schema(1))
1917    }
1918
1919    fn linearize(
1920        &self,
1921        inputs: &[&DynTensor],
1922        _outputs: &[DynTensor],
1923    ) -> AdResult<Self::Linearized> {
1924        Ok(CholeskyLinearized {
1925            input: inputs[0].clone(),
1926        })
1927    }
1928
1929    fn checkpoint_hint(&self) -> CheckpointHint {
1930        CheckpointHint::ExpensiveReplay
1931    }
1932}
1933
1934impl LinearizableOp<DynTensor> for LuOp {
1935    type Linearized = LuLinearized;
1936
1937    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1938        match inputs[0] {
1939            DynTensor::F32(input) => lu_primal_t::<f32>(input, self.pivot),
1940            DynTensor::F64(input) => lu_primal_t::<f64>(input, self.pivot),
1941            DynTensor::C32(input) => lu_primal_t::<Complex32>(input, self.pivot),
1942            DynTensor::C64(input) => lu_primal_t::<Complex64>(input, self.pivot),
1943        }
1944        .map_err(into_ad_error)
1945    }
1946
1947    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1948        Ok(differentiable_schema(1))
1949    }
1950
1951    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1952        Ok(lu_output_schema())
1953    }
1954
1955    fn linearize(
1956        &self,
1957        inputs: &[&DynTensor],
1958        _outputs: &[DynTensor],
1959    ) -> AdResult<Self::Linearized> {
1960        Ok(LuLinearized {
1961            input: inputs[0].clone(),
1962            pivot: self.pivot,
1963        })
1964    }
1965
1966    fn checkpoint_hint(&self) -> CheckpointHint {
1967        CheckpointHint::ExpensiveReplay
1968    }
1969}
1970
1971impl LinearizedOp<DynTensor> for LuLinearized {
1972    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1973        match &self.input {
1974            DynTensor::F32(input) => lu_jvp_t::<f32>(input, &input_tangents[0], self.pivot),
1975            DynTensor::F64(input) => lu_jvp_t::<f64>(input, &input_tangents[0], self.pivot),
1976            DynTensor::C32(input) => lu_jvp_t::<Complex32>(input, &input_tangents[0], self.pivot),
1977            DynTensor::C64(input) => lu_jvp_t::<Complex64>(input, &input_tangents[0], self.pivot),
1978        }
1979        .map_err(into_ad_error)
1980    }
1981
1982    fn vjp(
1983        &self,
1984        output_cotangents: &[Option<DynTensor>],
1985        input_grad_mask: &[bool],
1986    ) -> AdResult<Vec<Option<DynTensor>>> {
1987        match &self.input {
1988            DynTensor::F32(input) => {
1989                lu_vjp_t::<f32>(input, output_cotangents, input_grad_mask, self.pivot)
1990            }
1991            DynTensor::F64(input) => {
1992                lu_vjp_t::<f64>(input, output_cotangents, input_grad_mask, self.pivot)
1993            }
1994            DynTensor::C32(input) => {
1995                lu_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask, self.pivot)
1996            }
1997            DynTensor::C64(input) => {
1998                lu_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask, self.pivot)
1999            }
2000        }
2001        .map_err(into_ad_error)
2002    }
2003}
2004
2005impl LinearizedOp<DynTensor> for CholeskyLinearized {
2006    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2007        let tangent = match &self.input {
2008            DynTensor::F32(input) => cholesky_jvp_t::<f32>(input, &input_tangents[0]),
2009            DynTensor::F64(input) => cholesky_jvp_t::<f64>(input, &input_tangents[0]),
2010            DynTensor::C32(input) => cholesky_jvp_t::<Complex32>(input, &input_tangents[0]),
2011            DynTensor::C64(input) => cholesky_jvp_t::<Complex64>(input, &input_tangents[0]),
2012        }
2013        .map_err(into_ad_error)?;
2014        Ok(vec![tangent])
2015    }
2016
2017    fn vjp(
2018        &self,
2019        output_cotangents: &[Option<DynTensor>],
2020        input_grad_mask: &[bool],
2021    ) -> AdResult<Vec<Option<DynTensor>>> {
2022        let Some(cotangent) = output_cotangents[0].as_ref() else {
2023            return Ok(vec![None]);
2024        };
2025        match &self.input {
2026            DynTensor::F32(input) => cholesky_vjp_t::<f32>(input, cotangent, input_grad_mask),
2027            DynTensor::F64(input) => cholesky_vjp_t::<f64>(input, cotangent, input_grad_mask),
2028            DynTensor::C32(input) => cholesky_vjp_t::<Complex32>(input, cotangent, input_grad_mask),
2029            DynTensor::C64(input) => cholesky_vjp_t::<Complex64>(input, cotangent, input_grad_mask),
2030        }
2031        .map_err(into_ad_error)
2032    }
2033}
2034
2035impl LinearizableOp<DynTensor> for EigOp {
2036    type Linearized = EigLinearized;
2037
2038    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2039        match inputs[0] {
2040            DynTensor::F32(input) => eig_primal_t::<f32>(input),
2041            DynTensor::F64(input) => eig_primal_t::<f64>(input),
2042            DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2043                "eig AD currently supports real inputs only",
2044            )),
2045        }
2046        .map_err(into_ad_error)
2047    }
2048
2049    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2050        Ok(differentiable_schema(1))
2051    }
2052
2053    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2054        Ok(differentiable_schema(2))
2055    }
2056
2057    fn linearize(
2058        &self,
2059        inputs: &[&DynTensor],
2060        _outputs: &[DynTensor],
2061    ) -> AdResult<Self::Linearized> {
2062        Ok(EigLinearized {
2063            input: inputs[0].clone(),
2064        })
2065    }
2066
2067    fn checkpoint_hint(&self) -> CheckpointHint {
2068        CheckpointHint::ExpensiveReplay
2069    }
2070}
2071
2072impl LinearizedOp<DynTensor> for EigLinearized {
2073    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2074        match &self.input {
2075            DynTensor::F32(input) => eig_jvp_t::<f32>(input, &input_tangents[0]),
2076            DynTensor::F64(input) => eig_jvp_t::<f64>(input, &input_tangents[0]),
2077            DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2078                "eig AD currently supports real inputs only",
2079            )),
2080        }
2081        .map_err(into_ad_error)
2082    }
2083
2084    fn vjp(
2085        &self,
2086        output_cotangents: &[Option<DynTensor>],
2087        input_grad_mask: &[bool],
2088    ) -> AdResult<Vec<Option<DynTensor>>> {
2089        match &self.input {
2090            DynTensor::F32(input) => eig_vjp_t::<f32>(input, output_cotangents, input_grad_mask),
2091            DynTensor::F64(input) => eig_vjp_t::<f64>(input, output_cotangents, input_grad_mask),
2092            DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2093                "eig AD currently supports real inputs only",
2094            )),
2095        }
2096        .map_err(into_ad_error)
2097    }
2098}
2099
2100impl LinearizableOp<DynTensor> for EigenOp {
2101    type Linearized = EigenLinearized;
2102
2103    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2104        match inputs[0] {
2105            DynTensor::F32(input) => eigen_primal_t::<f32>(input),
2106            DynTensor::F64(input) => eigen_primal_t::<f64>(input),
2107            DynTensor::C32(input) => eigen_primal_t::<Complex32>(input),
2108            DynTensor::C64(input) => eigen_primal_t::<Complex64>(input),
2109        }
2110        .map_err(into_ad_error)
2111    }
2112
2113    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2114        Ok(differentiable_schema(1))
2115    }
2116
2117    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2118        Ok(differentiable_schema(2))
2119    }
2120
2121    fn linearize(
2122        &self,
2123        inputs: &[&DynTensor],
2124        _outputs: &[DynTensor],
2125    ) -> AdResult<Self::Linearized> {
2126        Ok(EigenLinearized {
2127            input: inputs[0].clone(),
2128        })
2129    }
2130
2131    fn checkpoint_hint(&self) -> CheckpointHint {
2132        CheckpointHint::ExpensiveReplay
2133    }
2134}
2135
2136impl LinearizedOp<DynTensor> for EigenLinearized {
2137    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2138        match &self.input {
2139            DynTensor::F32(input) => eigen_jvp_t::<f32>(input, &input_tangents[0]),
2140            DynTensor::F64(input) => eigen_jvp_t::<f64>(input, &input_tangents[0]),
2141            DynTensor::C32(input) => eigen_jvp_t::<Complex32>(input, &input_tangents[0]),
2142            DynTensor::C64(input) => eigen_jvp_t::<Complex64>(input, &input_tangents[0]),
2143        }
2144        .map_err(into_ad_error)
2145    }
2146
2147    fn vjp(
2148        &self,
2149        output_cotangents: &[Option<DynTensor>],
2150        input_grad_mask: &[bool],
2151    ) -> AdResult<Vec<Option<DynTensor>>> {
2152        match &self.input {
2153            DynTensor::F32(input) => eigen_vjp_t::<f32>(input, output_cotangents, input_grad_mask),
2154            DynTensor::F64(input) => eigen_vjp_t::<f64>(input, output_cotangents, input_grad_mask),
2155            DynTensor::C32(input) => {
2156                eigen_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask)
2157            }
2158            DynTensor::C64(input) => {
2159                eigen_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask)
2160            }
2161        }
2162        .map_err(into_ad_error)
2163    }
2164}
2165
2166impl LinearizableOp<DynTensor> for DetOp {
2167    type Linearized = DetLinearized;
2168
2169    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2170        let output = match inputs[0] {
2171            DynTensor::F32(input) => det_primal_t::<f32>(input),
2172            DynTensor::F64(input) => det_primal_t::<f64>(input),
2173            DynTensor::C32(input) => det_primal_t::<Complex32>(input),
2174            DynTensor::C64(input) => det_primal_t::<Complex64>(input),
2175        }
2176        .map_err(into_ad_error)?;
2177        Ok(vec![output])
2178    }
2179
2180    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2181        Ok(differentiable_schema(1))
2182    }
2183
2184    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2185        Ok(differentiable_schema(1))
2186    }
2187
2188    fn linearize(
2189        &self,
2190        inputs: &[&DynTensor],
2191        _outputs: &[DynTensor],
2192    ) -> AdResult<Self::Linearized> {
2193        Ok(DetLinearized {
2194            input: inputs[0].clone(),
2195        })
2196    }
2197
2198    fn checkpoint_hint(&self) -> CheckpointHint {
2199        CheckpointHint::ExpensiveReplay
2200    }
2201}
2202
2203impl LinearizedOp<DynTensor> for DetLinearized {
2204    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2205        let tangent = match &self.input {
2206            DynTensor::F32(input) => det_jvp_t::<f32>(input, &input_tangents[0]),
2207            DynTensor::F64(input) => det_jvp_t::<f64>(input, &input_tangents[0]),
2208            DynTensor::C32(input) => det_jvp_t::<Complex32>(input, &input_tangents[0]),
2209            DynTensor::C64(input) => det_jvp_t::<Complex64>(input, &input_tangents[0]),
2210        }
2211        .map_err(into_ad_error)?;
2212        Ok(vec![tangent])
2213    }
2214
2215    fn vjp(
2216        &self,
2217        output_cotangents: &[Option<DynTensor>],
2218        input_grad_mask: &[bool],
2219    ) -> AdResult<Vec<Option<DynTensor>>> {
2220        let Some(cotangent) = output_cotangents[0].as_ref() else {
2221            return Ok(vec![None]);
2222        };
2223        match &self.input {
2224            DynTensor::F32(input) => det_vjp_t::<f32>(input, cotangent, input_grad_mask),
2225            DynTensor::F64(input) => det_vjp_t::<f64>(input, cotangent, input_grad_mask),
2226            DynTensor::C32(input) => det_vjp_t::<Complex32>(input, cotangent, input_grad_mask),
2227            DynTensor::C64(input) => det_vjp_t::<Complex64>(input, cotangent, input_grad_mask),
2228        }
2229        .map_err(into_ad_error)
2230    }
2231}
2232
2233impl LinearizableOp<DynTensor> for PInvOp {
2234    type Linearized = PInvLinearized;
2235
2236    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2237        let output = match inputs[0] {
2238            DynTensor::F32(input) => pinv_primal_t::<f32>(input, self.rcond),
2239            DynTensor::F64(input) => pinv_primal_t::<f64>(input, self.rcond),
2240            DynTensor::C32(input) => pinv_primal_t::<Complex32>(input, self.rcond),
2241            DynTensor::C64(input) => pinv_primal_t::<Complex64>(input, self.rcond),
2242        }
2243        .map_err(into_ad_error)?;
2244        Ok(vec![output])
2245    }
2246
2247    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2248        Ok(differentiable_schema(1))
2249    }
2250
2251    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2252        Ok(differentiable_schema(1))
2253    }
2254
2255    fn linearize(
2256        &self,
2257        inputs: &[&DynTensor],
2258        _outputs: &[DynTensor],
2259    ) -> AdResult<Self::Linearized> {
2260        Ok(PInvLinearized {
2261            input: inputs[0].clone(),
2262            rcond: self.rcond,
2263        })
2264    }
2265
2266    fn checkpoint_hint(&self) -> CheckpointHint {
2267        CheckpointHint::ExpensiveReplay
2268    }
2269}
2270
2271impl LinearizedOp<DynTensor> for PInvLinearized {
2272    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2273        let tangent = match &self.input {
2274            DynTensor::F32(input) => pinv_jvp_t::<f32>(input, &input_tangents[0], self.rcond),
2275            DynTensor::F64(input) => pinv_jvp_t::<f64>(input, &input_tangents[0], self.rcond),
2276            DynTensor::C32(input) => pinv_jvp_t::<Complex32>(input, &input_tangents[0], self.rcond),
2277            DynTensor::C64(input) => pinv_jvp_t::<Complex64>(input, &input_tangents[0], self.rcond),
2278        }
2279        .map_err(into_ad_error)?;
2280        Ok(vec![tangent])
2281    }
2282
2283    fn vjp(
2284        &self,
2285        output_cotangents: &[Option<DynTensor>],
2286        input_grad_mask: &[bool],
2287    ) -> AdResult<Vec<Option<DynTensor>>> {
2288        let Some(cotangent) = output_cotangents[0].as_ref() else {
2289            return Ok(vec![None]);
2290        };
2291        match &self.input {
2292            DynTensor::F32(input) => {
2293                pinv_vjp_t::<f32>(input, cotangent, input_grad_mask, self.rcond)
2294            }
2295            DynTensor::F64(input) => {
2296                pinv_vjp_t::<f64>(input, cotangent, input_grad_mask, self.rcond)
2297            }
2298            DynTensor::C32(input) => {
2299                pinv_vjp_t::<Complex32>(input, cotangent, input_grad_mask, self.rcond)
2300            }
2301            DynTensor::C64(input) => {
2302                pinv_vjp_t::<Complex64>(input, cotangent, input_grad_mask, self.rcond)
2303            }
2304        }
2305        .map_err(into_ad_error)
2306    }
2307}
2308
2309impl LinearizableOp<DynTensor> for MatrixExpOp {
2310    type Linearized = MatrixExpLinearized;
2311
2312    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2313        let output = match inputs[0] {
2314            DynTensor::F32(input) => matrix_exp_primal_t::<f32>(input),
2315            DynTensor::F64(input) => matrix_exp_primal_t::<f64>(input),
2316            DynTensor::C32(input) => matrix_exp_primal_t::<Complex32>(input),
2317            DynTensor::C64(input) => matrix_exp_primal_t::<Complex64>(input),
2318        }
2319        .map_err(into_ad_error)?;
2320        Ok(vec![output])
2321    }
2322
2323    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2324        Ok(differentiable_schema(1))
2325    }
2326
2327    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2328        Ok(differentiable_schema(1))
2329    }
2330
2331    fn linearize(
2332        &self,
2333        inputs: &[&DynTensor],
2334        _outputs: &[DynTensor],
2335    ) -> AdResult<Self::Linearized> {
2336        Ok(MatrixExpLinearized {
2337            input: inputs[0].clone(),
2338        })
2339    }
2340
2341    fn checkpoint_hint(&self) -> CheckpointHint {
2342        CheckpointHint::ExpensiveReplay
2343    }
2344}
2345
2346impl LinearizedOp<DynTensor> for MatrixExpLinearized {
2347    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2348        let tangent = match &self.input {
2349            DynTensor::F32(input) => matrix_exp_jvp_t::<f32>(input, &input_tangents[0]),
2350            DynTensor::F64(input) => matrix_exp_jvp_t::<f64>(input, &input_tangents[0]),
2351            DynTensor::C32(input) => matrix_exp_jvp_t::<Complex32>(input, &input_tangents[0]),
2352            DynTensor::C64(input) => matrix_exp_jvp_t::<Complex64>(input, &input_tangents[0]),
2353        }
2354        .map_err(into_ad_error)?;
2355        Ok(vec![tangent])
2356    }
2357
2358    fn vjp(
2359        &self,
2360        output_cotangents: &[Option<DynTensor>],
2361        input_grad_mask: &[bool],
2362    ) -> AdResult<Vec<Option<DynTensor>>> {
2363        let Some(cotangent) = output_cotangents[0].as_ref() else {
2364            return Ok(vec![None]);
2365        };
2366        match &self.input {
2367            DynTensor::F32(input) => matrix_exp_vjp_t::<f32>(input, cotangent, input_grad_mask),
2368            DynTensor::F64(input) => matrix_exp_vjp_t::<f64>(input, cotangent, input_grad_mask),
2369            DynTensor::C32(input) => {
2370                matrix_exp_vjp_t::<Complex32>(input, cotangent, input_grad_mask)
2371            }
2372            DynTensor::C64(input) => {
2373                matrix_exp_vjp_t::<Complex64>(input, cotangent, input_grad_mask)
2374            }
2375        }
2376        .map_err(into_ad_error)
2377    }
2378}
2379
2380impl LinearizableOp<DynTensor> for QrOp {
2381    type Linearized = QrLinearized;
2382
2383    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2384        match inputs[0] {
2385            DynTensor::F32(input) => qr_primal_t::<f32>(input),
2386            DynTensor::F64(input) => qr_primal_t::<f64>(input),
2387            DynTensor::C32(input) => qr_primal_t::<Complex32>(input),
2388            DynTensor::C64(input) => qr_primal_t::<Complex64>(input),
2389        }
2390        .map_err(into_ad_error)
2391    }
2392
2393    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2394        Ok(differentiable_schema(1))
2395    }
2396
2397    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2398        Ok(differentiable_schema(2))
2399    }
2400
2401    fn linearize(
2402        &self,
2403        inputs: &[&DynTensor],
2404        _outputs: &[DynTensor],
2405    ) -> AdResult<Self::Linearized> {
2406        Ok(QrLinearized {
2407            input: inputs[0].clone(),
2408        })
2409    }
2410
2411    fn checkpoint_hint(&self) -> CheckpointHint {
2412        CheckpointHint::ExpensiveReplay
2413    }
2414}
2415
2416impl LinearizedOp<DynTensor> for QrLinearized {
2417    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2418        match &self.input {
2419            DynTensor::F32(input) => qr_jvp_t::<f32>(input, &input_tangents[0]),
2420            DynTensor::F64(input) => qr_jvp_t::<f64>(input, &input_tangents[0]),
2421            DynTensor::C32(input) => qr_jvp_t::<Complex32>(input, &input_tangents[0]),
2422            DynTensor::C64(input) => qr_jvp_t::<Complex64>(input, &input_tangents[0]),
2423        }
2424        .map_err(into_ad_error)
2425    }
2426
2427    fn vjp(
2428        &self,
2429        output_cotangents: &[Option<DynTensor>],
2430        input_grad_mask: &[bool],
2431    ) -> AdResult<Vec<Option<DynTensor>>> {
2432        match &self.input {
2433            DynTensor::F32(input) => qr_vjp_t::<f32>(input, output_cotangents, input_grad_mask),
2434            DynTensor::F64(input) => qr_vjp_t::<f64>(input, output_cotangents, input_grad_mask),
2435            DynTensor::C32(input) => {
2436                qr_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask)
2437            }
2438            DynTensor::C64(input) => {
2439                qr_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask)
2440            }
2441        }
2442        .map_err(into_ad_error)
2443    }
2444}
2445
2446impl LinearizableOp<DynTensor> for SvdOp {
2447    type Linearized = SvdLinearized;
2448
2449    fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2450        match inputs[0] {
2451            DynTensor::F32(input) => svd_primal_t::<f32>(input, self.options.as_ref()),
2452            DynTensor::F64(input) => svd_primal_t::<f64>(input, self.options.as_ref()),
2453            DynTensor::C32(input) => svd_primal_t::<Complex32>(input, self.options.as_ref()),
2454            DynTensor::C64(input) => svd_primal_t::<Complex64>(input, self.options.as_ref()),
2455        }
2456        .map_err(into_ad_error)
2457    }
2458
2459    fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2460        Ok(differentiable_schema(1))
2461    }
2462
2463    fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2464        Ok(differentiable_schema(3))
2465    }
2466
2467    fn linearize(
2468        &self,
2469        inputs: &[&DynTensor],
2470        _outputs: &[DynTensor],
2471    ) -> AdResult<Self::Linearized> {
2472        Ok(SvdLinearized {
2473            input: inputs[0].clone(),
2474            options: self.options.clone(),
2475        })
2476    }
2477
2478    fn checkpoint_hint(&self) -> CheckpointHint {
2479        CheckpointHint::ExpensiveReplay
2480    }
2481}
2482
2483impl LinearizedOp<DynTensor> for SvdLinearized {
2484    fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2485        match &self.input {
2486            DynTensor::F32(input) => {
2487                svd_jvp_t::<f32>(input, &input_tangents[0], self.options.as_ref())
2488            }
2489            DynTensor::F64(input) => {
2490                svd_jvp_t::<f64>(input, &input_tangents[0], self.options.as_ref())
2491            }
2492            DynTensor::C32(input) => {
2493                svd_jvp_t::<Complex32>(input, &input_tangents[0], self.options.as_ref())
2494            }
2495            DynTensor::C64(input) => {
2496                svd_jvp_t::<Complex64>(input, &input_tangents[0], self.options.as_ref())
2497            }
2498        }
2499        .map_err(into_ad_error)
2500    }
2501
2502    fn vjp(
2503        &self,
2504        output_cotangents: &[Option<DynTensor>],
2505        input_grad_mask: &[bool],
2506    ) -> AdResult<Vec<Option<DynTensor>>> {
2507        match &self.input {
2508            DynTensor::F32(input) => svd_vjp_t::<f32>(
2509                input,
2510                output_cotangents,
2511                input_grad_mask,
2512                self.options.as_ref(),
2513            ),
2514            DynTensor::F64(input) => svd_vjp_t::<f64>(
2515                input,
2516                output_cotangents,
2517                input_grad_mask,
2518                self.options.as_ref(),
2519            ),
2520            DynTensor::C32(input) => svd_vjp_t::<Complex32>(
2521                input,
2522                output_cotangents,
2523                input_grad_mask,
2524                self.options.as_ref(),
2525            ),
2526            DynTensor::C64(input) => svd_vjp_t::<Complex64>(
2527                input,
2528                output_cotangents,
2529                input_grad_mask,
2530                self.options.as_ref(),
2531            ),
2532        }
2533        .map_err(into_ad_error)
2534    }
2535}
2536
2537pub fn solve_dyn_values(a: &DynValue, b: &DynValue) -> AdResult<DynValue> {
2538    SolveOp.apply_one(&[a, b])
2539}
2540
2541pub fn lstsq_dyn_values(a: &DynValue, b: &DynValue) -> AdResult<DynLstsqValues> {
2542    let mut outputs = LstsqOp.apply(&[a, b])?;
2543    if outputs.len() != 2 {
2544        return Err(AutodiffError::InvalidArgument(format!(
2545            "LstsqOp expected 2 outputs, got {}",
2546            outputs.len()
2547        )));
2548    }
2549    let residuals = outputs.pop().unwrap();
2550    let solution = outputs.pop().unwrap();
2551    let (rank, singular_values) = match a.primal() {
2552        DynTensor::F32(input) => lstsq_aux_t::<f32>(input),
2553        DynTensor::F64(input) => lstsq_aux_t::<f64>(input),
2554        DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2555            "lstsq AD currently supports real dtypes only",
2556        )),
2557    }
2558    .map_err(into_ad_error)?;
2559    Ok(DynLstsqValues {
2560        solution,
2561        residuals,
2562        rank,
2563        singular_values,
2564    })
2565}
2566
2567pub fn solve_triangular_dyn_value(a: &DynValue, b: &DynValue, upper: bool) -> AdResult<DynValue> {
2568    SolveTriangularOp::new(upper).apply_one(&[a, b])
2569}
2570
2571pub fn norm_dyn_value(input: &DynValue, kind: NormKind) -> AdResult<DynValue> {
2572    NormOp::new(kind).apply_one(&[input])
2573}
2574
2575pub fn det_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2576    DetOp.apply_one(&[input])
2577}
2578
2579pub fn inv_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2580    InvOp.apply_one(&[input])
2581}
2582
2583pub fn slogdet_dyn_value(input: &DynValue) -> AdResult<DynSlogdetValues> {
2584    let mut outputs = SlogdetOp.apply(&[input])?;
2585    if outputs.len() != 2 {
2586        return Err(AutodiffError::InvalidArgument(format!(
2587            "SlogdetOp expected 2 outputs, got {}",
2588            outputs.len()
2589        )));
2590    }
2591    let logabsdet = outputs.pop().unwrap();
2592    let sign = outputs.pop().unwrap();
2593    Ok(DynSlogdetValues { sign, logabsdet })
2594}
2595
2596pub fn cholesky_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2597    CholeskyOp.apply_one(&[input])
2598}
2599
2600pub fn lu_dyn_value(input: &DynValue, pivot: LuPivot) -> AdResult<DynLuValues> {
2601    let mut outputs = LuOp::new(pivot).apply(&[input])?;
2602    if outputs.len() != 3 {
2603        return Err(AutodiffError::InvalidArgument(format!(
2604            "LuOp expected 3 outputs, got {}",
2605            outputs.len()
2606        )));
2607    }
2608    let u = outputs.pop().unwrap();
2609    let l = outputs.pop().unwrap();
2610    let p = outputs.pop().unwrap();
2611    Ok(DynLuValues { p, l, u })
2612}
2613
2614pub fn qr_dyn_value(input: &DynValue) -> AdResult<DynQrValues> {
2615    let mut outputs = QrOp.apply(&[input])?;
2616    if outputs.len() != 2 {
2617        return Err(AutodiffError::InvalidArgument(format!(
2618            "QrOp expected 2 outputs, got {}",
2619            outputs.len()
2620        )));
2621    }
2622    let r = outputs.pop().unwrap();
2623    let q = outputs.pop().unwrap();
2624    Ok(DynQrValues { q, r })
2625}
2626
2627pub fn svd_dyn_value(input: &DynValue, options: Option<SvdOptions>) -> AdResult<DynSvdValues> {
2628    let mut outputs = SvdOp::new(options).apply(&[input])?;
2629    if outputs.len() != 3 {
2630        return Err(AutodiffError::InvalidArgument(format!(
2631            "SvdOp expected 3 outputs, got {}",
2632            outputs.len()
2633        )));
2634    }
2635    let vt = outputs.pop().unwrap();
2636    let s = outputs.pop().unwrap();
2637    let u = outputs.pop().unwrap();
2638    Ok(DynSvdValues { u, s, vt })
2639}
2640
2641pub fn eig_dyn_value(input: &DynValue) -> AdResult<DynEigValues> {
2642    let mut outputs = EigOp.apply(&[input])?;
2643    if outputs.len() != 2 {
2644        return Err(AutodiffError::InvalidArgument(format!(
2645            "EigOp expected 2 outputs, got {}",
2646            outputs.len()
2647        )));
2648    }
2649    let vectors = outputs.pop().unwrap();
2650    let values = outputs.pop().unwrap();
2651    Ok(DynEigValues { values, vectors })
2652}
2653
2654pub fn eigen_dyn_value(input: &DynValue) -> AdResult<DynEigenValues> {
2655    let mut outputs = EigenOp.apply(&[input])?;
2656    if outputs.len() != 2 {
2657        return Err(AutodiffError::InvalidArgument(format!(
2658            "EigenOp expected 2 outputs, got {}",
2659            outputs.len()
2660        )));
2661    }
2662    let vectors = outputs.pop().unwrap();
2663    let values = outputs.pop().unwrap();
2664    Ok(DynEigenValues { values, vectors })
2665}
2666
2667pub fn pinv_dyn_value(input: &DynValue, rcond: Option<f64>) -> AdResult<DynValue> {
2668    PInvOp::new(rcond).apply_one(&[input])
2669}
2670
2671pub fn matrix_exp_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2672    MatrixExpOp.apply_one(&[input])
2673}