1use super::*;
2use num_complex::{Complex32, Complex64, ComplexFloat};
3use tenferro_algebra::Conjugate;
4
5#[doc(hidden)]
6pub trait SlogdetFruleDispatch<C>: crate::SlogdetDispatch<C> {
7 fn slogdet_frule_dispatch(
8 ctx: &mut C,
9 tensor: &Tensor<Self>,
10 tangent: &Tensor<Self>,
11 ) -> AdResult<(
12 SlogdetResult<Self, Self::Real>,
13 SlogdetResult<Self, Self::Real>,
14 )>;
15}
16
17fn slogdet_frule_real_impl<T, C>(
18 ctx: &mut C,
19 tensor: &Tensor<T>,
20 tangent: &Tensor<T>,
21) -> AdResult<(SlogdetResult<T, T::Real>, SlogdetResult<T, T::Real>)>
22where
23 T: KernelLinalgScalar<Real = T> + num_traits::Float + crate::SlogdetDispatch<C>,
24 C: backend::TensorLinalgContextFor<T>
25 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
26 C::Backend: 'static,
27 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
28 'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
29{
30 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_frule")
31 .map_err(to_ad_err)?;
32
33 let result = slogdet(ctx, tensor).map_err(to_ad_err)?;
34 let (n, _batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
35 if n == 0 {
36 let dsign = Tensor::zeros(
37 result.sign.dims(),
38 tensor.logical_memory_space(),
39 MemoryOrder::ColumnMajor,
40 )
41 .map_err(to_ad_err)?;
42 let dlog = Tensor::zeros(
43 result.logabsdet.dims(),
44 tensor.logical_memory_space(),
45 MemoryOrder::ColumnMajor,
46 )
47 .map_err(to_ad_err)?;
48 return Ok((
49 result,
50 SlogdetResult {
51 sign: dsign,
52 logabsdet: dlog,
53 },
54 ));
55 }
56
57 let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
58 let a_inv_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &a_inv, tangent, n, n, n)
59 .map_err(to_ad_err)?;
60 let dlog = trace_tensor(ctx, &a_inv_da).map_err(to_ad_err)?;
61 let dsign = Tensor::zeros(
62 result.sign.dims(),
63 tensor.logical_memory_space(),
64 MemoryOrder::ColumnMajor,
65 )
66 .map_err(to_ad_err)?;
67
68 let dresult = SlogdetResult {
69 sign: dsign,
70 logabsdet: dlog,
71 };
72 Ok((result, dresult))
73}
74
75fn slogdet_frule_complex_impl<T, R, C>(
76 ctx: &mut C,
77 tensor: &Tensor<T>,
78 tangent: &Tensor<T>,
79) -> AdResult<(SlogdetResult<T, R>, SlogdetResult<T, R>)>
80where
81 T: KernelLinalgScalar<Real = R>
82 + Conjugate
83 + ComplexFloat<Real = R>
84 + crate::SlogdetDispatch<C>,
85 T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
86 C: backend::TensorLinalgContextFor<T>
87 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
88 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
89 + tenferro_prims::TensorComplexRealContextFor<T>
90 + tenferro_prims::TensorMetadataContextFor,
91 C::Backend: 'static,
92 R: KernelLinalgScalar<Real = R> + num_traits::Float,
93 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
94 'static
95 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>
96 + tenferro_prims::TensorMetadataCastPrims<R, Context = C>,
97 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
98 tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
99 <C as tenferro_prims::TensorComplexRealContextFor<T>>::ComplexRealBackend:
100 tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
101 <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
102 tenferro_prims::TensorMetadataPrims<Context = C>,
103{
104 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_frule")
105 .map_err(to_ad_err)?;
106
107 let result = slogdet(ctx, tensor).map_err(to_ad_err)?;
108 let (n, _batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
109 if n == 0 {
110 let dsign = Tensor::zeros(
111 result.sign.dims(),
112 tensor.logical_memory_space(),
113 MemoryOrder::ColumnMajor,
114 )
115 .map_err(to_ad_err)?;
116 let dlog = Tensor::zeros(
117 result.logabsdet.dims(),
118 tensor.logical_memory_space(),
119 MemoryOrder::ColumnMajor,
120 )
121 .map_err(to_ad_err)?;
122 return Ok((
123 result,
124 SlogdetResult {
125 sign: dsign,
126 logabsdet: dlog,
127 },
128 ));
129 }
130
131 let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
132 let a_inv_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &a_inv, tangent, n, n, n)
133 .map_err(to_ad_err)?;
134 let trace = trace_tensor(ctx, &a_inv_da).map_err(to_ad_err)?;
135 let dlog = prims_bridge::complex_real_unary_same_shape(
136 ctx,
137 &trace,
138 tenferro_prims::ComplexRealUnaryOp::Real,
139 )
140 .map_err(to_ad_err)?;
141 let trace_conj =
142 prims_bridge::scalar_unary_same_shape(ctx, &trace, tenferro_prims::ScalarUnaryOp::Conj)
143 .map_err(to_ad_err)?;
144 let trace_diff = prims_bridge::scalar_binary_same_shape(
145 ctx,
146 &trace,
147 &trace_conj,
148 tenferro_prims::ScalarBinaryOp::Sub,
149 )
150 .map_err(to_ad_err)?;
151 let half = prims_bridge::full_like_constant(
152 scalar_from::<R>(0.5).map_err(to_ad_err)?,
153 trace_diff.dims(),
154 tensor.logical_memory_space(),
155 )
156 .map_err(to_ad_err)?;
157 let phase_tangent =
158 prims_bridge::complex_scale_same_shape(ctx, &trace_diff, &half).map_err(to_ad_err)?;
159 let dsign = prims_bridge::scalar_binary_same_shape(
160 ctx,
161 &result.sign,
162 &phase_tangent,
163 tenferro_prims::ScalarBinaryOp::Mul,
164 )
165 .map_err(to_ad_err)?;
166
167 let dresult = SlogdetResult {
168 sign: dsign,
169 logabsdet: dlog,
170 };
171 Ok((result, dresult))
172}
173
174impl<C> SlogdetFruleDispatch<C> for f32
175where
176 f32: crate::SlogdetDispatch<C>,
177 C: backend::TensorLinalgContextFor<f32>
178 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>,
179 C::Backend: 'static,
180 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
181 'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>,
182{
183 fn slogdet_frule_dispatch(
184 ctx: &mut C,
185 tensor: &Tensor<Self>,
186 tangent: &Tensor<Self>,
187 ) -> AdResult<(
188 SlogdetResult<Self, Self::Real>,
189 SlogdetResult<Self, Self::Real>,
190 )> {
191 slogdet_frule_real_impl(ctx, tensor, tangent)
192 }
193}
194
195impl<C> SlogdetFruleDispatch<C> for f64
196where
197 f64: crate::SlogdetDispatch<C>,
198 C: backend::TensorLinalgContextFor<f64>
199 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>,
200 C::Backend: 'static,
201 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
202 'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>,
203{
204 fn slogdet_frule_dispatch(
205 ctx: &mut C,
206 tensor: &Tensor<Self>,
207 tangent: &Tensor<Self>,
208 ) -> AdResult<(
209 SlogdetResult<Self, Self::Real>,
210 SlogdetResult<Self, Self::Real>,
211 )> {
212 slogdet_frule_real_impl(ctx, tensor, tangent)
213 }
214}
215
216impl<C> SlogdetFruleDispatch<C> for Complex32
217where
218 Complex32: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
219 C: backend::TensorLinalgContextFor<Complex32>
220 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
221 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>
222 + tenferro_prims::TensorComplexRealContextFor<Complex32>
223 + tenferro_prims::TensorMetadataContextFor,
224 C::Backend: 'static,
225 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
226 'static
227 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
228 + tenferro_prims::TensorMetadataCastPrims<f32, Context = C>,
229 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>>::ScalarBackend:
230 tenferro_prims::TensorMetadataCastPrims<Complex32, Context = C>,
231 <C as tenferro_prims::TensorComplexRealContextFor<Complex32>>::ComplexRealBackend:
232 tenferro_prims::TensorComplexRealPrims<Complex32, Context = C, Real = f32>,
233 <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
234 tenferro_prims::TensorMetadataPrims<Context = C>,
235{
236 fn slogdet_frule_dispatch(
237 ctx: &mut C,
238 tensor: &Tensor<Self>,
239 tangent: &Tensor<Self>,
240 ) -> AdResult<(
241 SlogdetResult<Self, Self::Real>,
242 SlogdetResult<Self, Self::Real>,
243 )> {
244 slogdet_frule_complex_impl::<Complex32, f32, C>(ctx, tensor, tangent)
245 }
246}
247
248impl<C> SlogdetFruleDispatch<C> for Complex64
249where
250 Complex64: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
251 C: backend::TensorLinalgContextFor<Complex64>
252 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
253 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>
254 + tenferro_prims::TensorComplexRealContextFor<Complex64>
255 + tenferro_prims::TensorMetadataContextFor,
256 C::Backend: 'static,
257 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
258 'static
259 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
260 + tenferro_prims::TensorMetadataCastPrims<f64, Context = C>,
261 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>>::ScalarBackend:
262 tenferro_prims::TensorMetadataCastPrims<Complex64, Context = C>,
263 <C as tenferro_prims::TensorComplexRealContextFor<Complex64>>::ComplexRealBackend:
264 tenferro_prims::TensorComplexRealPrims<Complex64, Context = C, Real = f64>,
265 <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
266 tenferro_prims::TensorMetadataPrims<Context = C>,
267{
268 fn slogdet_frule_dispatch(
269 ctx: &mut C,
270 tensor: &Tensor<Self>,
271 tangent: &Tensor<Self>,
272 ) -> AdResult<(
273 SlogdetResult<Self, Self::Real>,
274 SlogdetResult<Self, Self::Real>,
275 )> {
276 slogdet_frule_complex_impl::<Complex64, f64, C>(ctx, tensor, tangent)
277 }
278}
279
280pub fn slogdet_frule<T, C>(
298 ctx: &mut C,
299 tensor: &Tensor<T>,
300 tangent: &Tensor<T>,
301) -> AdResult<(SlogdetResult<T, T::Real>, SlogdetResult<T, T::Real>)>
302where
303 T: SlogdetFruleDispatch<C>,
304{
305 T::slogdet_frule_dispatch(ctx, tensor, tangent)
306}