tenferro_linalg/frules/
norms.rs

1use super::*;
2use num_complex::ComplexFloat;
3
4/// Forward-mode AD rule for norm (JVP / pushforward).
5///
6/// # Examples
7///
8/// ```
9/// use tenferro_linalg::{norm_frule, NormKind};
10/// use tenferro_prims::CpuContext;
11/// use tenferro_tensor::{Tensor, MemoryOrder};
12/// use tenferro_device::LogicalMemorySpace;
13///
14/// let col = MemoryOrder::ColumnMajor;
15/// let mem = LogicalMemorySpace::MainMemory;
16/// let mut ctx = CpuContext::new(1);
17/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
18/// let da = Tensor::<f64>::ones(&[3, 4], mem, col).unwrap();
19/// let (n, dn) = norm_frule(&mut ctx, &a, &da, NormKind::Fro).unwrap();
20/// ```
21pub fn norm_frule<T: KernelLinalgScalar<Real = T> + num_traits::Float, C>(
22    ctx: &mut C,
23    tensor: &Tensor<T>,
24    tangent: &Tensor<T>,
25    kind: NormKind,
26) -> AdResult<(Tensor<T>, Tensor<T>)>
27where
28    T: KernelLinalgScalar,
29    C: backend::TensorLinalgContextFor<T>
30        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
31    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
32        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
33    C::Backend: 'static,
34{
35    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_frule")
36        .map_err(to_ad_err)?;
37
38    let nrm = crate::primal::norm_real_impl(ctx, tensor, kind)
39        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
40
41    if tensor.ndim() == 1 {
42        let (a_data, _) = extract_data(tensor)?;
43        let (nrm_data, _) = extract_data(&nrm)?;
44        let (da_data, _) = extract_data(tangent)?;
45        let len = tensor.dims()[0];
46        let mut dnrm = T::zero();
47
48        match kind {
49            NormKind::Fro => {
50                let nv = nrm_data[0];
51                if nv > T::zero() {
52                    for i in 0..len {
53                        dnrm = dnrm + a_data[i] * da_data[i];
54                    }
55                    dnrm = dnrm / nv;
56                }
57            }
58            NormKind::L1 => {
59                for i in 0..len {
60                    let v = a_data[i];
61                    let sign = if v > T::zero() {
62                        T::one()
63                    } else if v < T::zero() {
64                        -T::one()
65                    } else {
66                        T::zero()
67                    };
68                    dnrm = dnrm + sign * da_data[i];
69                }
70            }
71            NormKind::Inf => {
72                let max_abs = a_data.iter().fold(T::zero(), |acc, &v| acc.max(v.abs()));
73                let active: Vec<usize> = a_data
74                    .iter()
75                    .enumerate()
76                    .filter_map(|(i, &v)| if v.abs() == max_abs { Some(i) } else { None })
77                    .collect();
78                if !active.is_empty() {
79                    for i in active.iter().copied() {
80                        let v = a_data[i];
81                        let sign = if v > T::zero() {
82                            T::one()
83                        } else if v < T::zero() {
84                            -T::one()
85                        } else {
86                            T::zero()
87                        };
88                        dnrm = dnrm + sign * da_data[i];
89                    }
90                    let active_count = scalar_from::<T>(active.len() as f64).map_err(to_ad_err)?;
91                    dnrm = dnrm / active_count;
92                }
93            }
94            NormKind::Lp(p) => {
95                if p < 1.0 {
96                    return Err(invalid_vector_lp_exponent_ad_error(p));
97                }
98                if p == 1.0 {
99                    for i in 0..len {
100                        let v = a_data[i];
101                        let sign = if v > T::zero() {
102                            T::one()
103                        } else if v < T::zero() {
104                            -T::one()
105                        } else {
106                            T::zero()
107                        };
108                        dnrm = dnrm + sign * da_data[i];
109                    }
110                } else {
111                    let nv = nrm_data[0];
112                    if nv > T::zero() {
113                        let p_minus_one = scalar_from::<T>(p - 1.0).map_err(to_ad_err)?;
114                        for i in 0..len {
115                            let v = a_data[i];
116                            let sign = if v > T::zero() {
117                                T::one()
118                            } else if v < T::zero() {
119                                -T::one()
120                            } else {
121                                T::zero()
122                            };
123                            dnrm = dnrm + sign * v.abs().powf(p_minus_one) * da_data[i];
124                        }
125                        dnrm = dnrm / nv.powf(p_minus_one);
126                    }
127                }
128            }
129            NormKind::Nuclear | NormKind::Spectral => {
130                return Err(matrix_only_norm_kind_ad_error(kind));
131            }
132        }
133
134        let dnrm = tensor_from_data(vec![dnrm], &[]).map_err(to_ad_err)?;
135        return Ok((nrm, dnrm));
136    }
137
138    let (m, n, batch_dims) = validate_2d(tensor)
139        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
140    let bc = batch_count(batch_dims);
141
142    let (a_data, _) = extract_data(tensor)?;
143    let (nrm_data, _) = extract_data(&nrm)?;
144    let (da_data, _) = extract_data(tangent)?;
145
146    let mut dnrm_data = vec![T::zero(); bc];
147
148    match kind {
149        NormKind::Fro => {
150            // d||A||_F = tr(A^T dA) / ||A||_F
151            for batch in 0..bc {
152                let nv = nrm_data[batch];
153                if nv > T::zero() {
154                    let mut dot = T::zero();
155                    for i in 0..m * n {
156                        dot = dot + a_data[batch * m * n + i] * da_data[batch * m * n + i];
157                    }
158                    dnrm_data[batch] = dot / nv;
159                }
160            }
161        }
162        NormKind::Nuclear => {
163            // d||A||_* = tr(U^T dA V)
164            for batch in 0..bc {
165                let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
166                let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
167                let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
168                let k = m.min(n);
169                let ut_da = backend_mat_mul(ctx, &transpose(&u, m, k), k, m, da_b, n)?;
170                let ut_da_v = backend_mat_mul(ctx, &ut_da, k, n, &v, k)?;
171                let mut trace = T::zero();
172                for i in 0..k {
173                    trace = trace + ut_da_v[i + i * k];
174                }
175                dnrm_data[batch] = trace;
176            }
177        }
178        NormKind::Spectral => {
179            // d||A||_2 = u1^T dA v1
180            for batch in 0..bc {
181                let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
182                let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
183                let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
184                let mut val = T::zero();
185                for i in 0..m {
186                    for j in 0..n {
187                        val = val + u[i] * da_b[i + j * m] * v[j];
188                    }
189                }
190                dnrm_data[batch] = val;
191            }
192        }
193        NormKind::L1 => {
194            // d||A||_1 = sum_i sign(A_ij) dA_ij on active max columns.
195            // At ties, average uniformly over active columns.
196            for (batch, dn_slot) in dnrm_data.iter_mut().enumerate().take(bc) {
197                if m == 0 || n == 0 {
198                    continue;
199                }
200                let base = batch * m * n;
201                let mut col_sums = vec![T::zero(); n];
202                for j in 0..n {
203                    let mut sum = T::zero();
204                    for i in 0..m {
205                        sum = sum + a_data[base + i + j * m].abs();
206                    }
207                    col_sums[j] = sum;
208                }
209                let mut max_sum = T::neg_infinity();
210                for &sum in &col_sums {
211                    if sum > max_sum {
212                        max_sum = sum;
213                    }
214                }
215                let active_cols: Vec<usize> = col_sums
216                    .iter()
217                    .enumerate()
218                    .filter_map(|(j, &sum)| if sum == max_sum { Some(j) } else { None })
219                    .collect();
220                if active_cols.is_empty() {
221                    continue;
222                }
223                let mut accum = T::zero();
224                for j in active_cols.iter().copied() {
225                    for i in 0..m {
226                        let v = a_data[base + i + j * m];
227                        let sign = if v > T::zero() {
228                            T::one()
229                        } else if v < T::zero() {
230                            -T::one()
231                        } else {
232                            T::zero()
233                        };
234                        accum = accum + sign * da_data[base + i + j * m];
235                    }
236                }
237                let active_count = scalar_from::<T>(active_cols.len() as f64).map_err(to_ad_err)?;
238                *dn_slot = accum / active_count;
239            }
240        }
241        NormKind::Inf => {
242            // d||A||_inf = sum_j sign(A_ij) dA_ij on active max rows.
243            // At ties, average uniformly over active rows.
244            for (batch, dn_slot) in dnrm_data.iter_mut().enumerate().take(bc) {
245                if m == 0 || n == 0 {
246                    continue;
247                }
248                let base = batch * m * n;
249                let mut row_sums = vec![T::zero(); m];
250                for i in 0..m {
251                    let mut sum = T::zero();
252                    for j in 0..n {
253                        sum = sum + a_data[base + i + j * m].abs();
254                    }
255                    row_sums[i] = sum;
256                }
257                let mut max_sum = T::neg_infinity();
258                for &sum in &row_sums {
259                    if sum > max_sum {
260                        max_sum = sum;
261                    }
262                }
263                let active_rows: Vec<usize> = row_sums
264                    .iter()
265                    .enumerate()
266                    .filter_map(|(i, &sum)| if sum == max_sum { Some(i) } else { None })
267                    .collect();
268                if active_rows.is_empty() {
269                    continue;
270                }
271                let mut accum = T::zero();
272                for i in active_rows.iter().copied() {
273                    for j in 0..n {
274                        let v = a_data[base + i + j * m];
275                        let sign = if v > T::zero() {
276                            T::one()
277                        } else if v < T::zero() {
278                            -T::one()
279                        } else {
280                            T::zero()
281                        };
282                        accum = accum + sign * da_data[base + i + j * m];
283                    }
284                }
285                let active_count = scalar_from::<T>(active_rows.len() as f64).map_err(to_ad_err)?;
286                *dn_slot = accum / active_count;
287            }
288        }
289        _ => {
290            return Err(chainrules_core::AutodiffError::ModeNotSupported {
291                mode: "norm_frule".into(),
292                reason: format!("norm kind {kind:?} AD not yet implemented"),
293            });
294        }
295    }
296
297    let dims = output_dims(&[], batch_dims);
298    let dnrm = tensor_from_data(dnrm_data, &dims)
299        .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
300    Ok((nrm, dnrm))
301}
302
303#[doc(hidden)]
304pub fn norm_frule_complex<T, R, C>(
305    ctx: &mut C,
306    tensor: &Tensor<T>,
307    tangent: &Tensor<T>,
308    kind: NormKind,
309) -> AdResult<(Tensor<R>, Tensor<R>)>
310where
311    T: KernelLinalgScalar<Real = R> + ComplexFloat<Real = R> + crate::NormPrimal<C>,
312    R: LinalgScalar<Real = R> + num_traits::Float,
313    C: backend::TensorLinalgContextFor<T>
314        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
315        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
316        + tenferro_prims::TensorComplexRealContextFor<T>,
317    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
318        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>,
319    C::ComplexRealBackend: tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
320    C::Backend: 'static,
321{
322    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_frule_complex")
323        .map_err(to_ad_err)?;
324    ensure_complex_norm_ad_supported(kind)?;
325
326    let nrm = crate::norm(ctx, tensor, kind).map_err(to_ad_err)?;
327    let conj_tensor = crate::prims_bridge::scalar_unary_same_shape(
328        ctx,
329        tensor,
330        tenferro_prims::ScalarUnaryOp::Conj,
331    )
332    .map_err(to_ad_err)?;
333    let product = crate::prims_bridge::scalar_binary_same_shape(
334        ctx,
335        &conj_tensor,
336        tangent,
337        tenferro_prims::ScalarBinaryOp::Mul,
338    )
339    .map_err(to_ad_err)?;
340    let kept_axes = norm_kept_axes(tensor.ndim());
341    let numerator = crate::prims_bridge::complex_real_reduce_keep_axes(
342        ctx,
343        &product,
344        tenferro_prims::ComplexRealUnaryOp::Real,
345        &kept_axes,
346        tenferro_prims::ScalarReductionOp::Sum,
347    )
348    .map_err(to_ad_err)?;
349    let zero =
350        crate::prims_bridge::full_like_constant(R::zero(), nrm.dims(), nrm.logical_memory_space())
351            .map_err(to_ad_err)?;
352    let nonzero = crate::prims_bridge::scalar_binary_same_shape(
353        ctx,
354        &nrm,
355        &zero,
356        tenferro_prims::ScalarBinaryOp::Greater,
357    )
358    .map_err(to_ad_err)?;
359    let quotient = crate::prims_bridge::scalar_binary_same_shape(
360        ctx,
361        &numerator,
362        &nrm,
363        tenferro_prims::ScalarBinaryOp::Div,
364    )
365    .map_err(to_ad_err)?;
366    let output_tangent =
367        crate::prims_bridge::scalar_where_same_shape(ctx, &nonzero, &quotient, &zero)
368            .map_err(to_ad_err)?;
369    Ok((nrm, output_tangent))
370}
371
372fn ensure_complex_norm_ad_supported(kind: NormKind) -> AdResult<()> {
373    match kind {
374        NormKind::Fro => Ok(()),
375        NormKind::Lp(p) if p == 2.0 => Ok(()),
376        _ => Err(chainrules_core::AutodiffError::InvalidArgument(format!(
377            "complex norm AD currently supports Fro and vector L2 only, got {kind:?}"
378        ))),
379    }
380}
381
382fn norm_kept_axes(ndim: usize) -> Vec<usize> {
383    if ndim <= 1 {
384        Vec::new()
385    } else {
386        (2..ndim).collect()
387    }
388}