tenferro_linalg/frules/
matrix_functions.rs

1use super::*;
2
3/// Forward-mode AD rule for matrix exponential (JVP / pushforward).
4///
5/// Computes `exp(A)` and the Frechet derivative `d(exp(A))` in the direction `dA`.
6/// Uses the auxiliary 2n x 2n matrix trick (PyTorch approach):
7///
8/// ```text
9/// M = [[A, dA], [0, A]]
10/// exp(A)    = top-left  n×n block of exp(M)
11/// d(exp(A)) = top-right n×n block of exp(M)
12/// ```
13///
14/// # Examples
15///
16/// ```
17/// use tenferro_linalg::matrix_exp_frule;
18/// use tenferro_prims::CpuContext;
19/// use tenferro_tensor::{Tensor, MemoryOrder};
20/// use tenferro_device::LogicalMemorySpace;
21///
22/// let col = MemoryOrder::ColumnMajor;
23/// let mem = LogicalMemorySpace::MainMemory;
24/// let mut ctx = CpuContext::new(1);
25/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
26/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
27/// let (exp_a, dexp_a) = matrix_exp_frule(&mut ctx, &a, &da).unwrap();
28/// ```
29pub fn matrix_exp_frule<T: KernelLinalgScalar, C>(
30    ctx: &mut C,
31    tensor: &Tensor<T>,
32    tangent: &Tensor<T>,
33) -> AdResult<(Tensor<T>, Tensor<T>)>
34where
35    T: KernelLinalgScalar
36        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
37        + crate::ad_helpers::MatrixExpAbsTensor<C>,
38    C: backend::TensorLinalgContextFor<T>
39        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
40        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
41        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
42    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
43    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
44        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T::Real>, Context = C>,
45    C::Backend: 'static,
46{
47    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixExp, "matrix_exp_frule")
48        .map_err(to_ad_err)?;
49
50    let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
51    let zero = Tensor::<T>::zeros(
52        &output_dims(&[n, n], batch_dims),
53        tensor.logical_memory_space(),
54        MemoryOrder::ColumnMajor,
55    )
56    .map_err(to_ad_err)?;
57    let top = Tensor::cat(&[tensor, tangent], 1).map_err(to_ad_err)?;
58    let bottom = Tensor::cat(&[&zero, tensor], 1).map_err(to_ad_err)?;
59    let m = Tensor::cat(&[&top, &bottom], 0).map_err(to_ad_err)?;
60    let exp_m = matrix_exp(ctx, &m).map_err(to_ad_err)?;
61
62    let result = exp_m
63        .narrow(0, 0, n)
64        .and_then(|t| t.narrow(1, 0, n))
65        .map_err(to_ad_err)?;
66    let tang = exp_m
67        .narrow(0, 0, n)
68        .and_then(|t| t.narrow(1, n, n))
69        .map_err(to_ad_err)?;
70    Ok((result, tang))
71}