1use super::*;
2use num_complex::ComplexFloat;
3
4pub fn norm_frule<T: KernelLinalgScalar<Real = T> + num_traits::Float, C>(
22 ctx: &mut C,
23 tensor: &Tensor<T>,
24 tangent: &Tensor<T>,
25 kind: NormKind,
26) -> AdResult<(Tensor<T>, Tensor<T>)>
27where
28 T: KernelLinalgScalar,
29 C: backend::TensorLinalgContextFor<T>
30 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
31 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
32 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
33 C::Backend: 'static,
34{
35 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_frule")
36 .map_err(to_ad_err)?;
37
38 let nrm = crate::primal::norm_real_impl(ctx, tensor, kind)
39 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
40
41 if tensor.ndim() == 1 {
42 let (a_data, _) = extract_data(tensor)?;
43 let (nrm_data, _) = extract_data(&nrm)?;
44 let (da_data, _) = extract_data(tangent)?;
45 let len = tensor.dims()[0];
46 let mut dnrm = T::zero();
47
48 match kind {
49 NormKind::Fro => {
50 let nv = nrm_data[0];
51 if nv > T::zero() {
52 for i in 0..len {
53 dnrm = dnrm + a_data[i] * da_data[i];
54 }
55 dnrm = dnrm / nv;
56 }
57 }
58 NormKind::L1 => {
59 for i in 0..len {
60 let v = a_data[i];
61 let sign = if v > T::zero() {
62 T::one()
63 } else if v < T::zero() {
64 -T::one()
65 } else {
66 T::zero()
67 };
68 dnrm = dnrm + sign * da_data[i];
69 }
70 }
71 NormKind::Inf => {
72 let max_abs = a_data.iter().fold(T::zero(), |acc, &v| acc.max(v.abs()));
73 let active: Vec<usize> = a_data
74 .iter()
75 .enumerate()
76 .filter_map(|(i, &v)| if v.abs() == max_abs { Some(i) } else { None })
77 .collect();
78 if !active.is_empty() {
79 for i in active.iter().copied() {
80 let v = a_data[i];
81 let sign = if v > T::zero() {
82 T::one()
83 } else if v < T::zero() {
84 -T::one()
85 } else {
86 T::zero()
87 };
88 dnrm = dnrm + sign * da_data[i];
89 }
90 let active_count = scalar_from::<T>(active.len() as f64).map_err(to_ad_err)?;
91 dnrm = dnrm / active_count;
92 }
93 }
94 NormKind::Lp(p) => {
95 if p < 1.0 {
96 return Err(invalid_vector_lp_exponent_ad_error(p));
97 }
98 if p == 1.0 {
99 for i in 0..len {
100 let v = a_data[i];
101 let sign = if v > T::zero() {
102 T::one()
103 } else if v < T::zero() {
104 -T::one()
105 } else {
106 T::zero()
107 };
108 dnrm = dnrm + sign * da_data[i];
109 }
110 } else {
111 let nv = nrm_data[0];
112 if nv > T::zero() {
113 let p_minus_one = scalar_from::<T>(p - 1.0).map_err(to_ad_err)?;
114 for i in 0..len {
115 let v = a_data[i];
116 let sign = if v > T::zero() {
117 T::one()
118 } else if v < T::zero() {
119 -T::one()
120 } else {
121 T::zero()
122 };
123 dnrm = dnrm + sign * v.abs().powf(p_minus_one) * da_data[i];
124 }
125 dnrm = dnrm / nv.powf(p_minus_one);
126 }
127 }
128 }
129 NormKind::Nuclear | NormKind::Spectral => {
130 return Err(matrix_only_norm_kind_ad_error(kind));
131 }
132 }
133
134 let dnrm = tensor_from_data(vec![dnrm], &[]).map_err(to_ad_err)?;
135 return Ok((nrm, dnrm));
136 }
137
138 let (m, n, batch_dims) = validate_2d(tensor)
139 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
140 let bc = batch_count(batch_dims);
141
142 let (a_data, _) = extract_data(tensor)?;
143 let (nrm_data, _) = extract_data(&nrm)?;
144 let (da_data, _) = extract_data(tangent)?;
145
146 let mut dnrm_data = vec![T::zero(); bc];
147
148 match kind {
149 NormKind::Fro => {
150 for batch in 0..bc {
152 let nv = nrm_data[batch];
153 if nv > T::zero() {
154 let mut dot = T::zero();
155 for i in 0..m * n {
156 dot = dot + a_data[batch * m * n + i] * da_data[batch * m * n + i];
157 }
158 dnrm_data[batch] = dot / nv;
159 }
160 }
161 }
162 NormKind::Nuclear => {
163 for batch in 0..bc {
165 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
166 let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
167 let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
168 let k = m.min(n);
169 let ut_da = backend_mat_mul(ctx, &transpose(&u, m, k), k, m, da_b, n)?;
170 let ut_da_v = backend_mat_mul(ctx, &ut_da, k, n, &v, k)?;
171 let mut trace = T::zero();
172 for i in 0..k {
173 trace = trace + ut_da_v[i + i * k];
174 }
175 dnrm_data[batch] = trace;
176 }
177 }
178 NormKind::Spectral => {
179 for batch in 0..bc {
181 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
182 let da_b = &da_data[batch * m * n..(batch + 1) * m * n];
183 let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
184 let mut val = T::zero();
185 for i in 0..m {
186 for j in 0..n {
187 val = val + u[i] * da_b[i + j * m] * v[j];
188 }
189 }
190 dnrm_data[batch] = val;
191 }
192 }
193 NormKind::L1 => {
194 for (batch, dn_slot) in dnrm_data.iter_mut().enumerate().take(bc) {
197 if m == 0 || n == 0 {
198 continue;
199 }
200 let base = batch * m * n;
201 let mut col_sums = vec![T::zero(); n];
202 for j in 0..n {
203 let mut sum = T::zero();
204 for i in 0..m {
205 sum = sum + a_data[base + i + j * m].abs();
206 }
207 col_sums[j] = sum;
208 }
209 let mut max_sum = T::neg_infinity();
210 for &sum in &col_sums {
211 if sum > max_sum {
212 max_sum = sum;
213 }
214 }
215 let active_cols: Vec<usize> = col_sums
216 .iter()
217 .enumerate()
218 .filter_map(|(j, &sum)| if sum == max_sum { Some(j) } else { None })
219 .collect();
220 if active_cols.is_empty() {
221 continue;
222 }
223 let mut accum = T::zero();
224 for j in active_cols.iter().copied() {
225 for i in 0..m {
226 let v = a_data[base + i + j * m];
227 let sign = if v > T::zero() {
228 T::one()
229 } else if v < T::zero() {
230 -T::one()
231 } else {
232 T::zero()
233 };
234 accum = accum + sign * da_data[base + i + j * m];
235 }
236 }
237 let active_count = scalar_from::<T>(active_cols.len() as f64).map_err(to_ad_err)?;
238 *dn_slot = accum / active_count;
239 }
240 }
241 NormKind::Inf => {
242 for (batch, dn_slot) in dnrm_data.iter_mut().enumerate().take(bc) {
245 if m == 0 || n == 0 {
246 continue;
247 }
248 let base = batch * m * n;
249 let mut row_sums = vec![T::zero(); m];
250 for i in 0..m {
251 let mut sum = T::zero();
252 for j in 0..n {
253 sum = sum + a_data[base + i + j * m].abs();
254 }
255 row_sums[i] = sum;
256 }
257 let mut max_sum = T::neg_infinity();
258 for &sum in &row_sums {
259 if sum > max_sum {
260 max_sum = sum;
261 }
262 }
263 let active_rows: Vec<usize> = row_sums
264 .iter()
265 .enumerate()
266 .filter_map(|(i, &sum)| if sum == max_sum { Some(i) } else { None })
267 .collect();
268 if active_rows.is_empty() {
269 continue;
270 }
271 let mut accum = T::zero();
272 for i in active_rows.iter().copied() {
273 for j in 0..n {
274 let v = a_data[base + i + j * m];
275 let sign = if v > T::zero() {
276 T::one()
277 } else if v < T::zero() {
278 -T::one()
279 } else {
280 T::zero()
281 };
282 accum = accum + sign * da_data[base + i + j * m];
283 }
284 }
285 let active_count = scalar_from::<T>(active_rows.len() as f64).map_err(to_ad_err)?;
286 *dn_slot = accum / active_count;
287 }
288 }
289 _ => {
290 return Err(chainrules_core::AutodiffError::ModeNotSupported {
291 mode: "norm_frule".into(),
292 reason: format!("norm kind {kind:?} AD not yet implemented"),
293 });
294 }
295 }
296
297 let dims = output_dims(&[], batch_dims);
298 let dnrm = tensor_from_data(dnrm_data, &dims)
299 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
300 Ok((nrm, dnrm))
301}
302
303#[doc(hidden)]
304pub fn norm_frule_complex<T, R, C>(
305 ctx: &mut C,
306 tensor: &Tensor<T>,
307 tangent: &Tensor<T>,
308 kind: NormKind,
309) -> AdResult<(Tensor<R>, Tensor<R>)>
310where
311 T: KernelLinalgScalar<Real = R> + ComplexFloat<Real = R> + crate::NormPrimal<C>,
312 R: LinalgScalar<Real = R> + num_traits::Float,
313 C: backend::TensorLinalgContextFor<T>
314 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
315 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
316 + tenferro_prims::TensorComplexRealContextFor<T>,
317 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
318 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>,
319 C::ComplexRealBackend: tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
320 C::Backend: 'static,
321{
322 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_frule_complex")
323 .map_err(to_ad_err)?;
324 ensure_complex_norm_ad_supported(kind)?;
325
326 let nrm = crate::norm(ctx, tensor, kind).map_err(to_ad_err)?;
327 let conj_tensor = crate::prims_bridge::scalar_unary_same_shape(
328 ctx,
329 tensor,
330 tenferro_prims::ScalarUnaryOp::Conj,
331 )
332 .map_err(to_ad_err)?;
333 let product = crate::prims_bridge::scalar_binary_same_shape(
334 ctx,
335 &conj_tensor,
336 tangent,
337 tenferro_prims::ScalarBinaryOp::Mul,
338 )
339 .map_err(to_ad_err)?;
340 let kept_axes = norm_kept_axes(tensor.ndim());
341 let numerator = crate::prims_bridge::complex_real_reduce_keep_axes(
342 ctx,
343 &product,
344 tenferro_prims::ComplexRealUnaryOp::Real,
345 &kept_axes,
346 tenferro_prims::ScalarReductionOp::Sum,
347 )
348 .map_err(to_ad_err)?;
349 let zero =
350 crate::prims_bridge::full_like_constant(R::zero(), nrm.dims(), nrm.logical_memory_space())
351 .map_err(to_ad_err)?;
352 let nonzero = crate::prims_bridge::scalar_binary_same_shape(
353 ctx,
354 &nrm,
355 &zero,
356 tenferro_prims::ScalarBinaryOp::Greater,
357 )
358 .map_err(to_ad_err)?;
359 let quotient = crate::prims_bridge::scalar_binary_same_shape(
360 ctx,
361 &numerator,
362 &nrm,
363 tenferro_prims::ScalarBinaryOp::Div,
364 )
365 .map_err(to_ad_err)?;
366 let output_tangent =
367 crate::prims_bridge::scalar_where_same_shape(ctx, &nonzero, "ient, &zero)
368 .map_err(to_ad_err)?;
369 Ok((nrm, output_tangent))
370}
371
372fn ensure_complex_norm_ad_supported(kind: NormKind) -> AdResult<()> {
373 match kind {
374 NormKind::Fro => Ok(()),
375 NormKind::Lp(p) if p == 2.0 => Ok(()),
376 _ => Err(chainrules_core::AutodiffError::InvalidArgument(format!(
377 "complex norm AD currently supports Fro and vector L2 only, got {kind:?}"
378 ))),
379 }
380}
381
382fn norm_kept_axes(ndim: usize) -> Vec<usize> {
383 if ndim <= 1 {
384 Vec::new()
385 } else {
386 (2..ndim).collect()
387 }
388}