1use super::*;
2use num_complex::ComplexFloat;
3
4pub fn norm_rrule<T: KernelLinalgScalar<Real = T> + num_traits::Float, C>(
22 ctx: &mut C,
23 tensor: &Tensor<T>,
24 cotangent: &Tensor<T>,
25 kind: NormKind,
26) -> AdResult<Tensor<T>>
27where
28 T: KernelLinalgScalar,
29 C: backend::TensorLinalgContextFor<T>
30 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
31 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
32 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
33 C::Backend: 'static,
34{
35 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_rrule")
36 .map_err(to_ad_err)?;
37
38 if tensor.ndim() == 1 {
39 validate_norm_cotangent(cotangent, &[]).map_err(to_ad_err)?;
40 let (a_data, _) = extract_data(tensor)?;
41 let (dn_data, _) = extract_data(cotangent)?;
42 let dn = dn_data[0];
43 let len = tensor.dims()[0];
44 let mut grad_a = vec![T::zero(); len];
45
46 match kind {
47 NormKind::Fro => {
48 let nrm =
49 crate::primal::norm_real_impl(ctx, tensor, NormKind::Fro).map_err(to_ad_err)?;
50 let (nrm_data, _) = extract_data(&nrm)?;
51 let nv = nrm_data[0];
52 let scale = if nv > T::zero() { dn / nv } else { T::zero() };
53 for i in 0..len {
54 grad_a[i] = scale * a_data[i];
55 }
56 }
57 NormKind::L1 => {
58 for i in 0..len {
59 let v = a_data[i];
60 let sign = if v > T::zero() {
61 T::one()
62 } else if v < T::zero() {
63 -T::one()
64 } else {
65 T::zero()
66 };
67 grad_a[i] = dn * sign;
68 }
69 }
70 NormKind::Inf => {
71 let max_abs = a_data.iter().fold(T::zero(), |acc, &v| acc.max(v.abs()));
72 let active: Vec<usize> = a_data
73 .iter()
74 .enumerate()
75 .filter_map(|(i, &v)| if v.abs() == max_abs { Some(i) } else { None })
76 .collect();
77 if !active.is_empty() {
78 let active_count = scalar_from::<T>(active.len() as f64).map_err(to_ad_err)?;
79 let scale = dn / active_count;
80 for i in active {
81 let v = a_data[i];
82 let sign = if v > T::zero() {
83 T::one()
84 } else if v < T::zero() {
85 -T::one()
86 } else {
87 T::zero()
88 };
89 grad_a[i] = scale * sign;
90 }
91 }
92 }
93 NormKind::Lp(p) => {
94 if p < 1.0 {
95 return Err(invalid_vector_lp_exponent_ad_error(p));
96 }
97 if p == 1.0 {
98 for i in 0..len {
99 let v = a_data[i];
100 let sign = if v > T::zero() {
101 T::one()
102 } else if v < T::zero() {
103 -T::one()
104 } else {
105 T::zero()
106 };
107 grad_a[i] = dn * sign;
108 }
109 } else {
110 let nrm =
111 crate::primal::norm_real_impl(ctx, tensor, kind).map_err(to_ad_err)?;
112 let (nrm_data, _) = extract_data(&nrm)?;
113 let nv = nrm_data[0];
114 if nv > T::zero() {
115 let p_minus_one = scalar_from::<T>(p - 1.0).map_err(to_ad_err)?;
116 let scale = dn / nv.powf(p_minus_one);
117 for i in 0..len {
118 let v = a_data[i];
119 let sign = if v > T::zero() {
120 T::one()
121 } else if v < T::zero() {
122 -T::one()
123 } else {
124 T::zero()
125 };
126 grad_a[i] = scale * sign * v.abs().powf(p_minus_one);
127 }
128 }
129 }
130 }
131 NormKind::Nuclear | NormKind::Spectral => {
132 return Err(matrix_only_norm_kind_ad_error(kind));
133 }
134 }
135
136 return tensor_from_data(grad_a, &[len]).map_err(to_ad_err);
137 }
138
139 let (m, n, batch_dims) = validate_2d(tensor)
140 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
141 let bc = batch_count(batch_dims);
142 validate_norm_cotangent(cotangent, batch_dims)
143 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
144
145 let (a_data, _) = extract_data(tensor)?;
146 let (dn_data, _) = extract_data(cotangent)?;
147
148 let mut grad_a = vec![T::zero(); m * n * bc];
149
150 match kind {
151 NormKind::Fro => {
152 let nrm = crate::primal::norm_real_impl(ctx, tensor, NormKind::Fro)
154 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))?;
155 let (nrm_data, _) = extract_data(&nrm)?;
156 for batch in 0..bc {
157 let dn = dn_data[batch];
158 let nv = nrm_data[batch];
159 let scale = if nv > T::zero() { dn / nv } else { T::zero() };
160 for i in 0..m * n {
161 grad_a[batch * m * n + i] = scale * a_data[batch * m * n + i];
162 }
163 }
164 }
165 NormKind::Nuclear => {
166 for batch in 0..bc {
168 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
169 let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
170 let k = m.min(n);
171 let uv = backend_mat_mul(ctx, &u, m, k, &transpose(&v, n, k), n)?;
172 let dn = dn_data[batch];
173 for i in 0..m * n {
174 grad_a[batch * m * n + i] = dn * uv[i];
175 }
176 }
177 }
178 NormKind::Spectral => {
179 for batch in 0..bc {
181 let a_b = &a_data[batch * m * n..(batch + 1) * m * n];
182 let (u, _s, v) = backend_thin_svd(ctx, a_b, m, n)?;
183 let dn = dn_data[batch];
184 for j in 0..n {
185 for i in 0..m {
186 grad_a[batch * m * n + i + j * m] = dn * u[i] * v[j];
187 }
188 }
189 }
190 }
191 NormKind::L1 => {
192 for (batch, &dn_batch) in dn_data.iter().enumerate().take(bc) {
195 if m == 0 || n == 0 {
196 continue;
197 }
198 let base = batch * m * n;
199 let mut col_sums = vec![T::zero(); n];
200 for j in 0..n {
201 let mut sum = T::zero();
202 for i in 0..m {
203 sum = sum + a_data[base + i + j * m].abs();
204 }
205 col_sums[j] = sum;
206 }
207 let mut max_sum = T::neg_infinity();
208 for &sum in &col_sums {
209 if sum > max_sum {
210 max_sum = sum;
211 }
212 }
213 let active_cols: Vec<usize> = col_sums
214 .iter()
215 .enumerate()
216 .filter_map(|(j, &sum)| if sum == max_sum { Some(j) } else { None })
217 .collect();
218 if active_cols.is_empty() {
219 continue;
220 }
221 let active_count = scalar_from::<T>(active_cols.len() as f64).map_err(to_ad_err)?;
222 let dn = dn_batch / active_count;
223 for j in active_cols {
224 for i in 0..m {
225 let v = a_data[base + i + j * m];
226 let sign = if v > T::zero() {
227 T::one()
228 } else if v < T::zero() {
229 -T::one()
230 } else {
231 T::zero()
232 };
233 grad_a[base + i + j * m] = grad_a[base + i + j * m] + dn * sign;
234 }
235 }
236 }
237 }
238 NormKind::Inf => {
239 for (batch, &dn_batch) in dn_data.iter().enumerate().take(bc) {
242 if m == 0 || n == 0 {
243 continue;
244 }
245 let base = batch * m * n;
246 let mut row_sums = vec![T::zero(); m];
247 for i in 0..m {
248 let mut sum = T::zero();
249 for j in 0..n {
250 sum = sum + a_data[base + i + j * m].abs();
251 }
252 row_sums[i] = sum;
253 }
254 let mut max_sum = T::neg_infinity();
255 for &sum in &row_sums {
256 if sum > max_sum {
257 max_sum = sum;
258 }
259 }
260 let active_rows: Vec<usize> = row_sums
261 .iter()
262 .enumerate()
263 .filter_map(|(i, &sum)| if sum == max_sum { Some(i) } else { None })
264 .collect();
265 if active_rows.is_empty() {
266 continue;
267 }
268 let active_count = scalar_from::<T>(active_rows.len() as f64).map_err(to_ad_err)?;
269 let dn = dn_batch / active_count;
270 for i in active_rows {
271 for j in 0..n {
272 let v = a_data[base + i + j * m];
273 let sign = if v > T::zero() {
274 T::one()
275 } else if v < T::zero() {
276 -T::one()
277 } else {
278 T::zero()
279 };
280 grad_a[base + i + j * m] = grad_a[base + i + j * m] + dn * sign;
281 }
282 }
283 }
284 }
285 _ => {
286 return Err(chainrules_core::AutodiffError::ModeNotSupported {
287 mode: "norm_rrule".into(),
288 reason: format!("norm kind {kind:?} AD not yet implemented"),
289 });
290 }
291 }
292
293 let dims = output_dims(&[m, n], batch_dims);
294 tensor_from_data(grad_a, &dims)
295 .map_err(|e| chainrules_core::AutodiffError::InvalidArgument(e.to_string()))
296}
297
298#[doc(hidden)]
299pub fn norm_rrule_complex<T, R, C>(
300 ctx: &mut C,
301 tensor: &Tensor<T>,
302 cotangent: &Tensor<R>,
303 kind: NormKind,
304) -> AdResult<Tensor<T>>
305where
306 T: KernelLinalgScalar<Real = R>
307 + ComplexFloat<Real = R>
308 + crate::prims_bridge::ScaleTensorByRealSameShape<C>,
309 T: crate::NormPrimal<C>,
310 R: LinalgScalar<Real = R> + num_traits::Float,
311 C: backend::TensorLinalgContextFor<T>
312 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
313 + tenferro_prims::TensorComplexScaleContextFor<T>,
314 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
315 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>,
316 C::ComplexScaleBackend: tenferro_prims::TensorComplexScalePrims<T, Context = C>,
317 C::Backend: 'static,
318{
319 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Norm, "norm_rrule_complex")
320 .map_err(to_ad_err)?;
321 ensure_complex_norm_ad_supported(kind)?;
322
323 validate_norm_cotangent(cotangent, expected_norm_output_dims(tensor)).map_err(to_ad_err)?;
324
325 let nrm = crate::norm(ctx, tensor, kind).map_err(to_ad_err)?;
326 let zero =
327 crate::prims_bridge::full_like_constant(R::zero(), nrm.dims(), nrm.logical_memory_space())
328 .map_err(to_ad_err)?;
329 let nonzero = crate::prims_bridge::scalar_binary_same_shape(
330 ctx,
331 &nrm,
332 &zero,
333 tenferro_prims::ScalarBinaryOp::Greater,
334 )
335 .map_err(to_ad_err)?;
336 let quotient = crate::prims_bridge::scalar_binary_same_shape(
337 ctx,
338 cotangent,
339 &nrm,
340 tenferro_prims::ScalarBinaryOp::Div,
341 )
342 .map_err(to_ad_err)?;
343 let safe_scale = crate::prims_bridge::scalar_where_same_shape(ctx, &nonzero, "ient, &zero)
344 .map_err(to_ad_err)?;
345 let expanded_scale =
346 broadcast_norm_control_to_input(&safe_scale, tensor.dims()).map_err(to_ad_err)?;
347 crate::prims_bridge::complex_scale_same_shape(ctx, tensor, &expanded_scale).map_err(to_ad_err)
348}
349
350fn ensure_complex_norm_ad_supported(kind: NormKind) -> AdResult<()> {
351 match kind {
352 NormKind::Fro => Ok(()),
353 NormKind::Lp(p) if p == 2.0 => Ok(()),
354 _ => Err(chainrules_core::AutodiffError::InvalidArgument(format!(
355 "complex norm AD currently supports Fro and vector L2 only, got {kind:?}"
356 ))),
357 }
358}
359
360fn expected_norm_output_dims<T: LinalgScalar>(tensor: &Tensor<T>) -> &[usize] {
361 if tensor.ndim() <= 1 {
362 &[]
363 } else {
364 &tensor.dims()[2..]
365 }
366}
367
368fn broadcast_norm_control_to_input<R: LinalgScalar>(
369 value_by_batch: &Tensor<R>,
370 input_dims: &[usize],
371) -> Result<Tensor<R>> {
372 if input_dims.len() <= 1 {
373 return value_by_batch.reshape(&[1])?.broadcast(input_dims);
374 }
375
376 let mut reshape_dims = vec![1, 1];
377 reshape_dims.extend_from_slice(&input_dims[2..]);
378 value_by_batch.reshape(&reshape_dims)?.broadcast(input_dims)
379}