tenferro_linalg/rrules/linear_systems/
inverse.rs1use super::*;
2use tenferro_algebra::Conjugate;
3
4pub fn inv_rrule<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
24 ctx: &mut C,
25 tensor: &Tensor<T>,
26 cotangent: &Tensor<T>,
27) -> AdResult<Tensor<T>>
28where
29 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
30 C: backend::TensorLinalgContextFor<T> + tenferro_prims::TensorResolveConjContextFor<T>,
31 C::Backend: 'static,
32{
33 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv_rrule")
34 .map_err(to_ad_err)?;
35
36 let b_inv = inv(ctx, tensor).map_err(to_ad_err)?;
38 let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
39
40 let bt = matrix_adjoint_eager(ctx, &b_inv).map_err(to_ad_err)?;
41 let bt_db = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &bt, cotangent, n, n, n)
42 .map_err(to_ad_err)?;
43 let grad_a = prims_bridge::batched_gemm_alpha_tensors(ctx, &bt_db, &bt, n, n, n, -T::one())
44 .map_err(to_ad_err)?;
45
46 Ok(grad_a)
47}
48
49pub fn det_rrule<T: KernelLinalgScalar + Conjugate, C>(
69 ctx: &mut C,
70 tensor: &Tensor<T>,
71 cotangent: &Tensor<T>,
72) -> AdResult<Tensor<T>>
73where
74 T: KernelLinalgScalar + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
75 C: backend::TensorLinalgContextFor<T>
76 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
77 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
78 + tenferro_prims::TensorMetadataContextFor
79 + tenferro_prims::TensorResolveConjContextFor<T>,
80 C::Backend: 'static,
81 C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
82 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
83 tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
84 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
85 tenferro_prims::TensorMetadataCastPrims<T::Real, Context = C>,
86{
87 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Det, "det_rrule")
88 .map_err(to_ad_err)?;
89
90 let (n, _) = validate_square(tensor).map_err(to_ad_err)?;
91 if n == 0 {
92 return Tensor::zeros(
93 tensor.dims(),
94 tensor.logical_memory_space(),
95 MemoryOrder::ColumnMajor,
96 )
97 .map_err(to_ad_err);
98 }
99
100 let det_val = det(ctx, tensor).map_err(to_ad_err)?;
102 let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
103 let a_inv_h = matrix_adjoint_eager(ctx, &a_inv).map_err(to_ad_err)?;
104 let det_conj =
105 prims_bridge::scalar_unary_same_shape(ctx, &det_val, tenferro_prims::ScalarUnaryOp::Conj)
106 .map_err(to_ad_err)?;
107
108 let scale = prims_bridge::scalar_binary_same_shape(
110 ctx,
111 cotangent,
112 &det_conj,
113 tenferro_prims::ScalarBinaryOp::Mul,
114 )
115 .map_err(to_ad_err)?;
116
117 let scale_expanded = scale
119 .unsqueeze(0)
120 .map_err(to_ad_err)?
121 .unsqueeze(0)
122 .map_err(to_ad_err)?
123 .broadcast(a_inv_h.dims())
124 .map_err(to_ad_err)?;
125 let grad_a = prims_bridge::scalar_binary_same_shape(
126 ctx,
127 &scale_expanded,
128 &a_inv_h,
129 tenferro_prims::ScalarBinaryOp::Mul,
130 )
131 .map_err(to_ad_err)?;
132
133 Ok(grad_a)
134}