tenferro_linalg/primal/norms.rs
1use super::*;
2
3mod norm_impl;
4
5pub(crate) use norm_impl::norm_real_impl;
6
7/// Dispatch trait for [`norm`] — selects real or complex implementation.
8///
9/// This trait is `#[doc(hidden)]` and not intended for external use.
10/// Use [`norm`] directly instead.
11///
12/// # Examples
13///
14/// ```ignore
15/// // Internal dispatch: users should call `norm` instead.
16/// use tenferro_linalg::NormPrimal;
17/// ```
18#[doc(hidden)]
19pub trait NormPrimal<C>: KernelLinalgScalar {
20 fn norm_primal(
21 ctx: &mut C,
22 tensor: &Tensor<Self>,
23 kind: NormKind,
24 ) -> Result<Tensor<Self::Real>>;
25}
26
27/// Solve a triangular linear system `A x = b`.
28///
29/// # Examples
30///
31/// ```
32/// use tenferro_device::LogicalMemorySpace;
33/// use tenferro_linalg::solve_triangular;
34/// use tenferro_prims::CpuContext;
35/// use tenferro_tensor::{MemoryOrder, Tensor};
36///
37/// let mut ctx = CpuContext::new(1);
38/// let a = Tensor::<f64>::from_slice(
39/// &[1.0, 0.0, 0.0, 2.0],
40/// &[2, 2],
41/// MemoryOrder::ColumnMajor,
42/// ).unwrap();
43/// let b = Tensor::<f64>::from_slice(&[1.0, 4.0], &[2], MemoryOrder::ColumnMajor).unwrap();
44/// let x = solve_triangular(&mut ctx, &a, &b, true).unwrap();
45/// assert_eq!(x.logical_memory_space(), LogicalMemorySpace::MainMemory);
46/// ```
47pub fn solve_triangular<T: KernelLinalgScalar, C>(
48 ctx: &mut C,
49 a: &Tensor<T>,
50 b: &Tensor<T>,
51 upper: bool,
52) -> Result<Tensor<T>>
53where
54 C: backend::TensorLinalgContextFor<T>,
55 C::Backend: 'static,
56{
57 <C::Backend as backend::TensorLinalgBackend<T>>::solve_triangular(ctx, a, b, upper)
58}
59
60/// Compute a norm.
61///
62/// Complex inputs return a tensor over the associated real scalar type.
63///
64/// # Examples
65///
66/// ```
67/// use num_complex::Complex64;
68/// use tenferro_device::LogicalMemorySpace;
69/// use tenferro_linalg::{norm, NormKind};
70/// use tenferro_prims::CpuContext;
71/// use tenferro_tensor::{MemoryOrder, Tensor};
72///
73/// let mut ctx = CpuContext::new(1);
74/// let a = Tensor::from_slice(
75/// &[Complex64::new(3.0, 4.0), Complex64::new(0.0, 0.0)],
76/// &[2],
77/// MemoryOrder::ColumnMajor,
78/// )
79/// .unwrap();
80/// let n: Tensor<f64> = norm(&mut ctx, &a, NormKind::Fro).unwrap();
81/// assert_eq!(n.logical_memory_space(), LogicalMemorySpace::MainMemory);
82/// ```
83#[allow(private_bounds)]
84pub fn norm<T, C>(ctx: &mut C, tensor: &Tensor<T>, kind: NormKind) -> Result<Tensor<T::Real>>
85where
86 T: KernelLinalgScalar + NormPrimal<C>,
87 C: backend::TensorLinalgContextFor<T>,
88 C::Backend: 'static,
89{
90 T::norm_primal(ctx, tensor, kind)
91}
92
93/// Compute the matrix condition number with a selected norm convention.
94///
95/// Complex inputs return a real-valued tensor.
96///
97/// # Examples
98///
99/// ```
100/// use num_complex::Complex64;
101/// use tenferro_device::LogicalMemorySpace;
102/// use tenferro_linalg::{cond, NormKind};
103/// use tenferro_prims::CpuContext;
104/// use tenferro_tensor::{MemoryOrder, Tensor};
105///
106/// let mut ctx = CpuContext::new(1);
107/// let a = Tensor::from_slice(
108/// &[
109/// Complex64::new(3.0, 4.0),
110/// Complex64::new(0.0, 0.0),
111/// Complex64::new(0.0, 0.0),
112/// Complex64::new(2.0, 0.0),
113/// ],
114/// &[2, 2],
115/// MemoryOrder::ColumnMajor,
116/// )
117/// .unwrap();
118/// let c: Tensor<f64> = cond(&mut ctx, &a, NormKind::Fro).unwrap();
119/// assert_eq!(c.logical_memory_space(), LogicalMemorySpace::MainMemory);
120/// ```
121#[allow(private_bounds)]
122pub fn cond<T, C>(ctx: &mut C, tensor: &Tensor<T>, kind: NormKind) -> Result<Tensor<T::Real>>
123where
124 T: KernelLinalgScalar + NormPrimal<C>,
125 C: backend::TensorLinalgContextFor<T>
126 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
127 C::Backend: 'static,
128{
129 match kind {
130 NormKind::Fro | NormKind::L1 | NormKind::Inf | NormKind::Spectral | NormKind::Nuclear => {}
131 _ => {
132 return Err(Error::InvalidArgument(format!(
133 "cond only supports Fro, L1, Inf, Spectral, and Nuclear norms, got {kind:?}"
134 )));
135 }
136 }
137
138 validate_square(tensor)?;
139 if matches!(kind, NormKind::Nuclear) {
140 let singular_values = svdvals(ctx, tensor)?;
141 let kept_axes: Vec<usize> = (1..singular_values.ndim()).collect();
142 let nuclear_norm = crate::prims_bridge::scalar_reduce_keep_axes(
143 ctx,
144 &singular_values,
145 &kept_axes,
146 tenferro_prims::ScalarReductionOp::Sum,
147 )?;
148 let reciprocal = crate::prims_bridge::scalar_unary_same_shape(
149 ctx,
150 &singular_values,
151 tenferro_prims::ScalarUnaryOp::Reciprocal,
152 )?;
153 let reciprocal_norm = crate::prims_bridge::scalar_reduce_keep_axes(
154 ctx,
155 &reciprocal,
156 &kept_axes,
157 tenferro_prims::ScalarReductionOp::Sum,
158 )?;
159 return crate::prims_bridge::scalar_binary_same_shape(
160 ctx,
161 &nuclear_norm,
162 &reciprocal_norm,
163 tenferro_prims::ScalarBinaryOp::Mul,
164 );
165 }
166
167 let lhs = norm(ctx, tensor, kind)?;
168 let inverse = inv(ctx, tensor)?;
169 let rhs = norm(ctx, &inverse, kind)?;
170 crate::prims_bridge::scalar_binary_same_shape(
171 ctx,
172 &lhs,
173 &rhs,
174 tenferro_prims::ScalarBinaryOp::Mul,
175 )
176}