tenferro_linalg/frules/linear_systems/
slogdet.rs

1use super::*;
2use num_complex::{Complex32, Complex64, ComplexFloat};
3use tenferro_algebra::Conjugate;
4
5#[doc(hidden)]
6pub trait SlogdetFruleDispatch<C>: crate::SlogdetDispatch<C> {
7    fn slogdet_frule_dispatch(
8        ctx: &mut C,
9        tensor: &Tensor<Self>,
10        tangent: &Tensor<Self>,
11    ) -> AdResult<(
12        SlogdetResult<Self, Self::Real>,
13        SlogdetResult<Self, Self::Real>,
14    )>;
15}
16
17fn slogdet_frule_real_impl<T, C>(
18    ctx: &mut C,
19    tensor: &Tensor<T>,
20    tangent: &Tensor<T>,
21) -> AdResult<(SlogdetResult<T, T::Real>, SlogdetResult<T, T::Real>)>
22where
23    T: KernelLinalgScalar<Real = T> + num_traits::Float + crate::SlogdetDispatch<C>,
24    C: backend::TensorLinalgContextFor<T>
25        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
26    C::Backend: 'static,
27    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
28        'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
29{
30    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_frule")
31        .map_err(to_ad_err)?;
32
33    let result = slogdet(ctx, tensor).map_err(to_ad_err)?;
34    let (n, _batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
35    if n == 0 {
36        let dsign = Tensor::zeros(
37            result.sign.dims(),
38            tensor.logical_memory_space(),
39            MemoryOrder::ColumnMajor,
40        )
41        .map_err(to_ad_err)?;
42        let dlog = Tensor::zeros(
43            result.logabsdet.dims(),
44            tensor.logical_memory_space(),
45            MemoryOrder::ColumnMajor,
46        )
47        .map_err(to_ad_err)?;
48        return Ok((
49            result,
50            SlogdetResult {
51                sign: dsign,
52                logabsdet: dlog,
53            },
54        ));
55    }
56
57    let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
58    let a_inv_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &a_inv, tangent, n, n, n)
59        .map_err(to_ad_err)?;
60    let dlog = trace_tensor(ctx, &a_inv_da).map_err(to_ad_err)?;
61    let dsign = Tensor::zeros(
62        result.sign.dims(),
63        tensor.logical_memory_space(),
64        MemoryOrder::ColumnMajor,
65    )
66    .map_err(to_ad_err)?;
67
68    let dresult = SlogdetResult {
69        sign: dsign,
70        logabsdet: dlog,
71    };
72    Ok((result, dresult))
73}
74
75fn slogdet_frule_complex_impl<T, R, C>(
76    ctx: &mut C,
77    tensor: &Tensor<T>,
78    tangent: &Tensor<T>,
79) -> AdResult<(SlogdetResult<T, R>, SlogdetResult<T, R>)>
80where
81    T: KernelLinalgScalar<Real = R>
82        + Conjugate
83        + ComplexFloat<Real = R>
84        + crate::SlogdetDispatch<C>,
85    T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
86    C: backend::TensorLinalgContextFor<T>
87        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
88        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
89        + tenferro_prims::TensorComplexRealContextFor<T>
90        + tenferro_prims::TensorMetadataContextFor,
91    C::Backend: 'static,
92    R: KernelLinalgScalar<Real = R> + num_traits::Float,
93    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
94        'static
95            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>
96            + tenferro_prims::TensorMetadataCastPrims<R, Context = C>,
97    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
98        tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
99    <C as tenferro_prims::TensorComplexRealContextFor<T>>::ComplexRealBackend:
100        tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
101    <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
102        tenferro_prims::TensorMetadataPrims<Context = C>,
103{
104    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_frule")
105        .map_err(to_ad_err)?;
106
107    let result = slogdet(ctx, tensor).map_err(to_ad_err)?;
108    let (n, _batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
109    if n == 0 {
110        let dsign = Tensor::zeros(
111            result.sign.dims(),
112            tensor.logical_memory_space(),
113            MemoryOrder::ColumnMajor,
114        )
115        .map_err(to_ad_err)?;
116        let dlog = Tensor::zeros(
117            result.logabsdet.dims(),
118            tensor.logical_memory_space(),
119            MemoryOrder::ColumnMajor,
120        )
121        .map_err(to_ad_err)?;
122        return Ok((
123            result,
124            SlogdetResult {
125                sign: dsign,
126                logabsdet: dlog,
127            },
128        ));
129    }
130
131    let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
132    let a_inv_da = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &a_inv, tangent, n, n, n)
133        .map_err(to_ad_err)?;
134    let trace = trace_tensor(ctx, &a_inv_da).map_err(to_ad_err)?;
135    let dlog = prims_bridge::complex_real_unary_same_shape(
136        ctx,
137        &trace,
138        tenferro_prims::ComplexRealUnaryOp::Real,
139    )
140    .map_err(to_ad_err)?;
141    let trace_conj =
142        prims_bridge::scalar_unary_same_shape(ctx, &trace, tenferro_prims::ScalarUnaryOp::Conj)
143            .map_err(to_ad_err)?;
144    let trace_diff = prims_bridge::scalar_binary_same_shape(
145        ctx,
146        &trace,
147        &trace_conj,
148        tenferro_prims::ScalarBinaryOp::Sub,
149    )
150    .map_err(to_ad_err)?;
151    let half = prims_bridge::full_like_constant(
152        scalar_from::<R>(0.5).map_err(to_ad_err)?,
153        trace_diff.dims(),
154        tensor.logical_memory_space(),
155    )
156    .map_err(to_ad_err)?;
157    let phase_tangent =
158        prims_bridge::complex_scale_same_shape(ctx, &trace_diff, &half).map_err(to_ad_err)?;
159    let dsign = prims_bridge::scalar_binary_same_shape(
160        ctx,
161        &result.sign,
162        &phase_tangent,
163        tenferro_prims::ScalarBinaryOp::Mul,
164    )
165    .map_err(to_ad_err)?;
166
167    let dresult = SlogdetResult {
168        sign: dsign,
169        logabsdet: dlog,
170    };
171    Ok((result, dresult))
172}
173
174impl<C> SlogdetFruleDispatch<C> for f32
175where
176    f32: crate::SlogdetDispatch<C>,
177    C: backend::TensorLinalgContextFor<f32>
178        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>,
179    C::Backend: 'static,
180    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
181        'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>,
182{
183    fn slogdet_frule_dispatch(
184        ctx: &mut C,
185        tensor: &Tensor<Self>,
186        tangent: &Tensor<Self>,
187    ) -> AdResult<(
188        SlogdetResult<Self, Self::Real>,
189        SlogdetResult<Self, Self::Real>,
190    )> {
191        slogdet_frule_real_impl(ctx, tensor, tangent)
192    }
193}
194
195impl<C> SlogdetFruleDispatch<C> for f64
196where
197    f64: crate::SlogdetDispatch<C>,
198    C: backend::TensorLinalgContextFor<f64>
199        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>,
200    C::Backend: 'static,
201    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
202        'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>,
203{
204    fn slogdet_frule_dispatch(
205        ctx: &mut C,
206        tensor: &Tensor<Self>,
207        tangent: &Tensor<Self>,
208    ) -> AdResult<(
209        SlogdetResult<Self, Self::Real>,
210        SlogdetResult<Self, Self::Real>,
211    )> {
212        slogdet_frule_real_impl(ctx, tensor, tangent)
213    }
214}
215
216impl<C> SlogdetFruleDispatch<C> for Complex32
217where
218    Complex32: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
219    C: backend::TensorLinalgContextFor<Complex32>
220        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
221        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>
222        + tenferro_prims::TensorComplexRealContextFor<Complex32>
223        + tenferro_prims::TensorMetadataContextFor,
224    C::Backend: 'static,
225    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
226        'static
227            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
228            + tenferro_prims::TensorMetadataCastPrims<f32, Context = C>,
229    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>>::ScalarBackend:
230        tenferro_prims::TensorMetadataCastPrims<Complex32, Context = C>,
231    <C as tenferro_prims::TensorComplexRealContextFor<Complex32>>::ComplexRealBackend:
232        tenferro_prims::TensorComplexRealPrims<Complex32, Context = C, Real = f32>,
233    <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
234        tenferro_prims::TensorMetadataPrims<Context = C>,
235{
236    fn slogdet_frule_dispatch(
237        ctx: &mut C,
238        tensor: &Tensor<Self>,
239        tangent: &Tensor<Self>,
240    ) -> AdResult<(
241        SlogdetResult<Self, Self::Real>,
242        SlogdetResult<Self, Self::Real>,
243    )> {
244        slogdet_frule_complex_impl::<Complex32, f32, C>(ctx, tensor, tangent)
245    }
246}
247
248impl<C> SlogdetFruleDispatch<C> for Complex64
249where
250    Complex64: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
251    C: backend::TensorLinalgContextFor<Complex64>
252        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
253        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>
254        + tenferro_prims::TensorComplexRealContextFor<Complex64>
255        + tenferro_prims::TensorMetadataContextFor,
256    C::Backend: 'static,
257    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
258        'static
259            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
260            + tenferro_prims::TensorMetadataCastPrims<f64, Context = C>,
261    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>>::ScalarBackend:
262        tenferro_prims::TensorMetadataCastPrims<Complex64, Context = C>,
263    <C as tenferro_prims::TensorComplexRealContextFor<Complex64>>::ComplexRealBackend:
264        tenferro_prims::TensorComplexRealPrims<Complex64, Context = C, Real = f64>,
265    <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
266        tenferro_prims::TensorMetadataPrims<Context = C>,
267{
268    fn slogdet_frule_dispatch(
269        ctx: &mut C,
270        tensor: &Tensor<Self>,
271        tangent: &Tensor<Self>,
272    ) -> AdResult<(
273        SlogdetResult<Self, Self::Real>,
274        SlogdetResult<Self, Self::Real>,
275    )> {
276        slogdet_frule_complex_impl::<Complex64, f64, C>(ctx, tensor, tangent)
277    }
278}
279
280/// Forward-mode AD rule for slogdet (JVP / pushforward).
281///
282/// # Examples
283///
284/// ```
285/// use tenferro_linalg::slogdet_frule;
286/// use tenferro_prims::CpuContext;
287/// use tenferro_tensor::{Tensor, MemoryOrder};
288/// use tenferro_device::LogicalMemorySpace;
289///
290/// let col = MemoryOrder::ColumnMajor;
291/// let mem = LogicalMemorySpace::MainMemory;
292/// let mut ctx = CpuContext::new(1);
293/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
294/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
295/// let (result, dresult) = slogdet_frule(&mut ctx, &a, &da).unwrap();
296/// ```
297pub fn slogdet_frule<T, C>(
298    ctx: &mut C,
299    tensor: &Tensor<T>,
300    tangent: &Tensor<T>,
301) -> AdResult<(SlogdetResult<T, T::Real>, SlogdetResult<T, T::Real>)>
302where
303    T: SlogdetFruleDispatch<C>,
304{
305    T::slogdet_frule_dispatch(ctx, tensor, tangent)
306}