tenferro_linalg/rrules/linear_systems/
slogdet.rs

1use super::*;
2use num_complex::{Complex32, Complex64, ComplexFloat};
3use tenferro_algebra::Conjugate;
4
5#[doc(hidden)]
6pub trait SlogdetRruleDispatch<C>: crate::SlogdetDispatch<C> {
7    fn slogdet_rrule_dispatch(
8        ctx: &mut C,
9        tensor: &Tensor<Self>,
10        cotangent: &SlogdetCotangent<Self, Self::Real>,
11    ) -> AdResult<Tensor<Self>>;
12}
13
14fn slogdet_rrule_real_impl<T, C>(
15    ctx: &mut C,
16    tensor: &Tensor<T>,
17    cotangent: &SlogdetCotangent<T, T>,
18) -> AdResult<Tensor<T>>
19where
20    T: KernelLinalgScalar<Real = T> + num_traits::Float + crate::SlogdetDispatch<C>,
21    C: backend::TensorLinalgContextFor<T>
22        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
23    C::Backend: 'static,
24{
25    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_rrule")
26        .map_err(to_ad_err)?;
27
28    let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
29    if n == 0 {
30        let dims = output_dims(&[n, n], batch_dims);
31        return Tensor::zeros(
32            &dims,
33            tensor.logical_memory_space(),
34            MemoryOrder::ColumnMajor,
35        )
36        .map_err(to_ad_err);
37    }
38
39    if let Some(ref dlog) = cotangent.logabsdet {
40        let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
41        let a_inv_t = matrix_transpose(&a_inv).map_err(to_ad_err)?;
42        let dlog_expanded = dlog
43            .unsqueeze(0)
44            .map_err(to_ad_err)?
45            .unsqueeze(0)
46            .map_err(to_ad_err)?
47            .broadcast(a_inv_t.dims())
48            .map_err(to_ad_err)?;
49        let grad_a = prims_bridge::scalar_binary_same_shape(
50            ctx,
51            &dlog_expanded,
52            &a_inv_t,
53            tenferro_prims::ScalarBinaryOp::Mul,
54        )
55        .map_err(to_ad_err)?;
56        Ok(grad_a)
57    } else {
58        let dims = output_dims(&[n, n], batch_dims);
59        Tensor::zeros(
60            &dims,
61            tensor.logical_memory_space(),
62            MemoryOrder::ColumnMajor,
63        )
64        .map_err(to_ad_err)
65    }
66}
67
68fn slogdet_rrule_complex_impl<T, R, C>(
69    ctx: &mut C,
70    tensor: &Tensor<T>,
71    cotangent: &SlogdetCotangent<T, R>,
72) -> AdResult<Tensor<T>>
73where
74    T: KernelLinalgScalar<Real = R>
75        + Conjugate
76        + ComplexFloat<Real = R>
77        + crate::SlogdetDispatch<C>,
78    T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
79    C: backend::TensorLinalgContextFor<T>
80        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
81        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
82        + tenferro_prims::TensorComplexRealContextFor<T>
83        + tenferro_prims::TensorResolveConjContextFor<T>
84        + tenferro_prims::TensorMetadataContextFor,
85    C::Backend: 'static,
86    R: KernelLinalgScalar<Real = R> + num_traits::Float,
87    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
88        'static
89            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>
90            + tenferro_prims::TensorMetadataCastPrims<R, Context = C>,
91    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
92        tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
93    <C as tenferro_prims::TensorComplexRealContextFor<T>>::ComplexRealBackend:
94        tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
95    <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
96        tenferro_prims::TensorMetadataPrims<Context = C>,
97{
98    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet_rrule")
99        .map_err(to_ad_err)?;
100
101    let (n, batch_dims) = validate_square(tensor).map_err(to_ad_err)?;
102    if n == 0 {
103        let dims = output_dims(&[n, n], batch_dims);
104        return Tensor::zeros(
105            &dims,
106            tensor.logical_memory_space(),
107            MemoryOrder::ColumnMajor,
108        )
109        .map_err(to_ad_err);
110    }
111
112    let a_inv = inv(ctx, tensor).map_err(to_ad_err)?;
113    let a_inv_h = matrix_adjoint_eager(ctx, &a_inv).map_err(to_ad_err)?;
114    let dims = output_dims(&[n, n], batch_dims);
115    let mut grad_a = Tensor::zeros(
116        &dims,
117        tensor.logical_memory_space(),
118        MemoryOrder::ColumnMajor,
119    )
120    .map_err(to_ad_err)?;
121
122    if let Some(ref dlog) = cotangent.logabsdet {
123        let dlog_expanded = dlog
124            .unsqueeze(0)
125            .map_err(to_ad_err)?
126            .unsqueeze(0)
127            .map_err(to_ad_err)?
128            .broadcast(a_inv_h.dims())
129            .map_err(to_ad_err)?;
130        grad_a = prims_bridge::complex_scale_same_shape(ctx, &a_inv_h, &dlog_expanded)
131            .map_err(to_ad_err)?;
132    }
133
134    if let Some(ref dsign) = cotangent.sign {
135        let result = slogdet(ctx, tensor).map_err(to_ad_err)?;
136        let dsign_conj =
137            prims_bridge::scalar_unary_same_shape(ctx, dsign, tenferro_prims::ScalarUnaryOp::Conj)
138                .map_err(to_ad_err)?;
139        let alpha = prims_bridge::scalar_binary_same_shape(
140            ctx,
141            &dsign_conj,
142            &result.sign,
143            tenferro_prims::ScalarBinaryOp::Mul,
144        )
145        .map_err(to_ad_err)?;
146        let alpha_conj =
147            prims_bridge::scalar_unary_same_shape(ctx, &alpha, tenferro_prims::ScalarUnaryOp::Conj)
148                .map_err(to_ad_err)?;
149        let alpha_skew = prims_bridge::scalar_binary_same_shape(
150            ctx,
151            &alpha_conj,
152            &alpha,
153            tenferro_prims::ScalarBinaryOp::Sub,
154        )
155        .map_err(to_ad_err)?;
156        let half = prims_bridge::full_like_constant(
157            scalar_from::<R>(0.5).map_err(to_ad_err)?,
158            alpha_skew.dims(),
159            tensor.logical_memory_space(),
160        )
161        .map_err(to_ad_err)?;
162        let sign_scale =
163            prims_bridge::complex_scale_same_shape(ctx, &alpha_skew, &half).map_err(to_ad_err)?;
164        let sign_scale_expanded = sign_scale
165            .unsqueeze(0)
166            .map_err(to_ad_err)?
167            .unsqueeze(0)
168            .map_err(to_ad_err)?
169            .broadcast(a_inv_h.dims())
170            .map_err(to_ad_err)?;
171        let sign_grad = prims_bridge::scalar_binary_same_shape(
172            ctx,
173            &sign_scale_expanded,
174            &a_inv_h,
175            tenferro_prims::ScalarBinaryOp::Mul,
176        )
177        .map_err(to_ad_err)?;
178        grad_a = prims_bridge::scalar_binary_same_shape(
179            ctx,
180            &grad_a,
181            &sign_grad,
182            tenferro_prims::ScalarBinaryOp::Add,
183        )
184        .map_err(to_ad_err)?;
185    }
186
187    Ok(grad_a)
188}
189
190/// Reverse-mode AD rule for slogdet (VJP / pullback).
191///
192/// `Ā = cotangent_logabsdet · A⁻ᵀ`.
193///
194/// # Examples
195///
196/// ```
197/// use tenferro_linalg::{slogdet_rrule, SlogdetCotangent};
198/// use tenferro_prims::CpuContext;
199/// use tenferro_tensor::{Tensor, MemoryOrder};
200/// use tenferro_device::LogicalMemorySpace;
201///
202/// let col = MemoryOrder::ColumnMajor;
203/// let mem = LogicalMemorySpace::MainMemory;
204/// let mut ctx = CpuContext::new(1);
205/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
206/// let cotangent = SlogdetCotangent {
207///     sign: None,
208///     logabsdet: Some(Tensor::ones(&[], mem, col).unwrap()),
209/// };
210/// let grad_a = slogdet_rrule(&mut ctx, &a, &cotangent).unwrap();
211/// ```
212pub fn slogdet_rrule<T, C>(
213    ctx: &mut C,
214    tensor: &Tensor<T>,
215    cotangent: &SlogdetCotangent<T, T::Real>,
216) -> AdResult<Tensor<T>>
217where
218    T: SlogdetRruleDispatch<C>,
219{
220    T::slogdet_rrule_dispatch(ctx, tensor, cotangent)
221}
222
223impl<C> SlogdetRruleDispatch<C> for f32
224where
225    f32: crate::SlogdetDispatch<C>,
226    C: backend::TensorLinalgContextFor<f32>
227        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>,
228    C::Backend: 'static,
229{
230    fn slogdet_rrule_dispatch(
231        ctx: &mut C,
232        tensor: &Tensor<Self>,
233        cotangent: &SlogdetCotangent<Self, Self::Real>,
234    ) -> AdResult<Tensor<Self>> {
235        slogdet_rrule_real_impl(ctx, tensor, cotangent)
236    }
237}
238
239impl<C> SlogdetRruleDispatch<C> for f64
240where
241    f64: crate::SlogdetDispatch<C>,
242    C: backend::TensorLinalgContextFor<f64>
243        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>,
244    C::Backend: 'static,
245{
246    fn slogdet_rrule_dispatch(
247        ctx: &mut C,
248        tensor: &Tensor<Self>,
249        cotangent: &SlogdetCotangent<Self, Self::Real>,
250    ) -> AdResult<Tensor<Self>> {
251        slogdet_rrule_real_impl(ctx, tensor, cotangent)
252    }
253}
254
255impl<C> SlogdetRruleDispatch<C> for Complex32
256where
257    Complex32: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
258    C: backend::TensorLinalgContextFor<Complex32>
259        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
260        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>
261        + tenferro_prims::TensorComplexRealContextFor<Complex32>
262        + tenferro_prims::TensorResolveConjContextFor<Complex32>
263        + tenferro_prims::TensorMetadataContextFor,
264    C::Backend: 'static,
265    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
266        'static
267            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
268            + tenferro_prims::TensorMetadataCastPrims<f32, Context = C>,
269    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>>::ScalarBackend:
270        tenferro_prims::TensorMetadataCastPrims<Complex32, Context = C>,
271    <C as tenferro_prims::TensorComplexRealContextFor<Complex32>>::ComplexRealBackend:
272        tenferro_prims::TensorComplexRealPrims<Complex32, Context = C, Real = f32>,
273    <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
274        tenferro_prims::TensorMetadataPrims<Context = C>,
275{
276    fn slogdet_rrule_dispatch(
277        ctx: &mut C,
278        tensor: &Tensor<Self>,
279        cotangent: &SlogdetCotangent<Self, Self::Real>,
280    ) -> AdResult<Tensor<Self>> {
281        slogdet_rrule_complex_impl::<Complex32, f32, C>(ctx, tensor, cotangent)
282    }
283}
284
285impl<C> SlogdetRruleDispatch<C> for Complex64
286where
287    Complex64: crate::SlogdetDispatch<C> + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
288    C: backend::TensorLinalgContextFor<Complex64>
289        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
290        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>
291        + tenferro_prims::TensorComplexRealContextFor<Complex64>
292        + tenferro_prims::TensorResolveConjContextFor<Complex64>
293        + tenferro_prims::TensorMetadataContextFor,
294    C::Backend: 'static,
295    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
296        'static
297            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
298            + tenferro_prims::TensorMetadataCastPrims<f64, Context = C>,
299    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>>::ScalarBackend:
300        tenferro_prims::TensorMetadataCastPrims<Complex64, Context = C>,
301    <C as tenferro_prims::TensorComplexRealContextFor<Complex64>>::ComplexRealBackend:
302        tenferro_prims::TensorComplexRealPrims<Complex64, Context = C, Real = f64>,
303    <C as tenferro_prims::TensorMetadataContextFor>::MetadataBackend:
304        tenferro_prims::TensorMetadataPrims<Context = C>,
305{
306    fn slogdet_rrule_dispatch(
307        ctx: &mut C,
308        tensor: &Tensor<Self>,
309        cotangent: &SlogdetCotangent<Self, Self::Real>,
310    ) -> AdResult<Tensor<Self>> {
311        slogdet_rrule_complex_impl::<Complex64, f64, C>(ctx, tensor, cotangent)
312    }
313}