tenferro_linalg/primal/
least_squares.rs

1use super::*;
2use tenferro_tensor::MemoryOrder;
3
4pub(crate) fn residual_summary_output_dims(
5    batch_dims: &[usize],
6    nrhs: usize,
7    rhs_is_vector: bool,
8) -> Vec<usize> {
9    if rhs_is_vector {
10        batch_dims.to_vec()
11    } else {
12        output_dims(&[nrhs], batch_dims)
13    }
14}
15
16pub(crate) fn empty_residual_summary<T: LinalgScalar>(
17    memory: tenferro_device::LogicalMemorySpace,
18) -> Result<Tensor<T>> {
19    Tensor::<T>::zeros(&[0], memory, MemoryOrder::ColumnMajor)
20}
21
22pub(crate) fn full_rank_residual_summaries<T: KernelLinalgScalar>(
23    residual_matrix: &Tensor<T>,
24    m: usize,
25    nrhs: usize,
26    batch_dims: &[usize],
27    rhs_is_vector: bool,
28) -> Result<Tensor<T::Real>>
29where
30    T::Real: LinalgScalar<Real = T::Real> + num_traits::Float,
31{
32    let bc = batch_count(batch_dims);
33    let (residual_data, _) =
34        extract_data(residual_matrix).map_err(|e| Error::InvalidArgument(e.to_string()))?;
35    let mut summary = vec![T::Real::zero(); bc * nrhs];
36    for batch in 0..bc {
37        for col in 0..nrhs {
38            let mut acc = T::Real::zero();
39            for row in 0..m {
40                let value = residual_data[batch * m * nrhs + row + col * m].abs_real();
41                acc = acc + value * value;
42            }
43            summary[batch * nrhs + col] = acc;
44        }
45    }
46    let dims = residual_summary_output_dims(batch_dims, nrhs, rhs_is_vector);
47    tensor_from_data(summary, &dims)
48}
49
50pub(crate) fn lstsq_has_full_rank<
51    T: LinalgScalar<Real = T> + num_traits::Float + tenferro_tensor::KeepCountScalar,
52>(
53    rank: &Tensor<T>,
54    expected_rank: usize,
55) -> Result<bool> {
56    let (rank_data, _) = extract_data(rank).map_err(|e| Error::InvalidArgument(e.to_string()))?;
57    let expected_rank = scalar_from::<T>(expected_rank as f64)?;
58    Ok(rank_data.iter().all(|value| *value == expected_rank))
59}
60
61/// Solve the least squares problem: `x = argmin ||Ax - b||²`.
62///
63/// # Examples
64///
65/// ```
66/// use tenferro_linalg::lstsq;
67/// use tenferro_prims::CpuContext;
68/// use tenferro_tensor::{MemoryOrder, Tensor};
69///
70/// let mut ctx = CpuContext::new(1);
71/// let col = MemoryOrder::ColumnMajor;
72/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 1.0, 1.0], &[2, 2], col).unwrap();
73/// let b = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], col).unwrap();
74/// let result = lstsq(&mut ctx, &a, &b).unwrap();
75/// assert_eq!(result.solution.dims(), &[2]);
76/// ```
77pub fn lstsq<T: KernelLinalgScalar, C>(
78    ctx: &mut C,
79    a: &Tensor<T>,
80    b: &Tensor<T>,
81) -> Result<LstsqResult<T, T::Real>>
82where
83    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
84    T::Real: LinalgScalar<Real = T::Real> + num_traits::Float + tenferro_tensor::KeepCountScalar,
85    C: backend::TensorLinalgContextFor<T>
86        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
87        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>
88        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
89    C::Backend: 'static,
90{
91    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Lstsq, "lstsq")?;
92
93    let (m, n, batch_dims) = validate_2d(a)?;
94    if m < n {
95        return Err(Error::InvalidArgument(format!(
96            "lstsq requires m >= n, got m={m}, n={n}"
97        )));
98    }
99    validate_lstsq_rhs(b, m, batch_dims)?;
100
101    let qr_result = qr(ctx, a)?;
102    let q_input = ensure_col_major(&qr_result.q);
103    let r_input = ensure_col_major(&qr_result.r);
104    let b_input = ensure_col_major(b);
105    let rhs_is_vector = b_input.ndim() == 1 + batch_dims.len();
106
107    let k = m.min(n);
108    let rhs_matrix = if rhs_is_vector {
109        b_input.unsqueeze(1)?
110    } else {
111        b_input.clone()
112    };
113
114    let mut q_perm = vec![1, 0];
115    q_perm.extend(2..q_input.ndim());
116    let q_adj = q_input.conj().permute(&q_perm)?;
117    let nrhs = rhs_matrix.dims()[1];
118    let qtb = crate::prims_bridge::batched_gemm_with_semiring_tensors(
119        ctx,
120        &q_adj,
121        &rhs_matrix,
122        k,
123        m,
124        nrhs,
125    )?;
126    let x_matrix = solve_triangular(ctx, &r_input, &qtb, true)?;
127    let x = if rhs_is_vector {
128        x_matrix.squeeze_dim(1)?
129    } else {
130        x_matrix
131    };
132    let projected_rhs =
133        crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &q_input, &qtb, m, k, nrhs)?;
134    let residual_matrix = crate::prims_bridge::scalar_binary_same_shape(
135        ctx,
136        &projected_rhs,
137        &rhs_matrix,
138        tenferro_prims::ScalarBinaryOp::Sub,
139    )?;
140    let aux = lstsq_aux(ctx, a)?;
141    let residuals = if m > n && lstsq_has_full_rank(&aux.rank, n)? {
142        full_rank_residual_summaries(&residual_matrix, m, nrhs, batch_dims, rhs_is_vector)?
143    } else {
144        empty_residual_summary::<T::Real>(a.logical_memory_space())?
145    };
146
147    Ok(LstsqResult {
148        solution: x,
149        residuals,
150    })
151}
152
153/// Compute least-squares auxiliary metadata.
154///
155/// This returns the singular values used for numerical rank estimation together
156/// with a batch-shaped count tensor containing the effective rank.
157pub fn lstsq_aux<T: KernelLinalgScalar, C>(
158    ctx: &mut C,
159    a: &Tensor<T>,
160) -> Result<LstsqAuxResult<T::Real>>
161where
162    T: KernelLinalgScalar + tenferro_algebra::Conjugate,
163    T::Real: tenferro_tensor::KeepCountScalar + num_traits::Float,
164    C: backend::TensorLinalgContextFor<T>
165        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
166    C::Backend: 'static,
167{
168    let singular_values = svdvals(ctx, a)?;
169    let (_, _, batch_dims) = validate_2d(a)?;
170    let rank = lstsq_rank_counts_tensor(
171        ctx,
172        &singular_values,
173        a.dims()[0].max(a.dims()[1]),
174        batch_dims,
175    )?;
176    Ok(LstsqAuxResult {
177        rank,
178        singular_values,
179    })
180}
181
182fn lstsq_rank_counts_tensor<R, C>(
183    ctx: &mut C,
184    singular_values: &Tensor<R>,
185    scale: usize,
186    batch_dims: &[usize],
187) -> Result<Tensor<R>>
188where
189    R: LinalgScalar<Real = R> + num_traits::Float + tenferro_tensor::KeepCountScalar,
190    C: tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>,
191{
192    let k = singular_values.dims().first().copied().unwrap_or(0);
193    if k == 0 {
194        return crate::prims_bridge::full_like_constant(
195            R::zero(),
196            batch_dims,
197            singular_values.logical_memory_space(),
198        );
199    }
200
201    let kept_axes: Vec<usize> = (1..singular_values.ndim()).collect();
202    let max_sigma = crate::prims_bridge::scalar_reduce_keep_axes(
203        ctx,
204        singular_values,
205        &kept_axes,
206        tenferro_prims::ScalarReductionOp::Max,
207    )?;
208    let scaled_eps = scalar_from::<R>(scale as f64)? * R::epsilon();
209    let scaled_eps_tensor = crate::prims_bridge::full_like_constant(
210        scaled_eps,
211        max_sigma.dims(),
212        max_sigma.logical_memory_space(),
213    )?;
214    let tol_by_batch = crate::prims_bridge::scalar_binary_same_shape(
215        ctx,
216        &max_sigma,
217        &scaled_eps_tensor,
218        tenferro_prims::ScalarBinaryOp::Mul,
219    )?;
220    let tol = broadcast_lstsq_batch_control(&tol_by_batch, singular_values.dims())?;
221    let active = crate::prims_bridge::scalar_binary_same_shape(
222        ctx,
223        singular_values,
224        &tol,
225        tenferro_prims::ScalarBinaryOp::Greater,
226    )?;
227    crate::prims_bridge::scalar_sum_keep_axes(ctx, &active, &kept_axes)
228}
229
230fn broadcast_lstsq_batch_control<R: LinalgScalar>(
231    value_by_batch: &Tensor<R>,
232    singular_dims: &[usize],
233) -> Result<Tensor<R>> {
234    if singular_dims.len() <= 1 {
235        return value_by_batch.reshape(&[1])?.broadcast(singular_dims);
236    }
237
238    let mut reshape_dims = vec![1];
239    reshape_dims.extend_from_slice(&singular_dims[1..]);
240    value_by_batch
241        .reshape(&reshape_dims)?
242        .broadcast(singular_dims)
243}
244
245/// Compute the Cholesky decomposition of a Hermitian positive-definite matrix.
246///
247/// # Examples
248///
249/// ```
250/// use tenferro_linalg::cholesky;
251/// use tenferro_prims::CpuContext;
252/// use tenferro_tensor::{MemoryOrder, Tensor};
253///
254/// let mut ctx = CpuContext::new(1);
255/// let col = MemoryOrder::ColumnMajor;
256/// let a = Tensor::<f64>::from_slice(&[4.0, 2.0, 2.0, 3.0], &[2, 2], col).unwrap();
257/// let l = cholesky(&mut ctx, &a).unwrap();
258/// assert_eq!(l.dims(), &[2, 2]);
259/// ```
260pub fn cholesky<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
261where
262    C: backend::TensorLinalgContextFor<T>,
263    C::Backend: 'static,
264{
265    <C::Backend as backend::TensorLinalgBackend<T>>::cholesky(ctx, tensor)
266}
267
268/// Compute the Cholesky decomposition with numerical status information.
269///
270/// # Examples
271///
272/// ```
273/// use tenferro_linalg::cholesky_ex;
274/// use tenferro_prims::CpuContext;
275/// use tenferro_tensor::{MemoryOrder, Tensor};
276///
277/// let mut ctx = CpuContext::new(1);
278/// let col = MemoryOrder::ColumnMajor;
279/// let a = Tensor::<f64>::from_slice(&[4.0, 2.0, 2.0, 3.0], &[2, 2], col).unwrap();
280/// let result = cholesky_ex(&mut ctx, &a).unwrap();
281/// assert_eq!(result.l.dims(), &[2, 2]);
282/// assert_eq!(result.info.len(), 1);
283/// ```
284pub fn cholesky_ex<T: KernelLinalgScalar, C>(
285    ctx: &mut C,
286    tensor: &Tensor<T>,
287) -> Result<CholeskyExResult<T>>
288where
289    T: KernelLinalgScalar,
290    C: backend::TensorLinalgContextFor<T>,
291    C::Backend: 'static,
292{
293    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::CholeskyEx, "cholesky_ex")?;
294    let result = <C::Backend as backend::TensorLinalgBackend<T>>::cholesky_ex(ctx, tensor)?;
295    Ok(CholeskyExResult {
296        l: result.l,
297        info: result.info,
298    })
299}