1use chainrules_core::AutodiffError;
2use num_complex::{Complex32, Complex64};
3use num_traits::Zero;
4use tenferro_algebra::{Conjugate, Scalar};
5use tenferro_device::LogicalMemorySpace;
6use tenferro_internal_ad_core::{
7 AdResult, CheckpointHint, DynValue, LinearizableOp, LinearizedOp, Schema, SlotSchema,
8};
9use tenferro_internal_frontend_core::{DynTensor, DynTensorTyped, StructuredTensor};
10use tenferro_internal_runtime::contracts::{LinalgRuntimeValue, RealLinalgRuntimeValue};
11use tenferro_internal_runtime::dispatch::{
12 with_linalg_runtime, LuLinalgDispatchValue, MatrixExpLinalgDispatchValue,
13 NormLinalgDispatchValue, ScaledLinalgDispatchValue, ScaledRealLinalgDispatchValue,
14 SlogdetLinalgDispatchValue,
15};
16use tenferro_linalg::backend::{CudaLinalgScalar, LinalgCapabilityOp};
17use tenferro_linalg::{
18 cholesky, cholesky_frule, cholesky_rrule, det, det_frule, det_rrule, eig, eig_frule, eig_rrule,
19 eigen, eigen_frule, eigen_rrule, inv, inv_frule, inv_rrule, lstsq, lstsq_aux, lstsq_frule,
20 lstsq_rrule, lu, lu_frule, lu_rrule, matrix_exp, matrix_exp_frule, matrix_exp_rrule, norm,
21 norm_frule, norm_frule_complex, norm_rrule, norm_rrule_complex, pinv, pinv_frule, pinv_rrule,
22 qr, qr_frule, qr_rrule, slogdet, slogdet_frule, slogdet_rrule, solve, solve_frule, solve_rrule,
23 solve_triangular, solve_triangular_frule, solve_triangular_rrule, svd, svd_frule, svd_rrule,
24 EigCotangent, EigenCotangent, KernelLinalgScalar, LuCotangent, LuPivot, NormKind, QrCotangent,
25 SlogdetCotangent, SvdCotangent, SvdOptions,
26};
27use tenferro_tensor::{KeepCountScalar, MemoryOrder, Tensor as DenseTensor};
28
29use crate::{Error, Result};
30
31#[derive(Clone, Copy)]
32pub struct SolveOp;
33
34#[derive(Clone, Copy)]
35pub struct LstsqOp;
36
37#[derive(Clone, Copy)]
38pub struct SolveTriangularOp {
39 upper: bool,
40}
41
42#[derive(Clone, Copy)]
43pub struct NormOp {
44 kind: NormKind,
45}
46
47#[derive(Clone, Copy)]
48pub struct DetOp;
49
50#[derive(Clone, Copy)]
51pub struct InvOp;
52
53#[derive(Clone, Copy)]
54pub struct SlogdetOp;
55
56#[derive(Clone, Copy)]
57pub struct CholeskyOp;
58
59#[derive(Clone, Copy)]
60pub struct LuOp {
61 pivot: LuPivot,
62}
63
64#[derive(Clone, Copy)]
65pub struct QrOp;
66
67#[derive(Clone, Copy)]
68pub struct EigOp;
69
70#[derive(Clone, Copy)]
71pub struct EigenOp;
72
73#[derive(Clone, Default)]
74pub struct SvdOp {
75 options: Option<SvdOptions>,
76}
77
78#[derive(Clone, Default)]
79pub struct PInvOp {
80 rcond: Option<f64>,
81}
82
83#[derive(Clone, Copy)]
84pub struct MatrixExpOp;
85
86pub struct DynQrValues {
87 pub q: DynValue,
88 pub r: DynValue,
89}
90
91pub struct DynLstsqValues {
92 pub solution: DynValue,
93 pub residuals: DynValue,
94 pub rank: Vec<usize>,
95 pub singular_values: DynTensor,
96}
97
98pub struct DynLuValues {
99 pub p: DynValue,
100 pub l: DynValue,
101 pub u: DynValue,
102}
103
104pub struct DynEigValues {
105 pub values: DynValue,
106 pub vectors: DynValue,
107}
108
109pub struct DynEigenValues {
110 pub values: DynValue,
111 pub vectors: DynValue,
112}
113
114pub struct DynSlogdetValues {
115 pub sign: DynValue,
116 pub logabsdet: DynValue,
117}
118
119pub struct DynSvdValues {
120 pub u: DynValue,
121 pub s: DynValue,
122 pub vt: DynValue,
123}
124
125#[doc(hidden)]
126pub struct SolveLinearized {
127 a: DynTensor,
128 b: DynTensor,
129}
130
131#[doc(hidden)]
132pub struct LstsqLinearized {
133 a: DynTensor,
134 b: DynTensor,
135}
136
137#[doc(hidden)]
138pub struct SolveTriangularLinearized {
139 a: DynTensor,
140 b: DynTensor,
141 upper: bool,
142}
143
144#[doc(hidden)]
145pub struct NormLinearized {
146 input: DynTensor,
147 kind: NormKind,
148}
149
150#[doc(hidden)]
151pub struct DetLinearized {
152 input: DynTensor,
153}
154
155#[doc(hidden)]
156pub struct InvLinearized {
157 input: DynTensor,
158}
159
160#[doc(hidden)]
161pub struct SlogdetLinearized {
162 input: DynTensor,
163}
164
165#[doc(hidden)]
166pub struct CholeskyLinearized {
167 input: DynTensor,
168}
169
170#[doc(hidden)]
171pub struct LuLinearized {
172 input: DynTensor,
173 pivot: LuPivot,
174}
175
176#[doc(hidden)]
177pub struct QrLinearized {
178 input: DynTensor,
179}
180
181#[doc(hidden)]
182pub struct EigLinearized {
183 input: DynTensor,
184}
185
186#[doc(hidden)]
187pub struct EigenLinearized {
188 input: DynTensor,
189}
190
191#[doc(hidden)]
192pub struct SvdLinearized {
193 input: DynTensor,
194 options: Option<SvdOptions>,
195}
196
197#[doc(hidden)]
198pub struct PInvLinearized {
199 input: DynTensor,
200 rcond: Option<f64>,
201}
202
203#[doc(hidden)]
204pub struct MatrixExpLinearized {
205 input: DynTensor,
206}
207
208fn differentiable_schema(slots: usize) -> Schema {
209 Schema {
210 slots: (0..slots)
211 .map(|_| SlotSchema {
212 differentiable: true,
213 auxiliary: false,
214 })
215 .collect(),
216 }
217}
218
219fn slogdet_output_schema() -> Schema {
220 Schema {
221 slots: vec![
222 SlotSchema {
223 differentiable: false,
224 auxiliary: true,
225 },
226 SlotSchema {
227 differentiable: true,
228 auxiliary: false,
229 },
230 ],
231 }
232}
233
234fn lstsq_output_schema() -> Schema {
235 Schema {
236 slots: vec![
237 SlotSchema {
238 differentiable: true,
239 auxiliary: false,
240 },
241 SlotSchema {
242 differentiable: false,
243 auxiliary: true,
244 },
245 ],
246 }
247}
248
249fn lu_output_schema() -> Schema {
250 Schema {
251 slots: vec![
252 SlotSchema {
253 differentiable: false,
254 auxiliary: true,
255 },
256 SlotSchema {
257 differentiable: true,
258 auxiliary: false,
259 },
260 SlotSchema {
261 differentiable: true,
262 auxiliary: false,
263 },
264 ],
265 }
266}
267
268fn invalid_argument(message: impl Into<String>) -> Error {
269 AutodiffError::InvalidArgument(message.into()).into()
270}
271
272fn into_ad_error(error: Error) -> AutodiffError {
273 match error {
274 Error::Autodiff(error) => error,
275 other => AutodiffError::InvalidArgument(other.to_string()),
276 }
277}
278
279macro_rules! dispatch_linalg {
280 ($ty:ty, $op:expr, $cap:expr, |$ctx:ident| $body:expr) => {{
281 with_linalg_runtime::<$ty, _>($op, $cap, |$ctx| $body, |$ctx| $body, |$ctx| $body)
282 }};
283}
284
285fn dense_dyn_tensor_typed<T>(value: &DynTensor, context: &str) -> Result<DenseTensor<T>>
286where
287 T: DynTensorTyped + Copy,
288{
289 let structured = T::structured_ref(value)
290 .ok_or_else(|| invalid_argument(format!("{context} requires matching dtypes")))?;
291 structured.to_dense()
292}
293
294fn optional_dense_dyn_tensor_typed<T>(
295 value: &Option<DynTensor>,
296 context: &str,
297) -> Result<Option<DenseTensor<T>>>
298where
299 T: DynTensorTyped + Copy,
300{
301 value
302 .as_ref()
303 .map(|tensor| dense_dyn_tensor_typed::<T>(tensor, context))
304 .transpose()
305}
306
307fn dense_zeros_like<T>(like: &DenseTensor<T>) -> Result<DenseTensor<T>>
308where
309 T: Scalar + Zero + Copy,
310{
311 let total: usize = like.dims().iter().product();
312 DenseTensor::from_slice(
313 &vec![T::zero(); total],
314 like.dims(),
315 MemoryOrder::ColumnMajor,
316 )
317 .map_err(Error::from)
318}
319
320fn dense_optional_or_zero<T>(
321 value: &Option<DynTensor>,
322 like: &DenseTensor<T>,
323 context: &str,
324) -> Result<DenseTensor<T>>
325where
326 T: DynTensorTyped + Scalar + Zero + Copy,
327{
328 optional_dense_dyn_tensor_typed::<T>(value, context)?.map_or_else(|| dense_zeros_like(like), Ok)
329}
330
331fn dyn_from_dense<T>(value: DenseTensor<T>) -> DynTensor
332where
333 T: DynTensorTyped + Copy,
334{
335 T::into_dyn(StructuredTensor::from(value))
336}
337
338fn solve_primal_t<T>(a: &StructuredTensor<T>, b: &StructuredTensor<T>) -> Result<DynTensor>
339where
340 T: LinalgRuntimeValue + DynTensorTyped + Copy,
341{
342 let dense_a = a.to_dense()?;
343 let dense_b = b.to_dense()?;
344 let output = dispatch_linalg!(T, "solve_dyn_value", LinalgCapabilityOp::Solve, |ctx| {
345 solve(ctx, &dense_a, &dense_b).map_err(Error::from)
346 })?;
347 Ok(dyn_from_dense(output))
348}
349
350fn solve_jvp_t<T>(
351 a: &StructuredTensor<T>,
352 b: &StructuredTensor<T>,
353 tangents: &[Option<DynTensor>],
354) -> Result<Option<DynTensor>>
355where
356 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
357{
358 if tangents.iter().all(Option::is_none) {
359 return Ok(None);
360 }
361 let dense_a = a.to_dense()?;
362 let dense_b = b.to_dense()?;
363 let tangent_a = dense_optional_or_zero(&tangents[0], &dense_a, "solve_jvp tangent_a")?;
364 let tangent_b = dense_optional_or_zero(&tangents[1], &dense_b, "solve_jvp tangent_b")?;
365 let (_, tangent) = dispatch_linalg!(T, "solve_jvp", LinalgCapabilityOp::Solve, |ctx| {
366 solve_frule(ctx, &dense_a, &dense_b, &tangent_a, &tangent_b).map_err(Error::from)
367 })?;
368 Ok(Some(dyn_from_dense(tangent)))
369}
370
371fn solve_vjp_t<T>(
372 a: &StructuredTensor<T>,
373 b: &StructuredTensor<T>,
374 cotangent: &DynTensor,
375 input_grad_mask: &[bool],
376) -> Result<Vec<Option<DynTensor>>>
377where
378 T: LinalgRuntimeValue + DynTensorTyped + Copy,
379{
380 if !input_grad_mask.iter().any(|needed| *needed) {
381 return Ok(vec![None, None]);
382 }
383 let dense_a = a.to_dense()?;
384 let dense_b = b.to_dense()?;
385 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "solve_vjp")?;
386 let grad = dispatch_linalg!(T, "solve_vjp", LinalgCapabilityOp::Solve, |ctx| {
387 solve_rrule(ctx, &dense_a, &dense_b, &dense_cotangent).map_err(Error::from)
388 })?;
389 Ok(vec![
390 input_grad_mask[0].then(|| dyn_from_dense(grad.a)),
391 input_grad_mask[1].then(|| dyn_from_dense(grad.b)),
392 ])
393}
394
395fn lstsq_primal_t<T>(a: &StructuredTensor<T>, b: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
396where
397 T: RealLinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
398{
399 let dense_a = a.to_dense()?;
400 let dense_b = b.to_dense()?;
401 let output = dispatch_linalg!(T, "lstsq_dyn_values", LinalgCapabilityOp::Lstsq, |ctx| {
402 lstsq(ctx, &dense_a, &dense_b).map_err(Error::from)
403 })?;
404 Ok(vec![
405 dyn_from_dense(output.solution),
406 dyn_from_dense(output.residuals),
407 ])
408}
409
410fn lstsq_aux_t<T>(a: &StructuredTensor<T>) -> Result<(Vec<usize>, DynTensor)>
411where
412 T: RealLinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
413{
414 let dense_a = a.to_dense()?;
415 let aux = dispatch_linalg!(T, "lstsq_aux", LinalgCapabilityOp::Lstsq, |ctx| {
416 lstsq_aux(ctx, &dense_a).map_err(Error::from)
417 })?;
418 let rank = dense_rank_counts_to_vec(&aux.rank)?;
419 Ok((rank, dyn_from_dense(aux.singular_values)))
420}
421
422fn dense_rank_counts_to_vec<T>(rank_counts: &DenseTensor<T>) -> Result<Vec<usize>>
423where
424 T: KeepCountScalar + Copy,
425{
426 let host = rank_counts.to_memory_space_async(LogicalMemorySpace::MainMemory)?;
427 let contiguous = host.contiguous(MemoryOrder::ColumnMajor);
428 let slice = contiguous.buffer().as_slice().ok_or_else(|| {
429 invalid_argument("lstsq rank counts require host-accessible contiguous storage")
430 })?;
431 slice
432 .iter()
433 .map(|value| {
434 value
435 .to_usize()
436 .ok_or_else(|| invalid_argument("failed to convert lstsq rank count to usize"))
437 })
438 .collect()
439}
440
441fn lstsq_jvp_t<T>(
442 a: &StructuredTensor<T>,
443 b: &StructuredTensor<T>,
444 tangents: &[Option<DynTensor>],
445) -> Result<Vec<Option<DynTensor>>>
446where
447 T: ScaledRealLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
448{
449 if tangents.iter().all(Option::is_none) {
450 return Ok(vec![None, None]);
451 }
452 let dense_a = a.to_dense()?;
453 let dense_b = b.to_dense()?;
454 let tangent_a = dense_optional_or_zero(&tangents[0], &dense_a, "lstsq_jvp tangent_a")?;
455 let tangent_b = dense_optional_or_zero(&tangents[1], &dense_b, "lstsq_jvp tangent_b")?;
456 let (_, tangent) = dispatch_linalg!(T, "lstsq_jvp", LinalgCapabilityOp::Lstsq, |ctx| {
457 lstsq_frule(ctx, &dense_a, &dense_b, &tangent_a, &tangent_b).map_err(Error::from)
458 })?;
459 Ok(vec![
460 Some(dyn_from_dense(tangent.solution)),
461 Some(dyn_from_dense(tangent.residuals)),
462 ])
463}
464
465fn lstsq_vjp_t<T>(
466 a: &StructuredTensor<T>,
467 b: &StructuredTensor<T>,
468 output_cotangents: &[Option<DynTensor>],
469 input_grad_mask: &[bool],
470) -> Result<Vec<Option<DynTensor>>>
471where
472 T: ScaledRealLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
473{
474 if !input_grad_mask.iter().any(|needed| *needed) {
475 return Ok(vec![None, None]);
476 }
477 let cotangent_solution = output_cotangents[0]
478 .as_ref()
479 .map(|cotangent| dense_dyn_tensor_typed::<T>(cotangent, "lstsq_vjp solution"))
480 .transpose()?;
481 let cotangent_residuals = output_cotangents[1]
482 .as_ref()
483 .map(|cotangent| dense_dyn_tensor_typed::<T>(cotangent, "lstsq_vjp residuals"))
484 .transpose()?;
485 if cotangent_solution.is_none() && cotangent_residuals.is_none() {
486 return Ok(vec![None, None]);
487 }
488 let dense_a = a.to_dense()?;
489 let dense_b = b.to_dense()?;
490 let grad = dispatch_linalg!(T, "lstsq_vjp", LinalgCapabilityOp::Lstsq, |ctx| {
491 lstsq_rrule(
492 ctx,
493 &dense_a,
494 &dense_b,
495 cotangent_solution.as_ref(),
496 cotangent_residuals.as_ref(),
497 )
498 .map_err(Error::from)
499 })?;
500 Ok(vec![
501 input_grad_mask[0].then(|| dyn_from_dense(grad.a)),
502 input_grad_mask[1].then(|| dyn_from_dense(grad.b)),
503 ])
504}
505
506fn solve_triangular_primal_t<T>(
507 a: &StructuredTensor<T>,
508 b: &StructuredTensor<T>,
509 upper: bool,
510) -> Result<DynTensor>
511where
512 T: LinalgRuntimeValue + DynTensorTyped + Copy,
513{
514 let dense_a = a.to_dense()?;
515 let dense_b = b.to_dense()?;
516 let output = dispatch_linalg!(
517 T,
518 "solve_triangular_dyn_value",
519 LinalgCapabilityOp::SolveTriangular,
520 |ctx| { solve_triangular(ctx, &dense_a, &dense_b, upper).map_err(Error::from) }
521 )?;
522 Ok(dyn_from_dense(output))
523}
524
525fn solve_triangular_jvp_t<T>(
526 a: &StructuredTensor<T>,
527 b: &StructuredTensor<T>,
528 tangents: &[Option<DynTensor>],
529 upper: bool,
530) -> Result<Option<DynTensor>>
531where
532 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
533{
534 if tangents.iter().all(Option::is_none) {
535 return Ok(None);
536 }
537 let dense_a = a.to_dense()?;
538 let dense_b = b.to_dense()?;
539 let tangent_a =
540 dense_optional_or_zero(&tangents[0], &dense_a, "solve_triangular_jvp tangent_a")?;
541 let tangent_b =
542 dense_optional_or_zero(&tangents[1], &dense_b, "solve_triangular_jvp tangent_b")?;
543 let (_, tangent) = dispatch_linalg!(
544 T,
545 "solve_triangular_jvp",
546 LinalgCapabilityOp::SolveTriangular,
547 |ctx| {
548 solve_triangular_frule(ctx, &dense_a, &dense_b, &tangent_a, &tangent_b, upper)
549 .map_err(Error::from)
550 }
551 )?;
552 Ok(Some(dyn_from_dense(tangent)))
553}
554
555fn solve_triangular_vjp_t<T>(
556 a: &StructuredTensor<T>,
557 b: &StructuredTensor<T>,
558 cotangent: &DynTensor,
559 input_grad_mask: &[bool],
560 upper: bool,
561) -> Result<Vec<Option<DynTensor>>>
562where
563 T: LinalgRuntimeValue + DynTensorTyped + Copy,
564{
565 if !input_grad_mask.iter().any(|needed| *needed) {
566 return Ok(vec![None, None]);
567 }
568 let dense_a = a.to_dense()?;
569 let dense_b = b.to_dense()?;
570 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "solve_triangular_vjp")?;
571 let grad = dispatch_linalg!(
572 T,
573 "solve_triangular_vjp",
574 LinalgCapabilityOp::SolveTriangular,
575 |ctx| {
576 solve_triangular_rrule(ctx, &dense_a, &dense_b, &dense_cotangent, upper)
577 .map_err(Error::from)
578 }
579 )?;
580 Ok(vec![
581 input_grad_mask[0].then(|| dyn_from_dense(grad.a)),
582 input_grad_mask[1].then(|| dyn_from_dense(grad.b)),
583 ])
584}
585
586fn inv_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
587where
588 T: LinalgRuntimeValue + DynTensorTyped + Copy,
589{
590 let dense_input = input.to_dense()?;
591 let output = dispatch_linalg!(T, "inv_dyn_value", LinalgCapabilityOp::Inv, |ctx| {
592 inv(ctx, &dense_input).map_err(Error::from)
593 })?;
594 Ok(dyn_from_dense(output))
595}
596
597fn inv_jvp_t<T>(
598 input: &StructuredTensor<T>,
599 tangent: &Option<DynTensor>,
600) -> Result<Option<DynTensor>>
601where
602 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
603{
604 if tangent.is_none() {
605 return Ok(None);
606 }
607 let dense_input = input.to_dense()?;
608 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "inv_jvp tangent")?;
609 let (_, output_tangent) = dispatch_linalg!(T, "inv_jvp", LinalgCapabilityOp::Inv, |ctx| {
610 inv_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
611 })?;
612 Ok(Some(dyn_from_dense(output_tangent)))
613}
614
615fn inv_vjp_t<T>(
616 input: &StructuredTensor<T>,
617 cotangent: &DynTensor,
618 input_grad_mask: &[bool],
619) -> Result<Vec<Option<DynTensor>>>
620where
621 T: LinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
622{
623 if !input_grad_mask[0] {
624 return Ok(vec![None]);
625 }
626 let dense_input = input.to_dense()?;
627 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "inv_vjp")?;
628 let grad = dispatch_linalg!(T, "inv_vjp", LinalgCapabilityOp::Inv, |ctx| {
629 inv_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
630 })?;
631 Ok(vec![Some(dyn_from_dense(grad))])
632}
633
634fn slogdet_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
635where
636 T: SlogdetLinalgDispatchValue + DynTensorTyped + Copy,
637 T::Real: DynTensorTyped,
638{
639 let dense_input = input.to_dense()?;
640 let output = dispatch_linalg!(T, "slogdet_dyn_value", LinalgCapabilityOp::Slogdet, |ctx| {
641 slogdet(ctx, &dense_input).map_err(Error::from)
642 })?;
643 Ok(vec![
644 dyn_from_dense(output.sign),
645 dyn_from_dense(output.logabsdet),
646 ])
647}
648
649fn slogdet_jvp_t<T>(
650 input: &StructuredTensor<T>,
651 tangent: &Option<DynTensor>,
652) -> Result<Vec<Option<DynTensor>>>
653where
654 T: SlogdetLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
655 T::Real: DynTensorTyped,
656{
657 if tangent.is_none() {
658 return Ok(vec![None, None]);
659 }
660 let dense_input = input.to_dense()?;
661 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "slogdet_jvp tangent")?;
662 let (_, tangent_output) =
663 dispatch_linalg!(T, "slogdet_jvp", LinalgCapabilityOp::Slogdet, |ctx| {
664 slogdet_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
665 })?;
666 Ok(vec![
667 Some(dyn_from_dense(tangent_output.sign)),
668 Some(dyn_from_dense(tangent_output.logabsdet)),
669 ])
670}
671
672fn slogdet_vjp_t<T>(
673 input: &StructuredTensor<T>,
674 output_cotangents: &[Option<DynTensor>],
675 input_grad_mask: &[bool],
676) -> Result<Vec<Option<DynTensor>>>
677where
678 T: SlogdetLinalgDispatchValue + DynTensorTyped + Copy,
679 T::Real: DynTensorTyped,
680{
681 if !input_grad_mask[0] {
682 return Ok(vec![None]);
683 }
684 let dense_input = input.to_dense()?;
685 let cotangent = SlogdetCotangent {
686 sign: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[0], "slogdet_vjp sign")?,
687 logabsdet: optional_dense_dyn_tensor_typed::<T::Real>(
688 &output_cotangents[1],
689 "slogdet_vjp logabsdet",
690 )?,
691 };
692 if cotangent.sign.is_none() && cotangent.logabsdet.is_none() {
693 return Ok(vec![None]);
694 }
695 let grad = dispatch_linalg!(T, "slogdet_vjp", LinalgCapabilityOp::Slogdet, |ctx| {
696 slogdet_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
697 })?;
698 Ok(vec![Some(dyn_from_dense(grad))])
699}
700
701fn cholesky_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
702where
703 T: LinalgRuntimeValue + DynTensorTyped + Copy,
704{
705 let dense_input = input.to_dense()?;
706 let output = dispatch_linalg!(
707 T,
708 "cholesky_dyn_value",
709 LinalgCapabilityOp::Cholesky,
710 |ctx| { cholesky(ctx, &dense_input).map_err(Error::from) }
711 )?;
712 Ok(dyn_from_dense(output))
713}
714
715fn cholesky_jvp_t<T>(
716 input: &StructuredTensor<T>,
717 tangent: &Option<DynTensor>,
718) -> Result<Option<DynTensor>>
719where
720 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
721{
722 if tangent.is_none() {
723 return Ok(None);
724 }
725 let dense_input = input.to_dense()?;
726 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "cholesky_jvp tangent")?;
727 let (_, output_tangent) =
728 dispatch_linalg!(T, "cholesky_jvp", LinalgCapabilityOp::Cholesky, |ctx| {
729 cholesky_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
730 })?;
731 Ok(Some(dyn_from_dense(output_tangent)))
732}
733
734fn cholesky_vjp_t<T>(
735 input: &StructuredTensor<T>,
736 cotangent: &DynTensor,
737 input_grad_mask: &[bool],
738) -> Result<Vec<Option<DynTensor>>>
739where
740 T: LinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
741{
742 if !input_grad_mask[0] {
743 return Ok(vec![None]);
744 }
745 let dense_input = input.to_dense()?;
746 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "cholesky_vjp")?;
747 let grad = dispatch_linalg!(T, "cholesky_vjp", LinalgCapabilityOp::Cholesky, |ctx| {
748 cholesky_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
749 })?;
750 Ok(vec![Some(dyn_from_dense(grad))])
751}
752
753fn lu_primal_t<T>(input: &StructuredTensor<T>, pivot: LuPivot) -> Result<Vec<DynTensor>>
754where
755 T: LuLinalgDispatchValue + DynTensorTyped + Copy,
756{
757 let dense_input = input.to_dense()?;
758 let output = dispatch_linalg!(T, "lu_dyn_value", LinalgCapabilityOp::LuFactor, |ctx| {
759 lu(ctx, &dense_input, pivot).map_err(Error::from)
760 })?;
761 Ok(vec![
762 dyn_from_dense(output.p),
763 dyn_from_dense(output.l),
764 dyn_from_dense(output.u),
765 ])
766}
767
768fn lu_jvp_t<T>(
769 input: &StructuredTensor<T>,
770 tangent: &Option<DynTensor>,
771 pivot: LuPivot,
772) -> Result<Vec<Option<DynTensor>>>
773where
774 T: LuLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
775{
776 if tangent.is_none() {
777 return Ok(vec![None, None, None]);
778 }
779 let dense_input = input.to_dense()?;
780 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "lu_jvp tangent")?;
781 let (_, tangent_output) = dispatch_linalg!(T, "lu_jvp", LinalgCapabilityOp::LuFactor, |ctx| {
782 lu_frule(ctx, &dense_input, &dense_tangent, pivot).map_err(Error::from)
783 })?;
784 Ok(vec![
785 None,
786 Some(dyn_from_dense(tangent_output.l)),
787 Some(dyn_from_dense(tangent_output.u)),
788 ])
789}
790
791fn lu_vjp_t<T>(
792 input: &StructuredTensor<T>,
793 output_cotangents: &[Option<DynTensor>],
794 input_grad_mask: &[bool],
795 pivot: LuPivot,
796) -> Result<Vec<Option<DynTensor>>>
797where
798 T: LuLinalgDispatchValue + DynTensorTyped + Copy,
799{
800 if !input_grad_mask[0] {
801 return Ok(vec![None]);
802 }
803 if output_cotangents
804 .first()
805 .and_then(|value| value.as_ref())
806 .is_some()
807 {
808 return Err(invalid_argument(
809 "lu permutation cotangent is unsupported; permutation output is auxiliary",
810 ));
811 }
812 let dense_input = input.to_dense()?;
813 let cotangent = LuCotangent {
814 l: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[1], "lu_vjp l")?,
815 u: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[2], "lu_vjp u")?,
816 };
817 if cotangent.l.is_none() && cotangent.u.is_none() {
818 return Ok(vec![None]);
819 }
820 let grad = dispatch_linalg!(T, "lu_vjp", LinalgCapabilityOp::LuFactor, |ctx| {
821 lu_rrule(ctx, &dense_input, &cotangent, pivot).map_err(Error::from)
822 })?;
823 Ok(vec![Some(dyn_from_dense(grad))])
824}
825
826fn pinv_primal_t<T>(input: &StructuredTensor<T>, rcond: Option<f64>) -> Result<DynTensor>
827where
828 T: ScaledLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
829 T::Real: KeepCountScalar,
830{
831 let dense_input = input.to_dense()?;
832 let output = dispatch_linalg!(T, "pinv_dyn_value", LinalgCapabilityOp::Pinv, |ctx| {
833 pinv(ctx, &dense_input, rcond).map_err(Error::from)
834 })?;
835 Ok(dyn_from_dense(output))
836}
837
838fn pinv_jvp_t<T>(
839 input: &StructuredTensor<T>,
840 tangent: &Option<DynTensor>,
841 rcond: Option<f64>,
842) -> Result<Option<DynTensor>>
843where
844 T: ScaledLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
845 T::Real: KeepCountScalar,
846{
847 if tangent.is_none() {
848 return Ok(None);
849 }
850 let dense_input = input.to_dense()?;
851 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "pinv_jvp tangent")?;
852 let (_, output_tangent) = dispatch_linalg!(T, "pinv_jvp", LinalgCapabilityOp::Pinv, |ctx| {
853 pinv_frule(ctx, &dense_input, &dense_tangent, rcond).map_err(Error::from)
854 })?;
855 Ok(Some(dyn_from_dense(output_tangent)))
856}
857
858fn pinv_vjp_t<T>(
859 input: &StructuredTensor<T>,
860 cotangent: &DynTensor,
861 input_grad_mask: &[bool],
862 rcond: Option<f64>,
863) -> Result<Vec<Option<DynTensor>>>
864where
865 T: ScaledLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
866 T::Real: KeepCountScalar,
867{
868 if !input_grad_mask[0] {
869 return Ok(vec![None]);
870 }
871 let dense_input = input.to_dense()?;
872 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "pinv_vjp")?;
873 let grad = dispatch_linalg!(T, "pinv_vjp", LinalgCapabilityOp::Pinv, |ctx| {
874 pinv_rrule(ctx, &dense_input, &dense_cotangent, rcond).map_err(Error::from)
875 })?;
876 Ok(vec![Some(dyn_from_dense(grad))])
877}
878
879fn matrix_exp_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
880where
881 T: MatrixExpLinalgDispatchValue + DynTensorTyped + Copy,
882 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
883{
884 let dense_input = input.to_dense()?;
885 let output = dispatch_linalg!(
886 T,
887 "matrix_exp_dyn_value",
888 LinalgCapabilityOp::MatrixExp,
889 |ctx| { matrix_exp(ctx, &dense_input).map_err(Error::from) }
890 )?;
891 Ok(dyn_from_dense(output))
892}
893
894fn matrix_exp_jvp_t<T>(
895 input: &StructuredTensor<T>,
896 tangent: &Option<DynTensor>,
897) -> Result<Option<DynTensor>>
898where
899 T: MatrixExpLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
900 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
901{
902 if tangent.is_none() {
903 return Ok(None);
904 }
905 let dense_input = input.to_dense()?;
906 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "matrix_exp_jvp tangent")?;
907 let (_, output_tangent) =
908 dispatch_linalg!(T, "matrix_exp_jvp", LinalgCapabilityOp::MatrixExp, |ctx| {
909 matrix_exp_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
910 })?;
911 Ok(Some(dyn_from_dense(output_tangent)))
912}
913
914fn matrix_exp_vjp_t<T>(
915 input: &StructuredTensor<T>,
916 cotangent: &DynTensor,
917 input_grad_mask: &[bool],
918) -> Result<Vec<Option<DynTensor>>>
919where
920 T: MatrixExpLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
921 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
922{
923 if !input_grad_mask[0] {
924 return Ok(vec![None]);
925 }
926 let dense_input = input.to_dense()?;
927 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "matrix_exp_vjp")?;
928 let grad = dispatch_linalg!(T, "matrix_exp_vjp", LinalgCapabilityOp::MatrixExp, |ctx| {
929 matrix_exp_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
930 })?;
931 Ok(vec![Some(dyn_from_dense(grad))])
932}
933
934fn eig_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
935where
936 T: RealLinalgRuntimeValue
937 + DynTensorTyped
938 + num_traits::Float
939 + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
940 + Copy,
941 num_complex::Complex<T>: DynTensorTyped
942 + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
943 + CudaLinalgScalar
944 + Copy,
945{
946 let dense_input = input.to_dense()?;
947 let output = dispatch_linalg!(T, "eig_dyn_value", LinalgCapabilityOp::Eig, |ctx| {
948 eig(ctx, &dense_input).map_err(Error::from)
949 })?;
950 Ok(vec![
951 dyn_from_dense(output.values),
952 dyn_from_dense(output.vectors),
953 ])
954}
955
956fn eig_jvp_t<T>(
957 input: &StructuredTensor<T>,
958 tangent: &Option<DynTensor>,
959) -> Result<Vec<Option<DynTensor>>>
960where
961 T: RealLinalgRuntimeValue
962 + DynTensorTyped
963 + Scalar
964 + Zero
965 + num_traits::Float
966 + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
967 + Copy,
968 num_complex::Complex<T>: DynTensorTyped
969 + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
970 + CudaLinalgScalar
971 + Copy,
972{
973 if tangent.is_none() {
974 return Ok(vec![None, None]);
975 }
976 let dense_input = input.to_dense()?;
977 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "eig_jvp tangent")?;
978 let (_, tangent_output) = dispatch_linalg!(T, "eig_jvp", LinalgCapabilityOp::Eig, |ctx| {
979 eig_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
980 })?;
981 Ok(vec![
982 Some(dyn_from_dense(tangent_output.values)),
983 Some(dyn_from_dense(tangent_output.vectors)),
984 ])
985}
986
987fn eig_vjp_t<T>(
988 input: &StructuredTensor<T>,
989 output_cotangents: &[Option<DynTensor>],
990 input_grad_mask: &[bool],
991) -> Result<Vec<Option<DynTensor>>>
992where
993 T: RealLinalgRuntimeValue
994 + DynTensorTyped
995 + num_traits::Float
996 + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
997 + Copy,
998 num_complex::Complex<T>: DynTensorTyped
999 + KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>>
1000 + CudaLinalgScalar
1001 + Copy,
1002{
1003 if !input_grad_mask[0] {
1004 return Ok(vec![None]);
1005 }
1006 let dense_input = input.to_dense()?;
1007 let cotangent = EigCotangent {
1008 values: optional_dense_dyn_tensor_typed::<num_complex::Complex<T>>(
1009 &output_cotangents[0],
1010 "eig_vjp values",
1011 )?,
1012 vectors: optional_dense_dyn_tensor_typed::<num_complex::Complex<T>>(
1013 &output_cotangents[1],
1014 "eig_vjp vectors",
1015 )?,
1016 };
1017 if cotangent.values.is_none() && cotangent.vectors.is_none() {
1018 return Ok(vec![None]);
1019 }
1020 let grad = dispatch_linalg!(T, "eig_vjp", LinalgCapabilityOp::Eig, |ctx| {
1021 eig_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
1022 })?;
1023 Ok(vec![Some(dyn_from_dense(grad))])
1024}
1025
1026fn eigen_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
1027where
1028 T: LinalgRuntimeValue + DynTensorTyped + Copy,
1029 T::Real: DynTensorTyped,
1030{
1031 let dense_input = input.to_dense()?;
1032 let output = dispatch_linalg!(T, "eigen_dyn_value", LinalgCapabilityOp::EigenSym, |ctx| {
1033 eigen(ctx, &dense_input).map_err(Error::from)
1034 })?;
1035 Ok(vec![
1036 dyn_from_dense(output.values),
1037 dyn_from_dense(output.vectors),
1038 ])
1039}
1040
1041fn eigen_jvp_t<T>(
1042 input: &StructuredTensor<T>,
1043 tangent: &Option<DynTensor>,
1044) -> Result<Vec<Option<DynTensor>>>
1045where
1046 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Conjugate + Copy,
1047 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float + DynTensorTyped,
1048{
1049 if tangent.is_none() {
1050 return Ok(vec![None, None]);
1051 }
1052 let dense_input = input.to_dense()?;
1053 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "eigen_jvp tangent")?;
1054 let (_, tangent_output) =
1055 dispatch_linalg!(T, "eigen_jvp", LinalgCapabilityOp::EigenSym, |ctx| {
1056 eigen_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
1057 })?;
1058 Ok(vec![
1059 Some(dyn_from_dense(tangent_output.values)),
1060 Some(dyn_from_dense(tangent_output.vectors)),
1061 ])
1062}
1063
1064fn eigen_vjp_t<T>(
1065 input: &StructuredTensor<T>,
1066 output_cotangents: &[Option<DynTensor>],
1067 input_grad_mask: &[bool],
1068) -> Result<Vec<Option<DynTensor>>>
1069where
1070 T: LinalgRuntimeValue + DynTensorTyped + Conjugate + Copy,
1071 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float + DynTensorTyped,
1072{
1073 if !input_grad_mask[0] {
1074 return Ok(vec![None]);
1075 }
1076 let dense_input = input.to_dense()?;
1077 let cotangent = EigenCotangent {
1078 values: optional_dense_dyn_tensor_typed::<T::Real>(
1079 &output_cotangents[0],
1080 "eigen_vjp values",
1081 )?,
1082 vectors: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[1], "eigen_vjp vectors")?,
1083 };
1084 if cotangent.values.is_none() && cotangent.vectors.is_none() {
1085 return Ok(vec![None]);
1086 }
1087 let grad = dispatch_linalg!(T, "eigen_vjp", LinalgCapabilityOp::EigenSym, |ctx| {
1088 eigen_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
1089 })?;
1090 Ok(vec![Some(dyn_from_dense(grad))])
1091}
1092
1093fn norm_primal_t<T>(input: &StructuredTensor<T>, kind: NormKind) -> Result<DynTensor>
1094where
1095 T: NormLinalgDispatchValue + DynTensorTyped + Copy,
1096 <T as tenferro_linalg::LinalgScalar>::Real: DynTensorTyped,
1097{
1098 let dense_input = input.to_dense()?;
1099 let output = dispatch_linalg!(T, "norm_dyn_value", LinalgCapabilityOp::Norm, |ctx| {
1100 norm(ctx, &dense_input, kind).map_err(Error::from)
1101 })?;
1102 Ok(dyn_from_dense(output))
1103}
1104
1105fn norm_jvp_t<T>(
1106 input: &StructuredTensor<T>,
1107 tangent: &Option<DynTensor>,
1108 kind: NormKind,
1109) -> Result<Option<DynTensor>>
1110where
1111 T: RealLinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
1112{
1113 if tangent.is_none() {
1114 return Ok(None);
1115 }
1116 let dense_input = input.to_dense()?;
1117 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "norm_jvp tangent")?;
1118 let (_, output_tangent) = dispatch_linalg!(T, "norm_jvp", LinalgCapabilityOp::Norm, |ctx| {
1119 norm_frule(ctx, &dense_input, &dense_tangent, kind).map_err(Error::from)
1120 })?;
1121 Ok(Some(dyn_from_dense(output_tangent)))
1122}
1123
1124fn norm_jvp_c32_t(
1125 input: &StructuredTensor<Complex32>,
1126 tangent: &Option<DynTensor>,
1127 kind: NormKind,
1128) -> Result<Option<DynTensor>> {
1129 if tangent.is_none() {
1130 return Ok(None);
1131 }
1132 let dense_input = input.to_dense()?;
1133 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "norm_jvp tangent")?;
1134 let (_, output_tangent) =
1135 dispatch_linalg!(Complex32, "norm_jvp", LinalgCapabilityOp::Norm, |ctx| {
1136 norm_frule_complex::<Complex32, f32, _>(ctx, &dense_input, &dense_tangent, kind)
1137 .map_err(Error::from)
1138 })?;
1139 Ok(Some(dyn_from_dense(output_tangent)))
1140}
1141
1142fn norm_jvp_c64_t(
1143 input: &StructuredTensor<Complex64>,
1144 tangent: &Option<DynTensor>,
1145 kind: NormKind,
1146) -> Result<Option<DynTensor>> {
1147 if tangent.is_none() {
1148 return Ok(None);
1149 }
1150 let dense_input = input.to_dense()?;
1151 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "norm_jvp tangent")?;
1152 let (_, output_tangent) =
1153 dispatch_linalg!(Complex64, "norm_jvp", LinalgCapabilityOp::Norm, |ctx| {
1154 norm_frule_complex::<Complex64, f64, _>(ctx, &dense_input, &dense_tangent, kind)
1155 .map_err(Error::from)
1156 })?;
1157 Ok(Some(dyn_from_dense(output_tangent)))
1158}
1159
1160fn norm_vjp_c32_t(
1161 input: &StructuredTensor<Complex32>,
1162 cotangent: &DynTensor,
1163 kind: NormKind,
1164 input_grad_mask: &[bool],
1165) -> Result<Vec<Option<DynTensor>>> {
1166 if !input_grad_mask[0] {
1167 return Ok(vec![None]);
1168 }
1169 let dense_input = input.to_dense()?;
1170 let dense_cotangent = dense_dyn_tensor_typed::<f32>(cotangent, "norm_vjp")?;
1171 let grad = dispatch_linalg!(Complex32, "norm_vjp", LinalgCapabilityOp::Norm, |ctx| {
1172 norm_rrule_complex::<Complex32, f32, _>(ctx, &dense_input, &dense_cotangent, kind)
1173 .map_err(Error::from)
1174 })?;
1175 Ok(vec![Some(dyn_from_dense(grad))])
1176}
1177
1178fn norm_vjp_t<T>(
1179 input: &StructuredTensor<T>,
1180 cotangent: &DynTensor,
1181 kind: NormKind,
1182 input_grad_mask: &[bool],
1183) -> Result<Vec<Option<DynTensor>>>
1184where
1185 T: RealLinalgRuntimeValue + DynTensorTyped + Copy,
1186{
1187 if !input_grad_mask[0] {
1188 return Ok(vec![None]);
1189 }
1190 let dense_input = input.to_dense()?;
1191 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "norm_vjp")?;
1192 let grad = dispatch_linalg!(T, "norm_vjp", LinalgCapabilityOp::Norm, |ctx| {
1193 norm_rrule(ctx, &dense_input, &dense_cotangent, kind).map_err(Error::from)
1194 })?;
1195 Ok(vec![Some(dyn_from_dense(grad))])
1196}
1197
1198fn norm_vjp_c64_t(
1199 input: &StructuredTensor<Complex64>,
1200 cotangent: &DynTensor,
1201 kind: NormKind,
1202 input_grad_mask: &[bool],
1203) -> Result<Vec<Option<DynTensor>>> {
1204 if !input_grad_mask[0] {
1205 return Ok(vec![None]);
1206 }
1207 let dense_input = input.to_dense()?;
1208 let dense_cotangent = dense_dyn_tensor_typed::<f64>(cotangent, "norm_vjp")?;
1209 let grad = dispatch_linalg!(Complex64, "norm_vjp", LinalgCapabilityOp::Norm, |ctx| {
1210 norm_rrule_complex::<Complex64, f64, _>(ctx, &dense_input, &dense_cotangent, kind)
1211 .map_err(Error::from)
1212 })?;
1213 Ok(vec![Some(dyn_from_dense(grad))])
1214}
1215
1216fn det_primal_t<T>(input: &StructuredTensor<T>) -> Result<DynTensor>
1217where
1218 T: ScaledLinalgDispatchValue + DynTensorTyped + Copy,
1219{
1220 let dense_input = input.to_dense()?;
1221 let output = dispatch_linalg!(T, "det_dyn_value", LinalgCapabilityOp::Det, |ctx| {
1222 det(ctx, &dense_input).map_err(Error::from)
1223 })?;
1224 Ok(dyn_from_dense(output))
1225}
1226
1227fn det_jvp_t<T>(
1228 input: &StructuredTensor<T>,
1229 tangent: &Option<DynTensor>,
1230) -> Result<Option<DynTensor>>
1231where
1232 T: ScaledLinalgDispatchValue + DynTensorTyped + Scalar + Zero + Copy,
1233{
1234 if tangent.is_none() {
1235 return Ok(None);
1236 }
1237 let dense_input = input.to_dense()?;
1238 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "det_jvp tangent")?;
1239 let (_, output_tangent) = dispatch_linalg!(T, "det_jvp", LinalgCapabilityOp::Det, |ctx| {
1240 det_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
1241 })?;
1242 Ok(Some(dyn_from_dense(output_tangent)))
1243}
1244
1245fn det_vjp_t<T>(
1246 input: &StructuredTensor<T>,
1247 cotangent: &DynTensor,
1248 input_grad_mask: &[bool],
1249) -> Result<Vec<Option<DynTensor>>>
1250where
1251 T: ScaledLinalgDispatchValue + DynTensorTyped + Conjugate + Copy,
1252{
1253 if !input_grad_mask[0] {
1254 return Ok(vec![None]);
1255 }
1256 let dense_input = input.to_dense()?;
1257 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "det_vjp")?;
1258 let grad = dispatch_linalg!(T, "det_vjp", LinalgCapabilityOp::Det, |ctx| {
1259 det_rrule(ctx, &dense_input, &dense_cotangent).map_err(Error::from)
1260 })?;
1261 Ok(vec![Some(dyn_from_dense(grad))])
1262}
1263
1264fn qr_primal_t<T>(input: &StructuredTensor<T>) -> Result<Vec<DynTensor>>
1265where
1266 T: LinalgRuntimeValue + DynTensorTyped + Copy,
1267{
1268 let dense_input = input.to_dense()?;
1269 let output = dispatch_linalg!(T, "qr_dyn_value", LinalgCapabilityOp::Qr, |ctx| {
1270 qr(ctx, &dense_input).map_err(Error::from)
1271 })?;
1272 Ok(vec![dyn_from_dense(output.q), dyn_from_dense(output.r)])
1273}
1274
1275fn qr_jvp_t<T>(
1276 input: &StructuredTensor<T>,
1277 tangent: &Option<DynTensor>,
1278) -> Result<Vec<Option<DynTensor>>>
1279where
1280 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
1281{
1282 if tangent.is_none() {
1283 return Ok(vec![None, None]);
1284 }
1285 let dense_input = input.to_dense()?;
1286 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "qr_jvp tangent")?;
1287 let (_, tangent_output) = dispatch_linalg!(T, "qr_jvp", LinalgCapabilityOp::Qr, |ctx| {
1288 qr_frule(ctx, &dense_input, &dense_tangent).map_err(Error::from)
1289 })?;
1290 Ok(vec![
1291 Some(dyn_from_dense(tangent_output.q)),
1292 Some(dyn_from_dense(tangent_output.r)),
1293 ])
1294}
1295
1296fn qr_vjp_t<T>(
1297 input: &StructuredTensor<T>,
1298 output_cotangents: &[Option<DynTensor>],
1299 input_grad_mask: &[bool],
1300) -> Result<Vec<Option<DynTensor>>>
1301where
1302 T: LinalgRuntimeValue + DynTensorTyped + Copy,
1303{
1304 if !input_grad_mask[0] {
1305 return Ok(vec![None]);
1306 }
1307 let dense_input = input.to_dense()?;
1308 let cotangent = QrCotangent {
1309 q: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[0], "qr_vjp q")?,
1310 r: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[1], "qr_vjp r")?,
1311 };
1312 if cotangent.q.is_none() && cotangent.r.is_none() {
1313 return Ok(vec![None]);
1314 }
1315 let grad = dispatch_linalg!(T, "qr_vjp", LinalgCapabilityOp::Qr, |ctx| {
1316 qr_rrule(ctx, &dense_input, &cotangent).map_err(Error::from)
1317 })?;
1318 Ok(vec![Some(dyn_from_dense(grad))])
1319}
1320
1321fn svd_primal_t<T>(
1322 input: &StructuredTensor<T>,
1323 options: Option<&SvdOptions>,
1324) -> Result<Vec<DynTensor>>
1325where
1326 T: LinalgRuntimeValue + DynTensorTyped + Copy,
1327 T::Real: DynTensorTyped + Copy + tenferro_tensor::KeepCountScalar,
1328{
1329 let dense_input = input.to_dense()?;
1330 let output = dispatch_linalg!(T, "svd_dyn_value", LinalgCapabilityOp::ThinSvd, |ctx| {
1331 svd(ctx, &dense_input, options).map_err(Error::from)
1332 })?;
1333 Ok(vec![
1334 dyn_from_dense(output.u),
1335 dyn_from_dense(output.s),
1336 dyn_from_dense(output.vt),
1337 ])
1338}
1339
1340fn svd_jvp_t<T>(
1341 input: &StructuredTensor<T>,
1342 tangent: &Option<DynTensor>,
1343 options: Option<&SvdOptions>,
1344) -> Result<Vec<Option<DynTensor>>>
1345where
1346 T: LinalgRuntimeValue + DynTensorTyped + Scalar + Zero + Copy,
1347 T::Real: DynTensorTyped + Copy + num_traits::Float + tenferro_tensor::KeepCountScalar,
1348{
1349 if tangent.is_none() {
1350 return Ok(vec![None, None, None]);
1351 }
1352 let dense_input = input.to_dense()?;
1353 let dense_tangent = dense_optional_or_zero(tangent, &dense_input, "svd_jvp tangent")?;
1354 let (_, tangent_output) = dispatch_linalg!(T, "svd_jvp", LinalgCapabilityOp::ThinSvd, |ctx| {
1355 svd_frule(ctx, &dense_input, &dense_tangent, options).map_err(Error::from)
1356 })?;
1357 Ok(vec![
1358 Some(dyn_from_dense(tangent_output.u)),
1359 Some(dyn_from_dense(tangent_output.s)),
1360 Some(dyn_from_dense(tangent_output.vt)),
1361 ])
1362}
1363
1364fn svd_vjp_t<T>(
1365 input: &StructuredTensor<T>,
1366 output_cotangents: &[Option<DynTensor>],
1367 input_grad_mask: &[bool],
1368 options: Option<&SvdOptions>,
1369) -> Result<Vec<Option<DynTensor>>>
1370where
1371 T: LinalgRuntimeValue + DynTensorTyped + Copy,
1372 T::Real: DynTensorTyped + Copy + num_traits::Float + tenferro_tensor::KeepCountScalar,
1373{
1374 if !input_grad_mask[0] {
1375 return Ok(vec![None]);
1376 }
1377 let dense_input = input.to_dense()?;
1378 let cotangent = SvdCotangent {
1379 u: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[0], "svd_vjp u")?,
1380 s: optional_dense_dyn_tensor_typed::<T::Real>(&output_cotangents[1], "svd_vjp s")?,
1381 vt: optional_dense_dyn_tensor_typed::<T>(&output_cotangents[2], "svd_vjp vt")?,
1382 };
1383 if cotangent.u.is_none() && cotangent.s.is_none() && cotangent.vt.is_none() {
1384 return Ok(vec![None]);
1385 }
1386 let grad = dispatch_linalg!(T, "svd_vjp", LinalgCapabilityOp::ThinSvd, |ctx| {
1387 svd_rrule(ctx, &dense_input, &cotangent, options).map_err(Error::from)
1388 })?;
1389 Ok(vec![Some(dyn_from_dense(grad))])
1390}
1391
1392impl NormOp {
1393 pub fn new(kind: NormKind) -> Self {
1394 Self { kind }
1395 }
1396}
1397
1398impl SolveTriangularOp {
1399 pub fn new(upper: bool) -> Self {
1400 Self { upper }
1401 }
1402}
1403
1404impl LuOp {
1405 pub fn new(pivot: LuPivot) -> Self {
1406 Self { pivot }
1407 }
1408}
1409
1410impl SvdOp {
1411 pub fn new(options: Option<SvdOptions>) -> Self {
1412 Self { options }
1413 }
1414}
1415
1416impl PInvOp {
1417 pub fn new(rcond: Option<f64>) -> Self {
1418 Self { rcond }
1419 }
1420}
1421
1422impl LinearizableOp<DynTensor> for SolveOp {
1423 type Linearized = SolveLinearized;
1424
1425 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1426 let output = match (inputs[0], inputs[1]) {
1427 (DynTensor::F32(a), DynTensor::F32(b)) => solve_primal_t::<f32>(a, b),
1428 (DynTensor::F64(a), DynTensor::F64(b)) => solve_primal_t::<f64>(a, b),
1429 (DynTensor::C32(a), DynTensor::C32(b)) => solve_primal_t::<Complex32>(a, b),
1430 (DynTensor::C64(a), DynTensor::C64(b)) => solve_primal_t::<Complex64>(a, b),
1431 _ => Err(invalid_argument("solve requires matching dtypes")),
1432 }
1433 .map_err(into_ad_error)?;
1434 Ok(vec![output])
1435 }
1436
1437 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1438 Ok(differentiable_schema(2))
1439 }
1440
1441 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1442 Ok(differentiable_schema(1))
1443 }
1444
1445 fn linearize(
1446 &self,
1447 inputs: &[&DynTensor],
1448 _outputs: &[DynTensor],
1449 ) -> AdResult<Self::Linearized> {
1450 Ok(SolveLinearized {
1451 a: inputs[0].clone(),
1452 b: inputs[1].clone(),
1453 })
1454 }
1455
1456 fn checkpoint_hint(&self) -> CheckpointHint {
1457 CheckpointHint::ExpensiveReplay
1458 }
1459}
1460
1461impl LinearizedOp<DynTensor> for SolveLinearized {
1462 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1463 let tangent = match (&self.a, &self.b) {
1464 (DynTensor::F32(a), DynTensor::F32(b)) => solve_jvp_t::<f32>(a, b, input_tangents),
1465 (DynTensor::F64(a), DynTensor::F64(b)) => solve_jvp_t::<f64>(a, b, input_tangents),
1466 (DynTensor::C32(a), DynTensor::C32(b)) => {
1467 solve_jvp_t::<Complex32>(a, b, input_tangents)
1468 }
1469 (DynTensor::C64(a), DynTensor::C64(b)) => {
1470 solve_jvp_t::<Complex64>(a, b, input_tangents)
1471 }
1472 _ => Err(invalid_argument(
1473 "solve linearization requires matching dtypes",
1474 )),
1475 }
1476 .map_err(into_ad_error)?;
1477 Ok(vec![tangent])
1478 }
1479
1480 fn vjp(
1481 &self,
1482 output_cotangents: &[Option<DynTensor>],
1483 input_grad_mask: &[bool],
1484 ) -> AdResult<Vec<Option<DynTensor>>> {
1485 let Some(cotangent) = output_cotangents[0].as_ref() else {
1486 return Ok(vec![None, None]);
1487 };
1488 match (&self.a, &self.b) {
1489 (DynTensor::F32(a), DynTensor::F32(b)) => {
1490 solve_vjp_t::<f32>(a, b, cotangent, input_grad_mask)
1491 }
1492 (DynTensor::F64(a), DynTensor::F64(b)) => {
1493 solve_vjp_t::<f64>(a, b, cotangent, input_grad_mask)
1494 }
1495 (DynTensor::C32(a), DynTensor::C32(b)) => {
1496 solve_vjp_t::<Complex32>(a, b, cotangent, input_grad_mask)
1497 }
1498 (DynTensor::C64(a), DynTensor::C64(b)) => {
1499 solve_vjp_t::<Complex64>(a, b, cotangent, input_grad_mask)
1500 }
1501 _ => Err(invalid_argument(
1502 "solve linearization requires matching dtypes",
1503 )),
1504 }
1505 .map_err(into_ad_error)
1506 }
1507}
1508
1509impl LinearizableOp<DynTensor> for LstsqOp {
1510 type Linearized = LstsqLinearized;
1511
1512 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1513 match (inputs[0], inputs[1]) {
1514 (DynTensor::F32(a), DynTensor::F32(b)) => lstsq_primal_t::<f32>(a, b),
1515 (DynTensor::F64(a), DynTensor::F64(b)) => lstsq_primal_t::<f64>(a, b),
1516 (DynTensor::C32(_), DynTensor::C32(_)) | (DynTensor::C64(_), DynTensor::C64(_)) => Err(
1517 invalid_argument("lstsq AD currently supports real dtypes only"),
1518 ),
1519 _ => Err(invalid_argument("lstsq requires matching dtypes")),
1520 }
1521 .map_err(into_ad_error)
1522 }
1523
1524 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1525 Ok(differentiable_schema(2))
1526 }
1527
1528 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1529 Ok(lstsq_output_schema())
1530 }
1531
1532 fn linearize(
1533 &self,
1534 inputs: &[&DynTensor],
1535 _outputs: &[DynTensor],
1536 ) -> AdResult<Self::Linearized> {
1537 Ok(LstsqLinearized {
1538 a: inputs[0].clone(),
1539 b: inputs[1].clone(),
1540 })
1541 }
1542
1543 fn checkpoint_hint(&self) -> CheckpointHint {
1544 CheckpointHint::ExpensiveReplay
1545 }
1546}
1547
1548impl LinearizedOp<DynTensor> for LstsqLinearized {
1549 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1550 match (&self.a, &self.b) {
1551 (DynTensor::F32(a), DynTensor::F32(b)) => lstsq_jvp_t::<f32>(a, b, input_tangents),
1552 (DynTensor::F64(a), DynTensor::F64(b)) => lstsq_jvp_t::<f64>(a, b, input_tangents),
1553 (DynTensor::C32(_), DynTensor::C32(_)) | (DynTensor::C64(_), DynTensor::C64(_)) => Err(
1554 invalid_argument("lstsq AD currently supports real dtypes only"),
1555 ),
1556 _ => Err(invalid_argument(
1557 "lstsq linearization requires matching dtypes",
1558 )),
1559 }
1560 .map_err(into_ad_error)
1561 }
1562
1563 fn vjp(
1564 &self,
1565 output_cotangents: &[Option<DynTensor>],
1566 input_grad_mask: &[bool],
1567 ) -> AdResult<Vec<Option<DynTensor>>> {
1568 match (&self.a, &self.b) {
1569 (DynTensor::F32(a), DynTensor::F32(b)) => {
1570 lstsq_vjp_t::<f32>(a, b, output_cotangents, input_grad_mask)
1571 }
1572 (DynTensor::F64(a), DynTensor::F64(b)) => {
1573 lstsq_vjp_t::<f64>(a, b, output_cotangents, input_grad_mask)
1574 }
1575 (DynTensor::C32(_), DynTensor::C32(_)) | (DynTensor::C64(_), DynTensor::C64(_)) => Err(
1576 invalid_argument("lstsq AD currently supports real dtypes only"),
1577 ),
1578 _ => Err(invalid_argument(
1579 "lstsq linearization requires matching dtypes",
1580 )),
1581 }
1582 .map_err(into_ad_error)
1583 }
1584}
1585
1586impl LinearizableOp<DynTensor> for NormOp {
1587 type Linearized = NormLinearized;
1588
1589 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1590 let output = match inputs[0] {
1591 DynTensor::F32(input) => norm_primal_t::<f32>(input, self.kind),
1592 DynTensor::F64(input) => norm_primal_t::<f64>(input, self.kind),
1593 DynTensor::C32(input) => norm_primal_t::<Complex32>(input, self.kind),
1594 DynTensor::C64(input) => norm_primal_t::<Complex64>(input, self.kind),
1595 }
1596 .map_err(into_ad_error)?;
1597 Ok(vec![output])
1598 }
1599
1600 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1601 Ok(differentiable_schema(1))
1602 }
1603
1604 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1605 Ok(differentiable_schema(1))
1606 }
1607
1608 fn linearize(
1609 &self,
1610 inputs: &[&DynTensor],
1611 _outputs: &[DynTensor],
1612 ) -> AdResult<Self::Linearized> {
1613 Ok(NormLinearized {
1614 input: inputs[0].clone(),
1615 kind: self.kind,
1616 })
1617 }
1618
1619 fn checkpoint_hint(&self) -> CheckpointHint {
1620 CheckpointHint::CheapReplay
1621 }
1622}
1623
1624impl LinearizableOp<DynTensor> for SolveTriangularOp {
1625 type Linearized = SolveTriangularLinearized;
1626
1627 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1628 let output = match (inputs[0], inputs[1]) {
1629 (DynTensor::F32(a), DynTensor::F32(b)) => {
1630 solve_triangular_primal_t::<f32>(a, b, self.upper)
1631 }
1632 (DynTensor::F64(a), DynTensor::F64(b)) => {
1633 solve_triangular_primal_t::<f64>(a, b, self.upper)
1634 }
1635 (DynTensor::C32(a), DynTensor::C32(b)) => {
1636 solve_triangular_primal_t::<Complex32>(a, b, self.upper)
1637 }
1638 (DynTensor::C64(a), DynTensor::C64(b)) => {
1639 solve_triangular_primal_t::<Complex64>(a, b, self.upper)
1640 }
1641 _ => Err(invalid_argument(
1642 "solve_triangular requires matching dtypes",
1643 )),
1644 }
1645 .map_err(into_ad_error)?;
1646 Ok(vec![output])
1647 }
1648
1649 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1650 Ok(differentiable_schema(2))
1651 }
1652
1653 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1654 Ok(differentiable_schema(1))
1655 }
1656
1657 fn linearize(
1658 &self,
1659 inputs: &[&DynTensor],
1660 _outputs: &[DynTensor],
1661 ) -> AdResult<Self::Linearized> {
1662 Ok(SolveTriangularLinearized {
1663 a: inputs[0].clone(),
1664 b: inputs[1].clone(),
1665 upper: self.upper,
1666 })
1667 }
1668
1669 fn checkpoint_hint(&self) -> CheckpointHint {
1670 CheckpointHint::ExpensiveReplay
1671 }
1672}
1673
1674impl LinearizedOp<DynTensor> for SolveTriangularLinearized {
1675 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1676 let tangent = match (&self.a, &self.b) {
1677 (DynTensor::F32(a), DynTensor::F32(b)) => {
1678 solve_triangular_jvp_t::<f32>(a, b, input_tangents, self.upper)
1679 }
1680 (DynTensor::F64(a), DynTensor::F64(b)) => {
1681 solve_triangular_jvp_t::<f64>(a, b, input_tangents, self.upper)
1682 }
1683 (DynTensor::C32(a), DynTensor::C32(b)) => {
1684 solve_triangular_jvp_t::<Complex32>(a, b, input_tangents, self.upper)
1685 }
1686 (DynTensor::C64(a), DynTensor::C64(b)) => {
1687 solve_triangular_jvp_t::<Complex64>(a, b, input_tangents, self.upper)
1688 }
1689 _ => Err(invalid_argument(
1690 "solve_triangular linearization requires matching dtypes",
1691 )),
1692 }
1693 .map_err(into_ad_error)?;
1694 Ok(vec![tangent])
1695 }
1696
1697 fn vjp(
1698 &self,
1699 output_cotangents: &[Option<DynTensor>],
1700 input_grad_mask: &[bool],
1701 ) -> AdResult<Vec<Option<DynTensor>>> {
1702 let Some(cotangent) = output_cotangents[0].as_ref() else {
1703 return Ok(vec![None, None]);
1704 };
1705 match (&self.a, &self.b) {
1706 (DynTensor::F32(a), DynTensor::F32(b)) => {
1707 solve_triangular_vjp_t::<f32>(a, b, cotangent, input_grad_mask, self.upper)
1708 }
1709 (DynTensor::F64(a), DynTensor::F64(b)) => {
1710 solve_triangular_vjp_t::<f64>(a, b, cotangent, input_grad_mask, self.upper)
1711 }
1712 (DynTensor::C32(a), DynTensor::C32(b)) => {
1713 solve_triangular_vjp_t::<Complex32>(a, b, cotangent, input_grad_mask, self.upper)
1714 }
1715 (DynTensor::C64(a), DynTensor::C64(b)) => {
1716 solve_triangular_vjp_t::<Complex64>(a, b, cotangent, input_grad_mask, self.upper)
1717 }
1718 _ => Err(invalid_argument(
1719 "solve_triangular linearization requires matching dtypes",
1720 )),
1721 }
1722 .map_err(into_ad_error)
1723 }
1724}
1725
1726impl LinearizedOp<DynTensor> for NormLinearized {
1727 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1728 let tangent = match &self.input {
1729 DynTensor::F32(input) => norm_jvp_t::<f32>(input, &input_tangents[0], self.kind),
1730 DynTensor::F64(input) => norm_jvp_t::<f64>(input, &input_tangents[0], self.kind),
1731 DynTensor::C32(input) => norm_jvp_c32_t(input, &input_tangents[0], self.kind),
1732 DynTensor::C64(input) => norm_jvp_c64_t(input, &input_tangents[0], self.kind),
1733 }
1734 .map_err(into_ad_error)?;
1735 Ok(vec![tangent])
1736 }
1737
1738 fn vjp(
1739 &self,
1740 output_cotangents: &[Option<DynTensor>],
1741 input_grad_mask: &[bool],
1742 ) -> AdResult<Vec<Option<DynTensor>>> {
1743 let Some(cotangent) = output_cotangents[0].as_ref() else {
1744 return Ok(vec![None]);
1745 };
1746 match &self.input {
1747 DynTensor::F32(input) => {
1748 norm_vjp_t::<f32>(input, cotangent, self.kind, input_grad_mask)
1749 }
1750 DynTensor::F64(input) => {
1751 norm_vjp_t::<f64>(input, cotangent, self.kind, input_grad_mask)
1752 }
1753 DynTensor::C32(input) => norm_vjp_c32_t(input, cotangent, self.kind, input_grad_mask),
1754 DynTensor::C64(input) => norm_vjp_c64_t(input, cotangent, self.kind, input_grad_mask),
1755 }
1756 .map_err(into_ad_error)
1757 }
1758}
1759
1760impl LinearizableOp<DynTensor> for InvOp {
1761 type Linearized = InvLinearized;
1762
1763 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1764 let output = match inputs[0] {
1765 DynTensor::F32(input) => inv_primal_t::<f32>(input),
1766 DynTensor::F64(input) => inv_primal_t::<f64>(input),
1767 DynTensor::C32(input) => inv_primal_t::<Complex32>(input),
1768 DynTensor::C64(input) => inv_primal_t::<Complex64>(input),
1769 }
1770 .map_err(into_ad_error)?;
1771 Ok(vec![output])
1772 }
1773
1774 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1775 Ok(differentiable_schema(1))
1776 }
1777
1778 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1779 Ok(differentiable_schema(1))
1780 }
1781
1782 fn linearize(
1783 &self,
1784 inputs: &[&DynTensor],
1785 _outputs: &[DynTensor],
1786 ) -> AdResult<Self::Linearized> {
1787 Ok(InvLinearized {
1788 input: inputs[0].clone(),
1789 })
1790 }
1791
1792 fn checkpoint_hint(&self) -> CheckpointHint {
1793 CheckpointHint::ExpensiveReplay
1794 }
1795}
1796
1797impl LinearizedOp<DynTensor> for InvLinearized {
1798 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1799 let tangent = match &self.input {
1800 DynTensor::F32(input) => inv_jvp_t::<f32>(input, &input_tangents[0]),
1801 DynTensor::F64(input) => inv_jvp_t::<f64>(input, &input_tangents[0]),
1802 DynTensor::C32(input) => inv_jvp_t::<Complex32>(input, &input_tangents[0]),
1803 DynTensor::C64(input) => inv_jvp_t::<Complex64>(input, &input_tangents[0]),
1804 }
1805 .map_err(into_ad_error)?;
1806 Ok(vec![tangent])
1807 }
1808
1809 fn vjp(
1810 &self,
1811 output_cotangents: &[Option<DynTensor>],
1812 input_grad_mask: &[bool],
1813 ) -> AdResult<Vec<Option<DynTensor>>> {
1814 let Some(cotangent) = output_cotangents[0].as_ref() else {
1815 return Ok(vec![None]);
1816 };
1817 match &self.input {
1818 DynTensor::F32(input) => inv_vjp_t::<f32>(input, cotangent, input_grad_mask),
1819 DynTensor::F64(input) => inv_vjp_t::<f64>(input, cotangent, input_grad_mask),
1820 DynTensor::C32(input) => inv_vjp_t::<Complex32>(input, cotangent, input_grad_mask),
1821 DynTensor::C64(input) => inv_vjp_t::<Complex64>(input, cotangent, input_grad_mask),
1822 }
1823 .map_err(into_ad_error)
1824 }
1825}
1826
1827impl LinearizableOp<DynTensor> for SlogdetOp {
1828 type Linearized = SlogdetLinearized;
1829
1830 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1831 match inputs[0] {
1832 DynTensor::F32(input) => slogdet_primal_t::<f32>(input),
1833 DynTensor::F64(input) => slogdet_primal_t::<f64>(input),
1834 DynTensor::C32(input) => slogdet_primal_t::<Complex32>(input),
1835 DynTensor::C64(input) => slogdet_primal_t::<Complex64>(input),
1836 }
1837 .map_err(into_ad_error)
1838 }
1839
1840 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1841 Ok(differentiable_schema(1))
1842 }
1843
1844 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1845 Ok(slogdet_output_schema())
1846 }
1847
1848 fn linearize(
1849 &self,
1850 inputs: &[&DynTensor],
1851 _outputs: &[DynTensor],
1852 ) -> AdResult<Self::Linearized> {
1853 Ok(SlogdetLinearized {
1854 input: inputs[0].clone(),
1855 })
1856 }
1857
1858 fn checkpoint_hint(&self) -> CheckpointHint {
1859 CheckpointHint::ExpensiveReplay
1860 }
1861}
1862
1863impl LinearizedOp<DynTensor> for SlogdetLinearized {
1864 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1865 match &self.input {
1866 DynTensor::F32(input) => slogdet_jvp_t::<f32>(input, &input_tangents[0]),
1867 DynTensor::F64(input) => slogdet_jvp_t::<f64>(input, &input_tangents[0]),
1868 DynTensor::C32(input) => slogdet_jvp_t::<Complex32>(input, &input_tangents[0]),
1869 DynTensor::C64(input) => slogdet_jvp_t::<Complex64>(input, &input_tangents[0]),
1870 }
1871 .map_err(into_ad_error)
1872 }
1873
1874 fn vjp(
1875 &self,
1876 output_cotangents: &[Option<DynTensor>],
1877 input_grad_mask: &[bool],
1878 ) -> AdResult<Vec<Option<DynTensor>>> {
1879 match &self.input {
1880 DynTensor::F32(input) => {
1881 slogdet_vjp_t::<f32>(input, output_cotangents, input_grad_mask)
1882 }
1883 DynTensor::F64(input) => {
1884 slogdet_vjp_t::<f64>(input, output_cotangents, input_grad_mask)
1885 }
1886 DynTensor::C32(input) => {
1887 slogdet_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask)
1888 }
1889 DynTensor::C64(input) => {
1890 slogdet_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask)
1891 }
1892 }
1893 .map_err(into_ad_error)
1894 }
1895}
1896
1897impl LinearizableOp<DynTensor> for CholeskyOp {
1898 type Linearized = CholeskyLinearized;
1899
1900 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1901 let output = match inputs[0] {
1902 DynTensor::F32(input) => cholesky_primal_t::<f32>(input),
1903 DynTensor::F64(input) => cholesky_primal_t::<f64>(input),
1904 DynTensor::C32(input) => cholesky_primal_t::<Complex32>(input),
1905 DynTensor::C64(input) => cholesky_primal_t::<Complex64>(input),
1906 }
1907 .map_err(into_ad_error)?;
1908 Ok(vec![output])
1909 }
1910
1911 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1912 Ok(differentiable_schema(1))
1913 }
1914
1915 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1916 Ok(differentiable_schema(1))
1917 }
1918
1919 fn linearize(
1920 &self,
1921 inputs: &[&DynTensor],
1922 _outputs: &[DynTensor],
1923 ) -> AdResult<Self::Linearized> {
1924 Ok(CholeskyLinearized {
1925 input: inputs[0].clone(),
1926 })
1927 }
1928
1929 fn checkpoint_hint(&self) -> CheckpointHint {
1930 CheckpointHint::ExpensiveReplay
1931 }
1932}
1933
1934impl LinearizableOp<DynTensor> for LuOp {
1935 type Linearized = LuLinearized;
1936
1937 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
1938 match inputs[0] {
1939 DynTensor::F32(input) => lu_primal_t::<f32>(input, self.pivot),
1940 DynTensor::F64(input) => lu_primal_t::<f64>(input, self.pivot),
1941 DynTensor::C32(input) => lu_primal_t::<Complex32>(input, self.pivot),
1942 DynTensor::C64(input) => lu_primal_t::<Complex64>(input, self.pivot),
1943 }
1944 .map_err(into_ad_error)
1945 }
1946
1947 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
1948 Ok(differentiable_schema(1))
1949 }
1950
1951 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
1952 Ok(lu_output_schema())
1953 }
1954
1955 fn linearize(
1956 &self,
1957 inputs: &[&DynTensor],
1958 _outputs: &[DynTensor],
1959 ) -> AdResult<Self::Linearized> {
1960 Ok(LuLinearized {
1961 input: inputs[0].clone(),
1962 pivot: self.pivot,
1963 })
1964 }
1965
1966 fn checkpoint_hint(&self) -> CheckpointHint {
1967 CheckpointHint::ExpensiveReplay
1968 }
1969}
1970
1971impl LinearizedOp<DynTensor> for LuLinearized {
1972 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
1973 match &self.input {
1974 DynTensor::F32(input) => lu_jvp_t::<f32>(input, &input_tangents[0], self.pivot),
1975 DynTensor::F64(input) => lu_jvp_t::<f64>(input, &input_tangents[0], self.pivot),
1976 DynTensor::C32(input) => lu_jvp_t::<Complex32>(input, &input_tangents[0], self.pivot),
1977 DynTensor::C64(input) => lu_jvp_t::<Complex64>(input, &input_tangents[0], self.pivot),
1978 }
1979 .map_err(into_ad_error)
1980 }
1981
1982 fn vjp(
1983 &self,
1984 output_cotangents: &[Option<DynTensor>],
1985 input_grad_mask: &[bool],
1986 ) -> AdResult<Vec<Option<DynTensor>>> {
1987 match &self.input {
1988 DynTensor::F32(input) => {
1989 lu_vjp_t::<f32>(input, output_cotangents, input_grad_mask, self.pivot)
1990 }
1991 DynTensor::F64(input) => {
1992 lu_vjp_t::<f64>(input, output_cotangents, input_grad_mask, self.pivot)
1993 }
1994 DynTensor::C32(input) => {
1995 lu_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask, self.pivot)
1996 }
1997 DynTensor::C64(input) => {
1998 lu_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask, self.pivot)
1999 }
2000 }
2001 .map_err(into_ad_error)
2002 }
2003}
2004
2005impl LinearizedOp<DynTensor> for CholeskyLinearized {
2006 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2007 let tangent = match &self.input {
2008 DynTensor::F32(input) => cholesky_jvp_t::<f32>(input, &input_tangents[0]),
2009 DynTensor::F64(input) => cholesky_jvp_t::<f64>(input, &input_tangents[0]),
2010 DynTensor::C32(input) => cholesky_jvp_t::<Complex32>(input, &input_tangents[0]),
2011 DynTensor::C64(input) => cholesky_jvp_t::<Complex64>(input, &input_tangents[0]),
2012 }
2013 .map_err(into_ad_error)?;
2014 Ok(vec![tangent])
2015 }
2016
2017 fn vjp(
2018 &self,
2019 output_cotangents: &[Option<DynTensor>],
2020 input_grad_mask: &[bool],
2021 ) -> AdResult<Vec<Option<DynTensor>>> {
2022 let Some(cotangent) = output_cotangents[0].as_ref() else {
2023 return Ok(vec![None]);
2024 };
2025 match &self.input {
2026 DynTensor::F32(input) => cholesky_vjp_t::<f32>(input, cotangent, input_grad_mask),
2027 DynTensor::F64(input) => cholesky_vjp_t::<f64>(input, cotangent, input_grad_mask),
2028 DynTensor::C32(input) => cholesky_vjp_t::<Complex32>(input, cotangent, input_grad_mask),
2029 DynTensor::C64(input) => cholesky_vjp_t::<Complex64>(input, cotangent, input_grad_mask),
2030 }
2031 .map_err(into_ad_error)
2032 }
2033}
2034
2035impl LinearizableOp<DynTensor> for EigOp {
2036 type Linearized = EigLinearized;
2037
2038 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2039 match inputs[0] {
2040 DynTensor::F32(input) => eig_primal_t::<f32>(input),
2041 DynTensor::F64(input) => eig_primal_t::<f64>(input),
2042 DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2043 "eig AD currently supports real inputs only",
2044 )),
2045 }
2046 .map_err(into_ad_error)
2047 }
2048
2049 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2050 Ok(differentiable_schema(1))
2051 }
2052
2053 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2054 Ok(differentiable_schema(2))
2055 }
2056
2057 fn linearize(
2058 &self,
2059 inputs: &[&DynTensor],
2060 _outputs: &[DynTensor],
2061 ) -> AdResult<Self::Linearized> {
2062 Ok(EigLinearized {
2063 input: inputs[0].clone(),
2064 })
2065 }
2066
2067 fn checkpoint_hint(&self) -> CheckpointHint {
2068 CheckpointHint::ExpensiveReplay
2069 }
2070}
2071
2072impl LinearizedOp<DynTensor> for EigLinearized {
2073 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2074 match &self.input {
2075 DynTensor::F32(input) => eig_jvp_t::<f32>(input, &input_tangents[0]),
2076 DynTensor::F64(input) => eig_jvp_t::<f64>(input, &input_tangents[0]),
2077 DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2078 "eig AD currently supports real inputs only",
2079 )),
2080 }
2081 .map_err(into_ad_error)
2082 }
2083
2084 fn vjp(
2085 &self,
2086 output_cotangents: &[Option<DynTensor>],
2087 input_grad_mask: &[bool],
2088 ) -> AdResult<Vec<Option<DynTensor>>> {
2089 match &self.input {
2090 DynTensor::F32(input) => eig_vjp_t::<f32>(input, output_cotangents, input_grad_mask),
2091 DynTensor::F64(input) => eig_vjp_t::<f64>(input, output_cotangents, input_grad_mask),
2092 DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2093 "eig AD currently supports real inputs only",
2094 )),
2095 }
2096 .map_err(into_ad_error)
2097 }
2098}
2099
2100impl LinearizableOp<DynTensor> for EigenOp {
2101 type Linearized = EigenLinearized;
2102
2103 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2104 match inputs[0] {
2105 DynTensor::F32(input) => eigen_primal_t::<f32>(input),
2106 DynTensor::F64(input) => eigen_primal_t::<f64>(input),
2107 DynTensor::C32(input) => eigen_primal_t::<Complex32>(input),
2108 DynTensor::C64(input) => eigen_primal_t::<Complex64>(input),
2109 }
2110 .map_err(into_ad_error)
2111 }
2112
2113 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2114 Ok(differentiable_schema(1))
2115 }
2116
2117 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2118 Ok(differentiable_schema(2))
2119 }
2120
2121 fn linearize(
2122 &self,
2123 inputs: &[&DynTensor],
2124 _outputs: &[DynTensor],
2125 ) -> AdResult<Self::Linearized> {
2126 Ok(EigenLinearized {
2127 input: inputs[0].clone(),
2128 })
2129 }
2130
2131 fn checkpoint_hint(&self) -> CheckpointHint {
2132 CheckpointHint::ExpensiveReplay
2133 }
2134}
2135
2136impl LinearizedOp<DynTensor> for EigenLinearized {
2137 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2138 match &self.input {
2139 DynTensor::F32(input) => eigen_jvp_t::<f32>(input, &input_tangents[0]),
2140 DynTensor::F64(input) => eigen_jvp_t::<f64>(input, &input_tangents[0]),
2141 DynTensor::C32(input) => eigen_jvp_t::<Complex32>(input, &input_tangents[0]),
2142 DynTensor::C64(input) => eigen_jvp_t::<Complex64>(input, &input_tangents[0]),
2143 }
2144 .map_err(into_ad_error)
2145 }
2146
2147 fn vjp(
2148 &self,
2149 output_cotangents: &[Option<DynTensor>],
2150 input_grad_mask: &[bool],
2151 ) -> AdResult<Vec<Option<DynTensor>>> {
2152 match &self.input {
2153 DynTensor::F32(input) => eigen_vjp_t::<f32>(input, output_cotangents, input_grad_mask),
2154 DynTensor::F64(input) => eigen_vjp_t::<f64>(input, output_cotangents, input_grad_mask),
2155 DynTensor::C32(input) => {
2156 eigen_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask)
2157 }
2158 DynTensor::C64(input) => {
2159 eigen_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask)
2160 }
2161 }
2162 .map_err(into_ad_error)
2163 }
2164}
2165
2166impl LinearizableOp<DynTensor> for DetOp {
2167 type Linearized = DetLinearized;
2168
2169 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2170 let output = match inputs[0] {
2171 DynTensor::F32(input) => det_primal_t::<f32>(input),
2172 DynTensor::F64(input) => det_primal_t::<f64>(input),
2173 DynTensor::C32(input) => det_primal_t::<Complex32>(input),
2174 DynTensor::C64(input) => det_primal_t::<Complex64>(input),
2175 }
2176 .map_err(into_ad_error)?;
2177 Ok(vec![output])
2178 }
2179
2180 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2181 Ok(differentiable_schema(1))
2182 }
2183
2184 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2185 Ok(differentiable_schema(1))
2186 }
2187
2188 fn linearize(
2189 &self,
2190 inputs: &[&DynTensor],
2191 _outputs: &[DynTensor],
2192 ) -> AdResult<Self::Linearized> {
2193 Ok(DetLinearized {
2194 input: inputs[0].clone(),
2195 })
2196 }
2197
2198 fn checkpoint_hint(&self) -> CheckpointHint {
2199 CheckpointHint::ExpensiveReplay
2200 }
2201}
2202
2203impl LinearizedOp<DynTensor> for DetLinearized {
2204 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2205 let tangent = match &self.input {
2206 DynTensor::F32(input) => det_jvp_t::<f32>(input, &input_tangents[0]),
2207 DynTensor::F64(input) => det_jvp_t::<f64>(input, &input_tangents[0]),
2208 DynTensor::C32(input) => det_jvp_t::<Complex32>(input, &input_tangents[0]),
2209 DynTensor::C64(input) => det_jvp_t::<Complex64>(input, &input_tangents[0]),
2210 }
2211 .map_err(into_ad_error)?;
2212 Ok(vec![tangent])
2213 }
2214
2215 fn vjp(
2216 &self,
2217 output_cotangents: &[Option<DynTensor>],
2218 input_grad_mask: &[bool],
2219 ) -> AdResult<Vec<Option<DynTensor>>> {
2220 let Some(cotangent) = output_cotangents[0].as_ref() else {
2221 return Ok(vec![None]);
2222 };
2223 match &self.input {
2224 DynTensor::F32(input) => det_vjp_t::<f32>(input, cotangent, input_grad_mask),
2225 DynTensor::F64(input) => det_vjp_t::<f64>(input, cotangent, input_grad_mask),
2226 DynTensor::C32(input) => det_vjp_t::<Complex32>(input, cotangent, input_grad_mask),
2227 DynTensor::C64(input) => det_vjp_t::<Complex64>(input, cotangent, input_grad_mask),
2228 }
2229 .map_err(into_ad_error)
2230 }
2231}
2232
2233impl LinearizableOp<DynTensor> for PInvOp {
2234 type Linearized = PInvLinearized;
2235
2236 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2237 let output = match inputs[0] {
2238 DynTensor::F32(input) => pinv_primal_t::<f32>(input, self.rcond),
2239 DynTensor::F64(input) => pinv_primal_t::<f64>(input, self.rcond),
2240 DynTensor::C32(input) => pinv_primal_t::<Complex32>(input, self.rcond),
2241 DynTensor::C64(input) => pinv_primal_t::<Complex64>(input, self.rcond),
2242 }
2243 .map_err(into_ad_error)?;
2244 Ok(vec![output])
2245 }
2246
2247 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2248 Ok(differentiable_schema(1))
2249 }
2250
2251 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2252 Ok(differentiable_schema(1))
2253 }
2254
2255 fn linearize(
2256 &self,
2257 inputs: &[&DynTensor],
2258 _outputs: &[DynTensor],
2259 ) -> AdResult<Self::Linearized> {
2260 Ok(PInvLinearized {
2261 input: inputs[0].clone(),
2262 rcond: self.rcond,
2263 })
2264 }
2265
2266 fn checkpoint_hint(&self) -> CheckpointHint {
2267 CheckpointHint::ExpensiveReplay
2268 }
2269}
2270
2271impl LinearizedOp<DynTensor> for PInvLinearized {
2272 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2273 let tangent = match &self.input {
2274 DynTensor::F32(input) => pinv_jvp_t::<f32>(input, &input_tangents[0], self.rcond),
2275 DynTensor::F64(input) => pinv_jvp_t::<f64>(input, &input_tangents[0], self.rcond),
2276 DynTensor::C32(input) => pinv_jvp_t::<Complex32>(input, &input_tangents[0], self.rcond),
2277 DynTensor::C64(input) => pinv_jvp_t::<Complex64>(input, &input_tangents[0], self.rcond),
2278 }
2279 .map_err(into_ad_error)?;
2280 Ok(vec![tangent])
2281 }
2282
2283 fn vjp(
2284 &self,
2285 output_cotangents: &[Option<DynTensor>],
2286 input_grad_mask: &[bool],
2287 ) -> AdResult<Vec<Option<DynTensor>>> {
2288 let Some(cotangent) = output_cotangents[0].as_ref() else {
2289 return Ok(vec![None]);
2290 };
2291 match &self.input {
2292 DynTensor::F32(input) => {
2293 pinv_vjp_t::<f32>(input, cotangent, input_grad_mask, self.rcond)
2294 }
2295 DynTensor::F64(input) => {
2296 pinv_vjp_t::<f64>(input, cotangent, input_grad_mask, self.rcond)
2297 }
2298 DynTensor::C32(input) => {
2299 pinv_vjp_t::<Complex32>(input, cotangent, input_grad_mask, self.rcond)
2300 }
2301 DynTensor::C64(input) => {
2302 pinv_vjp_t::<Complex64>(input, cotangent, input_grad_mask, self.rcond)
2303 }
2304 }
2305 .map_err(into_ad_error)
2306 }
2307}
2308
2309impl LinearizableOp<DynTensor> for MatrixExpOp {
2310 type Linearized = MatrixExpLinearized;
2311
2312 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2313 let output = match inputs[0] {
2314 DynTensor::F32(input) => matrix_exp_primal_t::<f32>(input),
2315 DynTensor::F64(input) => matrix_exp_primal_t::<f64>(input),
2316 DynTensor::C32(input) => matrix_exp_primal_t::<Complex32>(input),
2317 DynTensor::C64(input) => matrix_exp_primal_t::<Complex64>(input),
2318 }
2319 .map_err(into_ad_error)?;
2320 Ok(vec![output])
2321 }
2322
2323 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2324 Ok(differentiable_schema(1))
2325 }
2326
2327 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2328 Ok(differentiable_schema(1))
2329 }
2330
2331 fn linearize(
2332 &self,
2333 inputs: &[&DynTensor],
2334 _outputs: &[DynTensor],
2335 ) -> AdResult<Self::Linearized> {
2336 Ok(MatrixExpLinearized {
2337 input: inputs[0].clone(),
2338 })
2339 }
2340
2341 fn checkpoint_hint(&self) -> CheckpointHint {
2342 CheckpointHint::ExpensiveReplay
2343 }
2344}
2345
2346impl LinearizedOp<DynTensor> for MatrixExpLinearized {
2347 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2348 let tangent = match &self.input {
2349 DynTensor::F32(input) => matrix_exp_jvp_t::<f32>(input, &input_tangents[0]),
2350 DynTensor::F64(input) => matrix_exp_jvp_t::<f64>(input, &input_tangents[0]),
2351 DynTensor::C32(input) => matrix_exp_jvp_t::<Complex32>(input, &input_tangents[0]),
2352 DynTensor::C64(input) => matrix_exp_jvp_t::<Complex64>(input, &input_tangents[0]),
2353 }
2354 .map_err(into_ad_error)?;
2355 Ok(vec![tangent])
2356 }
2357
2358 fn vjp(
2359 &self,
2360 output_cotangents: &[Option<DynTensor>],
2361 input_grad_mask: &[bool],
2362 ) -> AdResult<Vec<Option<DynTensor>>> {
2363 let Some(cotangent) = output_cotangents[0].as_ref() else {
2364 return Ok(vec![None]);
2365 };
2366 match &self.input {
2367 DynTensor::F32(input) => matrix_exp_vjp_t::<f32>(input, cotangent, input_grad_mask),
2368 DynTensor::F64(input) => matrix_exp_vjp_t::<f64>(input, cotangent, input_grad_mask),
2369 DynTensor::C32(input) => {
2370 matrix_exp_vjp_t::<Complex32>(input, cotangent, input_grad_mask)
2371 }
2372 DynTensor::C64(input) => {
2373 matrix_exp_vjp_t::<Complex64>(input, cotangent, input_grad_mask)
2374 }
2375 }
2376 .map_err(into_ad_error)
2377 }
2378}
2379
2380impl LinearizableOp<DynTensor> for QrOp {
2381 type Linearized = QrLinearized;
2382
2383 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2384 match inputs[0] {
2385 DynTensor::F32(input) => qr_primal_t::<f32>(input),
2386 DynTensor::F64(input) => qr_primal_t::<f64>(input),
2387 DynTensor::C32(input) => qr_primal_t::<Complex32>(input),
2388 DynTensor::C64(input) => qr_primal_t::<Complex64>(input),
2389 }
2390 .map_err(into_ad_error)
2391 }
2392
2393 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2394 Ok(differentiable_schema(1))
2395 }
2396
2397 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2398 Ok(differentiable_schema(2))
2399 }
2400
2401 fn linearize(
2402 &self,
2403 inputs: &[&DynTensor],
2404 _outputs: &[DynTensor],
2405 ) -> AdResult<Self::Linearized> {
2406 Ok(QrLinearized {
2407 input: inputs[0].clone(),
2408 })
2409 }
2410
2411 fn checkpoint_hint(&self) -> CheckpointHint {
2412 CheckpointHint::ExpensiveReplay
2413 }
2414}
2415
2416impl LinearizedOp<DynTensor> for QrLinearized {
2417 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2418 match &self.input {
2419 DynTensor::F32(input) => qr_jvp_t::<f32>(input, &input_tangents[0]),
2420 DynTensor::F64(input) => qr_jvp_t::<f64>(input, &input_tangents[0]),
2421 DynTensor::C32(input) => qr_jvp_t::<Complex32>(input, &input_tangents[0]),
2422 DynTensor::C64(input) => qr_jvp_t::<Complex64>(input, &input_tangents[0]),
2423 }
2424 .map_err(into_ad_error)
2425 }
2426
2427 fn vjp(
2428 &self,
2429 output_cotangents: &[Option<DynTensor>],
2430 input_grad_mask: &[bool],
2431 ) -> AdResult<Vec<Option<DynTensor>>> {
2432 match &self.input {
2433 DynTensor::F32(input) => qr_vjp_t::<f32>(input, output_cotangents, input_grad_mask),
2434 DynTensor::F64(input) => qr_vjp_t::<f64>(input, output_cotangents, input_grad_mask),
2435 DynTensor::C32(input) => {
2436 qr_vjp_t::<Complex32>(input, output_cotangents, input_grad_mask)
2437 }
2438 DynTensor::C64(input) => {
2439 qr_vjp_t::<Complex64>(input, output_cotangents, input_grad_mask)
2440 }
2441 }
2442 .map_err(into_ad_error)
2443 }
2444}
2445
2446impl LinearizableOp<DynTensor> for SvdOp {
2447 type Linearized = SvdLinearized;
2448
2449 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
2450 match inputs[0] {
2451 DynTensor::F32(input) => svd_primal_t::<f32>(input, self.options.as_ref()),
2452 DynTensor::F64(input) => svd_primal_t::<f64>(input, self.options.as_ref()),
2453 DynTensor::C32(input) => svd_primal_t::<Complex32>(input, self.options.as_ref()),
2454 DynTensor::C64(input) => svd_primal_t::<Complex64>(input, self.options.as_ref()),
2455 }
2456 .map_err(into_ad_error)
2457 }
2458
2459 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
2460 Ok(differentiable_schema(1))
2461 }
2462
2463 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
2464 Ok(differentiable_schema(3))
2465 }
2466
2467 fn linearize(
2468 &self,
2469 inputs: &[&DynTensor],
2470 _outputs: &[DynTensor],
2471 ) -> AdResult<Self::Linearized> {
2472 Ok(SvdLinearized {
2473 input: inputs[0].clone(),
2474 options: self.options.clone(),
2475 })
2476 }
2477
2478 fn checkpoint_hint(&self) -> CheckpointHint {
2479 CheckpointHint::ExpensiveReplay
2480 }
2481}
2482
2483impl LinearizedOp<DynTensor> for SvdLinearized {
2484 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
2485 match &self.input {
2486 DynTensor::F32(input) => {
2487 svd_jvp_t::<f32>(input, &input_tangents[0], self.options.as_ref())
2488 }
2489 DynTensor::F64(input) => {
2490 svd_jvp_t::<f64>(input, &input_tangents[0], self.options.as_ref())
2491 }
2492 DynTensor::C32(input) => {
2493 svd_jvp_t::<Complex32>(input, &input_tangents[0], self.options.as_ref())
2494 }
2495 DynTensor::C64(input) => {
2496 svd_jvp_t::<Complex64>(input, &input_tangents[0], self.options.as_ref())
2497 }
2498 }
2499 .map_err(into_ad_error)
2500 }
2501
2502 fn vjp(
2503 &self,
2504 output_cotangents: &[Option<DynTensor>],
2505 input_grad_mask: &[bool],
2506 ) -> AdResult<Vec<Option<DynTensor>>> {
2507 match &self.input {
2508 DynTensor::F32(input) => svd_vjp_t::<f32>(
2509 input,
2510 output_cotangents,
2511 input_grad_mask,
2512 self.options.as_ref(),
2513 ),
2514 DynTensor::F64(input) => svd_vjp_t::<f64>(
2515 input,
2516 output_cotangents,
2517 input_grad_mask,
2518 self.options.as_ref(),
2519 ),
2520 DynTensor::C32(input) => svd_vjp_t::<Complex32>(
2521 input,
2522 output_cotangents,
2523 input_grad_mask,
2524 self.options.as_ref(),
2525 ),
2526 DynTensor::C64(input) => svd_vjp_t::<Complex64>(
2527 input,
2528 output_cotangents,
2529 input_grad_mask,
2530 self.options.as_ref(),
2531 ),
2532 }
2533 .map_err(into_ad_error)
2534 }
2535}
2536
2537pub fn solve_dyn_values(a: &DynValue, b: &DynValue) -> AdResult<DynValue> {
2538 SolveOp.apply_one(&[a, b])
2539}
2540
2541pub fn lstsq_dyn_values(a: &DynValue, b: &DynValue) -> AdResult<DynLstsqValues> {
2542 let mut outputs = LstsqOp.apply(&[a, b])?;
2543 if outputs.len() != 2 {
2544 return Err(AutodiffError::InvalidArgument(format!(
2545 "LstsqOp expected 2 outputs, got {}",
2546 outputs.len()
2547 )));
2548 }
2549 let residuals = outputs.pop().unwrap();
2550 let solution = outputs.pop().unwrap();
2551 let (rank, singular_values) = match a.primal() {
2552 DynTensor::F32(input) => lstsq_aux_t::<f32>(input),
2553 DynTensor::F64(input) => lstsq_aux_t::<f64>(input),
2554 DynTensor::C32(_) | DynTensor::C64(_) => Err(invalid_argument(
2555 "lstsq AD currently supports real dtypes only",
2556 )),
2557 }
2558 .map_err(into_ad_error)?;
2559 Ok(DynLstsqValues {
2560 solution,
2561 residuals,
2562 rank,
2563 singular_values,
2564 })
2565}
2566
2567pub fn solve_triangular_dyn_value(a: &DynValue, b: &DynValue, upper: bool) -> AdResult<DynValue> {
2568 SolveTriangularOp::new(upper).apply_one(&[a, b])
2569}
2570
2571pub fn norm_dyn_value(input: &DynValue, kind: NormKind) -> AdResult<DynValue> {
2572 NormOp::new(kind).apply_one(&[input])
2573}
2574
2575pub fn det_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2576 DetOp.apply_one(&[input])
2577}
2578
2579pub fn inv_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2580 InvOp.apply_one(&[input])
2581}
2582
2583pub fn slogdet_dyn_value(input: &DynValue) -> AdResult<DynSlogdetValues> {
2584 let mut outputs = SlogdetOp.apply(&[input])?;
2585 if outputs.len() != 2 {
2586 return Err(AutodiffError::InvalidArgument(format!(
2587 "SlogdetOp expected 2 outputs, got {}",
2588 outputs.len()
2589 )));
2590 }
2591 let logabsdet = outputs.pop().unwrap();
2592 let sign = outputs.pop().unwrap();
2593 Ok(DynSlogdetValues { sign, logabsdet })
2594}
2595
2596pub fn cholesky_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2597 CholeskyOp.apply_one(&[input])
2598}
2599
2600pub fn lu_dyn_value(input: &DynValue, pivot: LuPivot) -> AdResult<DynLuValues> {
2601 let mut outputs = LuOp::new(pivot).apply(&[input])?;
2602 if outputs.len() != 3 {
2603 return Err(AutodiffError::InvalidArgument(format!(
2604 "LuOp expected 3 outputs, got {}",
2605 outputs.len()
2606 )));
2607 }
2608 let u = outputs.pop().unwrap();
2609 let l = outputs.pop().unwrap();
2610 let p = outputs.pop().unwrap();
2611 Ok(DynLuValues { p, l, u })
2612}
2613
2614pub fn qr_dyn_value(input: &DynValue) -> AdResult<DynQrValues> {
2615 let mut outputs = QrOp.apply(&[input])?;
2616 if outputs.len() != 2 {
2617 return Err(AutodiffError::InvalidArgument(format!(
2618 "QrOp expected 2 outputs, got {}",
2619 outputs.len()
2620 )));
2621 }
2622 let r = outputs.pop().unwrap();
2623 let q = outputs.pop().unwrap();
2624 Ok(DynQrValues { q, r })
2625}
2626
2627pub fn svd_dyn_value(input: &DynValue, options: Option<SvdOptions>) -> AdResult<DynSvdValues> {
2628 let mut outputs = SvdOp::new(options).apply(&[input])?;
2629 if outputs.len() != 3 {
2630 return Err(AutodiffError::InvalidArgument(format!(
2631 "SvdOp expected 3 outputs, got {}",
2632 outputs.len()
2633 )));
2634 }
2635 let vt = outputs.pop().unwrap();
2636 let s = outputs.pop().unwrap();
2637 let u = outputs.pop().unwrap();
2638 Ok(DynSvdValues { u, s, vt })
2639}
2640
2641pub fn eig_dyn_value(input: &DynValue) -> AdResult<DynEigValues> {
2642 let mut outputs = EigOp.apply(&[input])?;
2643 if outputs.len() != 2 {
2644 return Err(AutodiffError::InvalidArgument(format!(
2645 "EigOp expected 2 outputs, got {}",
2646 outputs.len()
2647 )));
2648 }
2649 let vectors = outputs.pop().unwrap();
2650 let values = outputs.pop().unwrap();
2651 Ok(DynEigValues { values, vectors })
2652}
2653
2654pub fn eigen_dyn_value(input: &DynValue) -> AdResult<DynEigenValues> {
2655 let mut outputs = EigenOp.apply(&[input])?;
2656 if outputs.len() != 2 {
2657 return Err(AutodiffError::InvalidArgument(format!(
2658 "EigenOp expected 2 outputs, got {}",
2659 outputs.len()
2660 )));
2661 }
2662 let vectors = outputs.pop().unwrap();
2663 let values = outputs.pop().unwrap();
2664 Ok(DynEigenValues { values, vectors })
2665}
2666
2667pub fn pinv_dyn_value(input: &DynValue, rcond: Option<f64>) -> AdResult<DynValue> {
2668 PInvOp::new(rcond).apply_one(&[input])
2669}
2670
2671pub fn matrix_exp_dyn_value(input: &DynValue) -> AdResult<DynValue> {
2672 MatrixExpOp.apply_one(&[input])
2673}