tenferro_linalg/rrules/
matrix_functions.rs

1use super::*;
2
3/// Reverse-mode AD rule for matrix exponential (VJP / pullback).
4///
5/// Computes the gradient of the input given a cotangent for `exp(A)`.
6/// Uses the auxiliary 2n x 2n matrix trick (PyTorch approach):
7///
8/// ```text
9/// M = [[A^T, cotangent], [0, A^T]]
10/// grad_A = top-right n×n block of exp(M)
11/// ```
12///
13/// # Examples
14///
15/// ```
16/// use tenferro_linalg::matrix_exp_rrule;
17/// use tenferro_prims::CpuContext;
18/// use tenferro_tensor::{Tensor, MemoryOrder};
19/// use tenferro_device::LogicalMemorySpace;
20///
21/// let col = MemoryOrder::ColumnMajor;
22/// let mem = LogicalMemorySpace::MainMemory;
23/// let mut ctx = CpuContext::new(1);
24/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
25/// let cotangent = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
26/// let grad_a = matrix_exp_rrule(&mut ctx, &a, &cotangent).unwrap();
27/// ```
28pub fn matrix_exp_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
29    ctx: &mut C,
30    tensor: &Tensor<T>,
31    cotangent: &Tensor<T>,
32) -> AdResult<Tensor<T>>
33where
34    T: KernelLinalgScalar
35        + tenferro_algebra::Conjugate
36        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
37        + crate::ad_helpers::MatrixExpAbsTensor<C>,
38    C: backend::TensorLinalgContextFor<T>
39        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
40        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
41        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>
42        + tenferro_prims::TensorResolveConjContextFor<T>,
43    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
44    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
45        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T::Real>, Context = C>,
46    C::Backend: 'static,
47{
48    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixExp, "matrix_exp_rrule")
49        .map_err(to_ad_err)?;
50
51    let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
52    let mut perm = vec![1, 0];
53    perm.extend(2..tensor.ndim());
54    let a_h_view = tensor.conj().permute(&perm).map_err(to_ad_err)?;
55    let a_h = crate::prims_bridge::resolve_conj(ctx, &a_h_view);
56    let zero = Tensor::<T>::zeros(
57        &output_dims(&[n, n], batch_dims),
58        tensor.logical_memory_space(),
59        MemoryOrder::ColumnMajor,
60    )
61    .map_err(to_ad_err)?;
62    let top = Tensor::cat(&[&a_h, cotangent], 1).map_err(to_ad_err)?;
63    let bottom = Tensor::cat(&[&zero, &a_h], 1).map_err(to_ad_err)?;
64    let m = Tensor::cat(&[&top, &bottom], 0).map_err(to_ad_err)?;
65    let exp_m = matrix_exp(ctx, &m).map_err(to_ad_err)?;
66
67    exp_m
68        .narrow(0, 0, n)
69        .and_then(|t| t.narrow(1, n, n))
70        .map_err(to_ad_err)
71}