tenferro_linalg/rrules/
matrix_functions.rs1use super::*;
2
3pub fn matrix_exp_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
29 ctx: &mut C,
30 tensor: &Tensor<T>,
31 cotangent: &Tensor<T>,
32) -> AdResult<Tensor<T>>
33where
34 T: KernelLinalgScalar
35 + tenferro_algebra::Conjugate
36 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
37 + crate::ad_helpers::MatrixExpAbsTensor<C>,
38 C: backend::TensorLinalgContextFor<T>
39 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
40 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
41 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>
42 + tenferro_prims::TensorResolveConjContextFor<T>,
43 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
44 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
45 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T::Real>, Context = C>,
46 C::Backend: 'static,
47{
48 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixExp, "matrix_exp_rrule")
49 .map_err(to_ad_err)?;
50
51 let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
52 let mut perm = vec![1, 0];
53 perm.extend(2..tensor.ndim());
54 let a_h_view = tensor.conj().permute(&perm).map_err(to_ad_err)?;
55 let a_h = crate::prims_bridge::resolve_conj(ctx, &a_h_view);
56 let zero = Tensor::<T>::zeros(
57 &output_dims(&[n, n], batch_dims),
58 tensor.logical_memory_space(),
59 MemoryOrder::ColumnMajor,
60 )
61 .map_err(to_ad_err)?;
62 let top = Tensor::cat(&[&a_h, cotangent], 1).map_err(to_ad_err)?;
63 let bottom = Tensor::cat(&[&zero, &a_h], 1).map_err(to_ad_err)?;
64 let m = Tensor::cat(&[&top, &bottom], 0).map_err(to_ad_err)?;
65 let exp_m = matrix_exp(ctx, &m).map_err(to_ad_err)?;
66
67 exp_m
68 .narrow(0, 0, n)
69 .and_then(|t| t.narrow(1, n, n))
70 .map_err(to_ad_err)
71}