tenferro_linalg/rrules/
norms.rs

1use super::*;
2use num_complex::ComplexFloat;
3
4/// Reverse-mode AD rule for norm (VJP / pullback).
5///
6/// # Examples
7///
8/// ```
9/// use tenferro_linalg::{norm_rrule, NormKind};
10/// use tenferro_prims::CpuContext;
11/// use tenferro_tensor::{Tensor, MemoryOrder};
12/// use tenferro_device::LogicalMemorySpace;
13///
14/// let col = MemoryOrder::ColumnMajor;
15/// let mem = LogicalMemorySpace::MainMemory;
16/// let mut ctx = CpuContext::new(1);
17/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
18/// let cotangent = Tensor::<f64>::ones(&[], mem, col).unwrap();
19/// let grad_a = norm_rrule(&mut ctx, &a, &cotangent, NormKind::Fro).unwrap();
20/// ```
21pub fn norm_rrule<T: KernelLinalgScalar<Real = T> + num_traits::Float, C>(
22    ctx: &mut C,
23    tensor: &Tensor<T>,
24    cotangent: &Tensor<T>,
25    kind: NormKind,
26) -> AdResult<Tensor<T>>
27where
28    T: KernelLinalgScalar,
29    C: backend::TensorLinalgContextFor<T>
30        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
31    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
32        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
33    C::Backend: 'static,
34{
35    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_rrule")
36        .map_err(to_ad_err)?;
37
38    if tensor.ndim() == 1 {
39        validate_norm_cotangent(cotangent, &[]).map_err(to_ad_err)?;
40        let (a_data, _) = extract_data(tensor)?;
41        let (dn_data, _) = extract_data(cotangent)?;
42        let dn = dn_data[0];
43        let len = tensor.dims()[0];
44        let mut grad_a = vec![T::zero(); len];
45
46        match kind {
47            NormKind::Fro => {
48                let nrm =
49                    crate::primal::norm_real_impl(ctx, tensor, NormKind::Fro).map_err(to_ad_err)?;
50                let (nrm_data, _) = extract_data(&nrm)?;
51                let nv = nrm_data[0];
52                let scale = if nv > T::zero() { dn / nv } else { T::zero() };
53                for i in 0..len {
54                    grad_a[i] = scale * a_data[i];
55                }
56            }
57            NormKind::L1 => {
58                for i in 0..len {
59                    let v = a_data[i];
60                    let sign = if v > T::zero() {
61                        T::one()
62                    } else if v < T::zero() {
63                        -T::one()
64                    } else {
65                        T::zero()
66                    };
67                    grad_a[i] = dn * sign;
68                }
69            }
70            NormKind::Inf => {
71                let max_abs = a_data.iter().fold(T::zero(), |acc, &v| acc.max(v.abs()));
72                let active: Vec<usize> = a_data
73                    .iter()
74                    .enumerate()
75                    .filter_map(|(i, &v)| if v.abs() == max_abs { Some(i) } else { None })
76                    .collect();
77                if !active.is_empty() {
78                    let active_count = scalar_from::<T>(active.len() as f64).map_err(to_ad_err)?;
79                    let scale = dn / active_count;
80                    for i in active {
81                        let v = a_data[i];
82                        let sign = if v > T::zero() {
83                            T::one()
84                        } else if v < T::zero() {
85                            -T::one()
86                        } else {
87                            T::zero()
88                        };
89                        grad_a[i] = scale * sign;
90                    }
91                }
92            }
93            NormKind::Lp(p) => {
94                if p < 1.0 {
95                    return Err(invalid_vector_lp_exponent_ad_error(p));
96                }
97                if p == 1.0 {
98                    for i in 0..len {
99                        let v = a_data[i];
100                        let sign = if v > T::zero() {
101                            T::one()
102                        } else if v < T::zero() {
103                            -T::one()
104                        } else {
105                            T::zero()
106                        };
107                        grad_a[i] = dn * sign;
108                    }
109                } else {
110                    let nrm =
111                        crate::primal::norm_real_impl(ctx, tensor, kind).map_err(to_ad_err)?;
112                    let (nrm_data, _) = extract_data(&nrm)?;
113                    let nv = nrm_data[0];
114                    if nv > T::zero() {
115                        let p_minus_one = scalar_from::<T>(p - 1.0).map_err(to_ad_err)?;
116                        let scale = dn / nv.powf(p_minus_one);
117                        for i in 0..len {
118                            let v = a_data[i];
119                            let sign = if v > T::zero() {
120                                T::one()
121                            } else if v < T::zero() {
122                                -T::one()
123                            } else {
124                                T::zero()
125                            };
126                            grad_a[i] = scale * sign * v.abs().powf(p_minus_one);
127                        }
128                    }
129                }
130            }
131            NormKind::Nuclear | NormKind::Spectral => {
132                return Err(matrix_only_norm_kind_ad_error(kind));
133            }
134        }
135
136        return tensor_from_data(grad_a, &[len]).map_err(to_ad_err);
137    }
138
139    let (m, n, batch_dims) = validate_2d(tensor)
140        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
141    let bc = batch_count(batch_dims);
142    validate_norm_cotangent(cotangent, batch_dims)
143        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
144
145    let (a_data, _) = extract_data(tensor)?;
146    let (dn_data, _) = extract_data(cotangent)?;
147
148    let mut grad_a = vec![T::zero(); m * n * bc];
149
150    match kind {
151        NormKind::Fro => {
152            // dA = dn * A / ||A||_F
153            let nrm = crate::primal::norm_real_impl(ctx, tensor, NormKind::Fro)
154                .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
155            let (nrm_data, _) = extract_data(&nrm)?;
156            for batch in 0..bc {
157                let dn = dn_data[batch];
158                let nv = nrm_data[batch];
159                let scale = if nv > T::zero() { dn / nv } else { T::zero() };
160                for i in 0..m * n {
161                    grad_a[batch * m * n + i] = scale * a_data[batch * m * n + i];
162                }
163            }
164        }
165        NormKind::Nuclear => {
166            // dA = dn * U V^T
167            for batch in 0..bc {
168                let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
169                let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
170                let k = m.min(n);
171                let uv = backend_mat_mul(ctx, &u, m, k, &transpose(&v, n, k), n)?;
172                let dn = dn_data[batch];
173                for i in 0..m * n {
174                    grad_a[batch * m * n + i] = dn * uv[i];
175                }
176            }
177        }
178        NormKind::Spectral => {
179            // dA = dn * u1 v1^T (rank-1 outer product of leading singular vectors)
180            for batch in 0..bc {
181                let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
182                let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
183                let dn = dn_data[batch];
184                for j in 0..n {
185                    for i in 0..m {
186                        grad_a[batch * m * n + i + j * m] = dn * u[i] * v[j];
187                    }
188                }
189            }
190        }
191        NormKind::L1 => {
192            // dA = dn * sign(A) on columns that attain max absolute column sum.
193            // At ties, average uniformly over active columns.
194            for (batch, &dn_batch) in dn_data.iter().enumerate().take(bc) {
195                if m == 0 || n == 0 {
196                    continue;
197                }
198                let base = batch * m * n;
199                let mut col_sums = vec![T::zero(); n];
200                for j in 0..n {
201                    let mut sum = T::zero();
202                    for i in 0..m {
203                        sum = sum + a_data[base + i + j * m].abs();
204                    }
205                    col_sums[j] = sum;
206                }
207                let mut max_sum = T::neg_infinity();
208                for &sum in &col_sums {
209                    if sum > max_sum {
210                        max_sum = sum;
211                    }
212                }
213                let active_cols: Vec<usize> = col_sums
214                    .iter()
215                    .enumerate()
216                    .filter_map(|(j, &sum)| if sum == max_sum { Some(j) } else { None })
217                    .collect();
218                if active_cols.is_empty() {
219                    continue;
220                }
221                let active_count = scalar_from::<T>(active_cols.len() as f64).map_err(to_ad_err)?;
222                let dn = dn_batch / active_count;
223                for j in active_cols {
224                    for i in 0..m {
225                        let v = a_data[base + i + j * m];
226                        let sign = if v > T::zero() {
227                            T::one()
228                        } else if v < T::zero() {
229                            -T::one()
230                        } else {
231                            T::zero()
232                        };
233                        grad_a[base + i + j * m] = grad_a[base + i + j * m] + dn * sign;
234                    }
235                }
236            }
237        }
238        NormKind::Inf => {
239            // dA = dn * sign(A) on rows that attain max absolute row sum.
240            // At ties, average uniformly over active rows.
241            for (batch, &dn_batch) in dn_data.iter().enumerate().take(bc) {
242                if m == 0 || n == 0 {
243                    continue;
244                }
245                let base = batch * m * n;
246                let mut row_sums = vec![T::zero(); m];
247                for i in 0..m {
248                    let mut sum = T::zero();
249                    for j in 0..n {
250                        sum = sum + a_data[base + i + j * m].abs();
251                    }
252                    row_sums[i] = sum;
253                }
254                let mut max_sum = T::neg_infinity();
255                for &sum in &row_sums {
256                    if sum > max_sum {
257                        max_sum = sum;
258                    }
259                }
260                let active_rows: Vec<usize> = row_sums
261                    .iter()
262                    .enumerate()
263                    .filter_map(|(i, &sum)| if sum == max_sum { Some(i) } else { None })
264                    .collect();
265                if active_rows.is_empty() {
266                    continue;
267                }
268                let active_count = scalar_from::<T>(active_rows.len() as f64).map_err(to_ad_err)?;
269                let dn = dn_batch / active_count;
270                for i in active_rows {
271                    for j in 0..n {
272                        let v = a_data[base + i + j * m];
273                        let sign = if v > T::zero() {
274                            T::one()
275                        } else if v < T::zero() {
276                            -T::one()
277                        } else {
278                            T::zero()
279                        };
280                        grad_a[base + i + j * m] = grad_a[base + i + j * m] + dn * sign;
281                    }
282                }
283            }
284        }
285        _ => {
286            return Err(chainrules_core::AutodiffError::ModeNotSupported {
287                mode: "norm_rrule".into(),
288                reason: format!("norm kind {kind:?} AD not yet implemented"),
289            });
290        }
291    }
292
293    let dims = output_dims(&[m, n], batch_dims);
294    tensor_from_data(grad_a, &dims)
295        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
296}
297
298#[doc(hidden)]
299pub fn norm_rrule_complex<T, R, C>(
300    ctx: &mut C,
301    tensor: &Tensor<T>,
302    cotangent: &Tensor<R>,
303    kind: NormKind,
304) -> AdResult<Tensor<T>>
305where
306    T: KernelLinalgScalar<Real = R>
307        + ComplexFloat<Real = R>
308        + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
309    T: crate::NormPrimal<C>,
310    R: LinalgScalar<Real = R> + num_traits::Float,
311    C: backend::TensorLinalgContextFor<T>
312        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
313        + tenferro_prims::TensorComplexScaleContextFor<T>,
314    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
315        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>,
316    C::ComplexScaleBackend: tenferro_prims::TensorComplexScalePrims<T, Context = C>,
317    C::Backend: 'static,
318{
319    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_rrule_complex")
320        .map_err(to_ad_err)?;
321    ensure_complex_norm_ad_supported(kind)?;
322
323    validate_norm_cotangent(cotangent, expected_norm_output_dims(tensor)).map_err(to_ad_err)?;
324
325    let nrm = crate::norm(ctx, tensor, kind).map_err(to_ad_err)?;
326    let zero =
327        crate::prims_bridge::full_like_constant(R::zero(), nrm.dims(), nrm.logical_memory_space())
328            .map_err(to_ad_err)?;
329    let nonzero = crate::prims_bridge::scalar_binary_same_shape(
330        ctx,
331        &nrm,
332        &zero,
333        tenferro_prims::ScalarBinaryOp::Greater,
334    )
335    .map_err(to_ad_err)?;
336    let quotient = crate::prims_bridge::scalar_binary_same_shape(
337        ctx,
338        cotangent,
339        &nrm,
340        tenferro_prims::ScalarBinaryOp::Div,
341    )
342    .map_err(to_ad_err)?;
343    let safe_scale = crate::prims_bridge::scalar_where_same_shape(ctx, &nonzero, &quotient, &zero)
344        .map_err(to_ad_err)?;
345    let expanded_scale =
346        broadcast_norm_control_to_input(&safe_scale, tensor.dims()).map_err(to_ad_err)?;
347    crate::prims_bridge::complex_scale_same_shape(ctx, tensor, &expanded_scale).map_err(to_ad_err)
348}
349
350fn ensure_complex_norm_ad_supported(kind: NormKind) -> AdResult<()> {
351    match kind {
352        NormKind::Fro => Ok(()),
353        NormKind::Lp(p) if p == 2.0 => Ok(()),
354        _ => Err(chainrules_core::AutodiffError::InvalidArgument(format!(
355            "complex norm AD currently supports Fro and vector L2 only, got {kind:?}"
356        ))),
357    }
358}
359
360fn expected_norm_output_dims<T: LinalgScalar>(tensor: &Tensor<T>) -> &[usize] {
361    if tensor.ndim() <= 1 {
362        &[]
363    } else {
364        &tensor.dims()[2..]
365    }
366}
367
368fn broadcast_norm_control_to_input<R: LinalgScalar>(
369    value_by_batch: &Tensor<R>,
370    input_dims: &[usize],
371) -> Result<Tensor<R>> {
372    if input_dims.len() <= 1 {
373        return value_by_batch.reshape(&[1])?.broadcast(input_dims);
374    }
375
376    let mut reshape_dims = vec![1, 1];
377    reshape_dims.extend_from_slice(&input_dims[2..]);
378    value_by_batch.reshape(&reshape_dims)?.broadcast(input_dims)
379}