tenferro_linalg/frules/linear_systems/
inv_det.rs1use super::*;
2
3pub fn inv_frule<T: KernelLinalgScalar, C>(
21 ctx: &mut C,
22 tensor: &Tensor<T>,
23 tangent: &Tensor<T>,
24) -> AdResult<(Tensor<T>, Tensor<T>)>
25where
26 T: KernelLinalgScalar,
27 C: backend::TensorLinalgContextFor<T>
28 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
29 C::Backend: 'static,
30 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
31 'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
32{
33 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv_frule")
34 .map_err(to_ad_err)?;
35
36 let b_inv = inv(ctx, tensor).map_err(to_ad_err)?;
38 let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
39
40 let b_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &b_inv, tangent, n, n, n)
41 .map_err(to_ad_err)?;
42 let db = prims_bridge::batched_gemm_alpha_tensors(ctx, &b_da, &b_inv, n, n, n, -T::one())
43 .map_err(to_ad_err)?;
44
45 Ok((b_inv, db))
46}
47
48pub fn det_frule<T: KernelLinalgScalar, C>(
66 ctx: &mut C,
67 tensor: &Tensor<T>,
68 tangent: &Tensor<T>,
69) -> AdResult<(Tensor<T>, Tensor<T>)>
70where
71 T: KernelLinalgScalar + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
72 C: backend::TensorLinalgContextFor<T>
73 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
74 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
75 + tenferro_prims::TensorMetadataContextFor,
76 C::Backend: 'static,
77 C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
78 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
79 tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
80 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
81 tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
82{
83 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Det, "det_frule")
84 .map_err(to_ad_err)?;
85
86 let d = det(ctx, tensor).map_err(to_ad_err)?;
88 let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
89 if n == 0 {
90 let dd = Tensor::zeros(
91 d.dims(),
92 tensor.logical_memory_space(),
93 MemoryOrder::ColumnMajor,
94 )
95 .map_err(to_ad_err)?;
96 return Ok((d, dd));
97 }
98
99 let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
100 let a_inv_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &a_inv, tangent, n, n, n)
101 .map_err(to_ad_err)?;
102 let trace = trace_tensor(ctx, &a_inv_da).map_err(to_ad_err)?;
103
104 let dd = prims_bridge::scalar_binary_same_shape(
106 ctx,
107 &d,
108 &trace,
109 tenferro_prims::ScalarBinaryOp::Mul,
110 )
111 .map_err(to_ad_err)?;
112
113 Ok((d, dd))
114}