tenferro_linalg/frules/linear_systems/
inv_det.rs

1use super::*;
2
3/// Forward-mode AD rule for matrix inverse (JVP / pushforward).
4///
5/// # Examples
6///
7/// ```
8/// use tenferro_linalg::inv_frule;
9/// use tenferro_prims::CpuContext;
10/// use tenferro_tensor::{Tensor, MemoryOrder};
11/// use tenferro_device::LogicalMemorySpace;
12///
13/// let col = MemoryOrder::ColumnMajor;
14/// let mem = LogicalMemorySpace::MainMemory;
15/// let mut ctx = CpuContext::new(1);
16/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
17/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
18/// let (a_inv, da_inv) = inv_frule(&mut ctx, &a, &da).unwrap();
19/// ```
20pub fn inv_frule<T: KernelLinalgScalar, C>(
21    ctx: &mut C,
22    tensor: &Tensor<T>,
23    tangent: &Tensor<T>,
24) -> AdResult<(Tensor<T>, Tensor<T>)>
25where
26    T: KernelLinalgScalar,
27    C: backend::TensorLinalgContextFor<T>
28        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
29    C::Backend: 'static,
30    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
31        'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
32{
33    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv_frule")
34        .map_err(to_ad_err)?;
35
36    // dB = -B dA B where B = A^{-1}
37    let b_inv = inv(ctx, tensor).map_err(to_ad_err)?;
38    let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
39
40    let b_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &b_inv, tangent, n, n, n)
41        .map_err(to_ad_err)?;
42    let db = prims_bridge::batched_gemm_alpha_tensors(ctx, &b_da, &b_inv, n, n, n, -T::one())
43        .map_err(to_ad_err)?;
44
45    Ok((b_inv, db))
46}
47
48/// Forward-mode AD rule for determinant (JVP / pushforward).
49///
50/// # Examples
51///
52/// ```
53/// use tenferro_linalg::det_frule;
54/// use tenferro_prims::CpuContext;
55/// use tenferro_tensor::{Tensor, MemoryOrder};
56/// use tenferro_device::LogicalMemorySpace;
57///
58/// let col = MemoryOrder::ColumnMajor;
59/// let mem = LogicalMemorySpace::MainMemory;
60/// let mut ctx = CpuContext::new(1);
61/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
62/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
63/// let (d, dd) = det_frule(&mut ctx, &a, &da).unwrap();
64/// ```
65pub fn det_frule<T: KernelLinalgScalar, C>(
66    ctx: &mut C,
67    tensor: &Tensor<T>,
68    tangent: &Tensor<T>,
69) -> AdResult<(Tensor<T>, Tensor<T>)>
70where
71    T: KernelLinalgScalar + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
72    C: backend::TensorLinalgContextFor<T>
73        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
74        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
75        + tenferro_prims::TensorMetadataContextFor,
76    C::Backend: 'static,
77    C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
78    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
79        tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
80    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
81        tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
82{
83    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Det, "det_frule")
84        .map_err(to_ad_err)?;
85
86    // d(det) = det(A) * tr(A^{-1} dA)
87    let d = det(ctx, tensor).map_err(to_ad_err)?;
88    let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
89    if n == 0 {
90        let dd = Tensor::zeros(
91            d.dims(),
92            tensor.logical_memory_space(),
93            MemoryOrder::ColumnMajor,
94        )
95        .map_err(to_ad_err)?;
96        return Ok((d, dd));
97    }
98
99    let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
100    let a_inv_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &a_inv, tangent, n, n, n)
101        .map_err(to_ad_err)?;
102    let trace = trace_tensor(ctx, &a_inv_da).map_err(to_ad_err)?;
103
104    // dd = det(A) * trace
105    let dd = prims_bridge::scalar_binary_same_shape(
106        ctx,
107        &d,
108        &trace,
109        tenferro_prims::ScalarBinaryOp::Mul,
110    )
111    .map_err(to_ad_err)?;
112
113    Ok((d, dd))
114}