tenferro_linalg/frules/
matrix_functions.rs1use super::*;
2
3pub fn matrix_exp_frule<T: KernelLinalgScalar, C>(
30 ctx: &mut C,
31 tensor: &Tensor<T>,
32 tangent: &Tensor<T>,
33) -> AdResult<(Tensor<T>, Tensor<T>)>
34where
35 T: KernelLinalgScalar
36 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
37 + crate::ad_helpers::MatrixExpAbsTensor<C>,
38 C: backend::TensorLinalgContextFor<T>
39 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
40 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
41 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
42 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
43 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
44 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T::Real>, Context = C>,
45 C::Backend: 'static,
46{
47 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixExp, "matrix_exp_frule")
48 .map_err(to_ad_err)?;
49
50 let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
51 let zero = Tensor::<T>::zeros(
52 &output_dims(&[n, n], batch_dims),
53 tensor.logical_memory_space(),
54 MemoryOrder::ColumnMajor,
55 )
56 .map_err(to_ad_err)?;
57 let top = Tensor::cat(&[tensor, tangent], 1).map_err(to_ad_err)?;
58 let bottom = Tensor::cat(&[&zero, tensor], 1).map_err(to_ad_err)?;
59 let m = Tensor::cat(&[&top, &bottom], 0).map_err(to_ad_err)?;
60 let exp_m = matrix_exp(ctx, &m).map_err(to_ad_err)?;
61
62 let result = exp_m
63 .narrow(0, 0, n)
64 .and_then(|t| t.narrow(1, 0, n))
65 .map_err(to_ad_err)?;
66 let tang = exp_m
67 .narrow(0, 0, n)
68 .and_then(|t| t.narrow(1, n, n))
69 .map_err(to_ad_err)?;
70 Ok((result, tang))
71}