tenferro_linalg/rrules/linear_systems/
inverse.rs

1use super::*;
2use tenferro_algebra::Conjugate;
3
4/// Reverse-mode AD rule for matrix inverse (VJP / pullback).
5///
6/// `Ā = -A⁻ᵀ · cotangent · A⁻ᵀ`.
7///
8/// # Examples
9///
10/// ```
11/// use tenferro_linalg::inv_rrule;
12/// use tenferro_prims::CpuContext;
13/// use tenferro_tensor::{Tensor, MemoryOrder};
14/// use tenferro_device::LogicalMemorySpace;
15///
16/// let col = MemoryOrder::ColumnMajor;
17/// let mem = LogicalMemorySpace::MainMemory;
18/// let mut ctx = CpuContext::new(1);
19/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
20/// let cotangent = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
21/// let grad_a = inv_rrule(&mut ctx, &a, &cotangent).unwrap();
22/// ```
23pub fn inv_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
24    ctx: &mut C,
25    tensor: &Tensor<T>,
26    cotangent: &Tensor<T>,
27) -> AdResult<Tensor<T>>
28where
29    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
30    C: backend::TensorLinalgContextFor<T> + tenferro_prims::TensorResolveConjContextFor<T>,
31    C::Backend: 'static,
32{
33    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv_rrule")
34        .map_err(to_ad_err)?;
35
36    // dA = -B^H dB B^H where B = A^{-1}
37    let b_inv = inv(ctx, tensor).map_err(to_ad_err)?;
38    let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
39
40    let bt = matrix_adjoint_eager(ctx, &b_inv).map_err(to_ad_err)?;
41    let bt_db = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &bt, cotangent, n, n, n)
42        .map_err(to_ad_err)?;
43    let grad_a = prims_bridge::batched_gemm_alpha_tensors(ctx, &bt_db, &bt, n, n, n, -T::one())
44        .map_err(to_ad_err)?;
45
46    Ok(grad_a)
47}
48
49/// Reverse-mode AD rule for determinant (VJP / pullback).
50///
51/// `Ā = det(A) · cotangent · A⁻ᵀ`.
52///
53/// # Examples
54///
55/// ```
56/// use tenferro_linalg::det_rrule;
57/// use tenferro_prims::CpuContext;
58/// use tenferro_tensor::{Tensor, MemoryOrder};
59/// use tenferro_device::LogicalMemorySpace;
60///
61/// let col = MemoryOrder::ColumnMajor;
62/// let mem = LogicalMemorySpace::MainMemory;
63/// let mut ctx = CpuContext::new(1);
64/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
65/// let cotangent = Tensor::<f64>::ones(&[], mem, col).unwrap();
66/// let grad_a = det_rrule(&mut ctx, &a, &cotangent).unwrap();
67/// ```
68pub fn det_rrule<T: KernelLinalgScalar + Conjugate, C>(
69    ctx: &mut C,
70    tensor: &Tensor<T>,
71    cotangent: &Tensor<T>,
72) -> AdResult<Tensor<T>>
73where
74    T: KernelLinalgScalar + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
75    C: backend::TensorLinalgContextFor<T>
76        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
77        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
78        + tenferro_prims::TensorMetadataContextFor
79        + tenferro_prims::TensorResolveConjContextFor<T>,
80    C::Backend: 'static,
81    C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
82    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
83        tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
84    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
85        tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
86{
87    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Det, "det_rrule")
88        .map_err(to_ad_err)?;
89
90    let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
91    if n == 0 {
92        return Tensor::zeros(
93            tensor.dims(),
94            tensor.logical_memory_space(),
95            MemoryOrder::ColumnMajor,
96        )
97        .map_err(to_ad_err);
98    }
99
100    // dA = cotangent * conj(det(A)) * A^{-H}
101    let det_val = det(ctx, tensor).map_err(to_ad_err)?;
102    let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
103    let a_inv_h = matrix_adjoint_eager(ctx, &a_inv).map_err(to_ad_err)?;
104    let det_conj =
105        prims_bridge::scalar_unary_same_shape(ctx, &det_val, tenferro_prims::ScalarUnaryOp::Conj)
106            .map_err(to_ad_err)?;
107
108    // scale = cotangent * conj(det(A)), shape [batch...]
109    let scale = prims_bridge::scalar_binary_same_shape(
110        ctx,
111        cotangent,
112        &det_conj,
113        tenferro_prims::ScalarBinaryOp::Mul,
114    )
115    .map_err(to_ad_err)?;
116
117    // broadcast scale [batch...] → [1, 1, batch...] → [n, n, batch...]
118    let scale_expanded = scale
119        .unsqueeze(0)
120        .map_err(to_ad_err)?
121        .unsqueeze(0)
122        .map_err(to_ad_err)?
123        .broadcast(a_inv_h.dims())
124        .map_err(to_ad_err)?;
125    let grad_a = prims_bridge::scalar_binary_same_shape(
126        ctx,
127        &scale_expanded,
128        &a_inv_h,
129        tenferro_prims::ScalarBinaryOp::Mul,
130    )
131    .map_err(to_ad_err)?;
132
133    Ok(grad_a)
134}