1use super::*;
2use num_complex::{Complex32, Complex64, ComplexFloat};
3use tenferro_algebra::Conjugate;
4
5#[doc(hidden)]
6pub trait SlogdetRruleDispatch<C>: crate::SlogdetDispatch<C> {
7 fn slogdet_rrule_dispatch(
8 ctx: &mut C,
9 tensor: &Tensor<Self>,
10 cotangent: &SlogdetCotangent<Self, Self::Real>,
11 ) -> AdResult<Tensor<Self>>;
12}
13
14fn slogdet_rrule_real_impl<T, C>(
15 ctx: &mut C,
16 tensor: &Tensor<T>,
17 cotangent: &SlogdetCotangent<T, T>,
18) -> AdResult<Tensor<T>>
19where
20 T: KernelLinalgScalar<Real = T> + num_traits::Float + crate::SlogdetDispatch<C>,
21 C: backend::TensorLinalgContextFor<T>
22 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
23 C::Backend: 'static,
24{
25 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_rrule")
26 .map_err(to_ad_err)?;
27
28 let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
29 if n == 0 {
30 let dims = output_dims(&[n, n], batch_dims);
31 return Tensor::zeros(
32 &dims,
33 tensor.logical_memory_space(),
34 MemoryOrder::ColumnMajor,
35 )
36 .map_err(to_ad_err);
37 }
38
39 if let Some(ref dlog) = cotangent.logabsdet {
40 let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
41 let a_inv_t = matrix_transpose(&a_inv).map_err(to_ad_err)?;
42 let dlog_expanded = dlog
43 .unsqueeze(0)
44 .map_err(to_ad_err)?
45 .unsqueeze(0)
46 .map_err(to_ad_err)?
47 .broadcast(a_inv_t.dims())
48 .map_err(to_ad_err)?;
49 let grad_a = prims_bridge::scalar_binary_same_shape(
50 ctx,
51 &dlog_expanded,
52 &a_inv_t,
53 tenferro_prims::ScalarBinaryOp::Mul,
54 )
55 .map_err(to_ad_err)?;
56 Ok(grad_a)
57 } else {
58 let dims = output_dims(&[n, n], batch_dims);
59 Tensor::zeros(
60 &dims,
61 tensor.logical_memory_space(),
62 MemoryOrder::ColumnMajor,
63 )
64 .map_err(to_ad_err)
65 }
66}
67
68fn slogdet_rrule_complex_impl<T, R, C>(
69 ctx: &mut C,
70 tensor: &Tensor<T>,
71 cotangent: &SlogdetCotangent<T, R>,
72) -> AdResult<Tensor<T>>
73where
74 T: KernelLinalgScalar<Real = R>
75 + Conjugate
76 + ComplexFloat<Real = R>
77 + crate::SlogdetDispatch<C>,
78 T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
79 C: backend::TensorLinalgContextFor<T>
80 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
81 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
82 + tenferro_prims::TensorComplexRealContextFor<T>
83 + tenferro_prims::TensorResolveConjContextFor<T>
84 + tenferro_prims::TensorMetadataContextFor,
85 C::Backend: 'static,
86 R: KernelLinalgScalar<Real = R> + num_traits::Float,
87 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
88 'static
89 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>
90 + tenferro_prims::TensorMetadataCastPrims<R, Context = C>,
91 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
92 tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
93 <C as tenferro_prims::TensorComplexRealContextFor<T>>::ComplexRealBackend:
94 tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
95 <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
96 tenferro_prims::TensorMetadataPrims<Context = C>,
97{
98 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_rrule")
99 .map_err(to_ad_err)?;
100
101 let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
102 if n == 0 {
103 let dims = output_dims(&[n, n], batch_dims);
104 return Tensor::zeros(
105 &dims,
106 tensor.logical_memory_space(),
107 MemoryOrder::ColumnMajor,
108 )
109 .map_err(to_ad_err);
110 }
111
112 let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
113 let a_inv_h = matrix_adjoint_eager(ctx, &a_inv).map_err(to_ad_err)?;
114 let dims = output_dims(&[n, n], batch_dims);
115 let mut grad_a = Tensor::zeros(
116 &dims,
117 tensor.logical_memory_space(),
118 MemoryOrder::ColumnMajor,
119 )
120 .map_err(to_ad_err)?;
121
122 if let Some(ref dlog) = cotangent.logabsdet {
123 let dlog_expanded = dlog
124 .unsqueeze(0)
125 .map_err(to_ad_err)?
126 .unsqueeze(0)
127 .map_err(to_ad_err)?
128 .broadcast(a_inv_h.dims())
129 .map_err(to_ad_err)?;
130 grad_a = prims_bridge::complex_scale_same_shape(ctx, &a_inv_h, &dlog_expanded)
131 .map_err(to_ad_err)?;
132 }
133
134 if let Some(ref dsign) = cotangent.sign {
135 let result = slogdet(ctx, tensor).map_err(to_ad_err)?;
136 let dsign_conj =
137 prims_bridge::scalar_unary_same_shape(ctx, dsign, tenferro_prims::ScalarUnaryOp::Conj)
138 .map_err(to_ad_err)?;
139 let alpha = prims_bridge::scalar_binary_same_shape(
140 ctx,
141 &dsign_conj,
142 &result.sign,
143 tenferro_prims::ScalarBinaryOp::Mul,
144 )
145 .map_err(to_ad_err)?;
146 let alpha_conj =
147 prims_bridge::scalar_unary_same_shape(ctx, &alpha, tenferro_prims::ScalarUnaryOp::Conj)
148 .map_err(to_ad_err)?;
149 let alpha_skew = prims_bridge::scalar_binary_same_shape(
150 ctx,
151 &alpha_conj,
152 &alpha,
153 tenferro_prims::ScalarBinaryOp::Sub,
154 )
155 .map_err(to_ad_err)?;
156 let half = prims_bridge::full_like_constant(
157 scalar_from::<R>(0.5).map_err(to_ad_err)?,
158 alpha_skew.dims(),
159 tensor.logical_memory_space(),
160 )
161 .map_err(to_ad_err)?;
162 let sign_scale =
163 prims_bridge::complex_scale_same_shape(ctx, &alpha_skew, &half).map_err(to_ad_err)?;
164 let sign_scale_expanded = sign_scale
165 .unsqueeze(0)
166 .map_err(to_ad_err)?
167 .unsqueeze(0)
168 .map_err(to_ad_err)?
169 .broadcast(a_inv_h.dims())
170 .map_err(to_ad_err)?;
171 let sign_grad = prims_bridge::scalar_binary_same_shape(
172 ctx,
173 &sign_scale_expanded,
174 &a_inv_h,
175 tenferro_prims::ScalarBinaryOp::Mul,
176 )
177 .map_err(to_ad_err)?;
178 grad_a = prims_bridge::scalar_binary_same_shape(
179 ctx,
180 &grad_a,
181 &sign_grad,
182 tenferro_prims::ScalarBinaryOp::Add,
183 )
184 .map_err(to_ad_err)?;
185 }
186
187 Ok(grad_a)
188}
189
190pub fn slogdet_rrule<T, C>(
213 ctx: &mut C,
214 tensor: &Tensor<T>,
215 cotangent: &SlogdetCotangent<T, T::Real>,
216) -> AdResult<Tensor<T>>
217where
218 T: SlogdetRruleDispatch<C>,
219{
220 T::slogdet_rrule_dispatch(ctx, tensor, cotangent)
221}
222
223impl<C> SlogdetRruleDispatch<C> for f32
224where
225 f32: crate::SlogdetDispatch<C>,
226 C: backend::TensorLinalgContextFor<f32>
227 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>,
228 C::Backend: 'static,
229{
230 fn slogdet_rrule_dispatch(
231 ctx: &mut C,
232 tensor: &Tensor<Self>,
233 cotangent: &SlogdetCotangent<Self, Self::Real>,
234 ) -> AdResult<Tensor<Self>> {
235 slogdet_rrule_real_impl(ctx, tensor, cotangent)
236 }
237}
238
239impl<C> SlogdetRruleDispatch<C> for f64
240where
241 f64: crate::SlogdetDispatch<C>,
242 C: backend::TensorLinalgContextFor<f64>
243 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>,
244 C::Backend: 'static,
245{
246 fn slogdet_rrule_dispatch(
247 ctx: &mut C,
248 tensor: &Tensor<Self>,
249 cotangent: &SlogdetCotangent<Self, Self::Real>,
250 ) -> AdResult<Tensor<Self>> {
251 slogdet_rrule_real_impl(ctx, tensor, cotangent)
252 }
253}
254
255impl<C> SlogdetRruleDispatch<C> for Complex32
256where
257 Complex32: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
258 C: backend::TensorLinalgContextFor<Complex32>
259 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
260 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>
261 + tenferro_prims::TensorComplexRealContextFor<Complex32>
262 + tenferro_prims::TensorResolveConjContextFor<Complex32>
263 + tenferro_prims::TensorMetadataContextFor,
264 C::Backend: 'static,
265 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
266 'static
267 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
268 + tenferro_prims::TensorMetadataCastPrims<f32, Context = C>,
269 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>>::ScalarBackend:
270 tenferro_prims::TensorMetadataCastPrims<Complex32, Context = C>,
271 <C as tenferro_prims::TensorComplexRealContextFor<Complex32>>::ComplexRealBackend:
272 tenferro_prims::TensorComplexRealPrims<Complex32, Context = C, Real = f32>,
273 <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
274 tenferro_prims::TensorMetadataPrims<Context = C>,
275{
276 fn slogdet_rrule_dispatch(
277 ctx: &mut C,
278 tensor: &Tensor<Self>,
279 cotangent: &SlogdetCotangent<Self, Self::Real>,
280 ) -> AdResult<Tensor<Self>> {
281 slogdet_rrule_complex_impl::<Complex32, f32, C>(ctx, tensor, cotangent)
282 }
283}
284
285impl<C> SlogdetRruleDispatch<C> for Complex64
286where
287 Complex64: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
288 C: backend::TensorLinalgContextFor<Complex64>
289 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
290 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>
291 + tenferro_prims::TensorComplexRealContextFor<Complex64>
292 + tenferro_prims::TensorResolveConjContextFor<Complex64>
293 + tenferro_prims::TensorMetadataContextFor,
294 C::Backend: 'static,
295 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
296 'static
297 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
298 + tenferro_prims::TensorMetadataCastPrims<f64, Context = C>,
299 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>>::ScalarBackend:
300 tenferro_prims::TensorMetadataCastPrims<Complex64, Context = C>,
301 <C as tenferro_prims::TensorComplexRealContextFor<Complex64>>::ComplexRealBackend:
302 tenferro_prims::TensorComplexRealPrims<Complex64, Context = C, Real = f64>,
303 <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
304 tenferro_prims::TensorMetadataPrims<Context = C>,
305{
306 fn slogdet_rrule_dispatch(
307 ctx: &mut C,
308 tensor: &Tensor<Self>,
309 cotangent: &SlogdetCotangent<Self, Self::Real>,
310 ) -> AdResult<Tensor<Self>> {
311 slogdet_rrule_complex_impl::<Complex64, f64, C>(ctx, tensor, cotangent)
312 }
313}