1use super::*;
2use tenferro_tensor::MemoryOrder;
3
4pub(crate) fn residual_summary_output_dims(
5 batch_dims: &[usize],
6 nrhs: usize,
7 rhs_is_vector: bool,
8) -> Vec<usize> {
9 if rhs_is_vector {
10 batch_dims.to_vec()
11 } else {
12 output_dims(&[nrhs], batch_dims)
13 }
14}
15
16pub(crate) fn empty_residual_summary<T: LinalgScalar>(
17 memory: tenferro_device::LogicalMemorySpace,
18) -> Result<Tensor<T>> {
19 Tensor::<T>::zeros(&[0], memory, MemoryOrder::ColumnMajor)
20}
21
22pub(crate) fn full_rank_residual_summaries<T: KernelLinalgScalar>(
23 residual_matrix: &Tensor<T>,
24 m: usize,
25 nrhs: usize,
26 batch_dims: &[usize],
27 rhs_is_vector: bool,
28) -> Result<Tensor<T::Real>>
29where
30 T::Real: LinalgScalar<Real = T::Real> + num_traits::Float,
31{
32 let bc = batch_count(batch_dims);
33 let (residual_data, _) =
34 extract_data(residual_matrix).map_err(|e| Error::InvalidArgument(e.to_string()))?;
35 let mut summary = vec![T::Real::zero(); bc * nrhs];
36 for batch in 0..bc {
37 for col in 0..nrhs {
38 let mut acc = T::Real::zero();
39 for row in 0..m {
40 let value = residual_data[batch * m * nrhs + row + col * m].abs_real();
41 acc = acc + value * value;
42 }
43 summary[batch * nrhs + col] = acc;
44 }
45 }
46 let dims = residual_summary_output_dims(batch_dims, nrhs, rhs_is_vector);
47 tensor_from_data(summary, &dims)
48}
49
50pub(crate) fn lstsq_has_full_rank<
51 T: LinalgScalar<Real = T> + num_traits::Float + tenferro_tensor::KeepCountScalar,
52>(
53 rank: &Tensor<T>,
54 expected_rank: usize,
55) -> Result<bool> {
56 let (rank_data, _) = extract_data(rank).map_err(|e| Error::InvalidArgument(e.to_string()))?;
57 let expected_rank = scalar_from::<T>(expected_rank as f64)?;
58 Ok(rank_data.iter().all(|value| *value == expected_rank))
59}
60
61pub fn lstsq<T: KernelLinalgScalar, C>(
78 ctx: &mut C,
79 a: &Tensor<T>,
80 b: &Tensor<T>,
81) -> Result<LstsqResult<T, T::Real>>
82where
83 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
84 T::Real: LinalgScalar<Real = T::Real> + num_traits::Float + tenferro_tensor::KeepCountScalar,
85 C: backend::TensorLinalgContextFor<T>
86 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
87 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>
88 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
89 C::Backend: 'static,
90{
91 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Lstsq, "lstsq")?;
92
93 let (m, n, batch_dims) = validate_2d(a)?;
94 if m < n {
95 return Err(Error::InvalidArgument(format!(
96 "lstsq requires m >= n, got m={m}, n={n}"
97 )));
98 }
99 validate_lstsq_rhs(b, m, batch_dims)?;
100
101 let qr_result = qr(ctx, a)?;
102 let q_input = ensure_col_major(&qr_result.q);
103 let r_input = ensure_col_major(&qr_result.r);
104 let b_input = ensure_col_major(b);
105 let rhs_is_vector = b_input.ndim() == 1 + batch_dims.len();
106
107 let k = m.min(n);
108 let rhs_matrix = if rhs_is_vector {
109 b_input.unsqueeze(1)?
110 } else {
111 b_input.clone()
112 };
113
114 let mut q_perm = vec![1, 0];
115 q_perm.extend(2..q_input.ndim());
116 let q_adj = q_input.conj().permute(&q_perm)?;
117 let nrhs = rhs_matrix.dims()[1];
118 let qtb = crate::prims_bridge::batched_gemm_with_semiring_tensors(
119 ctx,
120 &q_adj,
121 &rhs_matrix,
122 k,
123 m,
124 nrhs,
125 )?;
126 let x_matrix = solve_triangular(ctx, &r_input, &qtb, true)?;
127 let x = if rhs_is_vector {
128 x_matrix.squeeze_dim(1)?
129 } else {
130 x_matrix
131 };
132 let projected_rhs =
133 crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &q_input, &qtb, m, k, nrhs)?;
134 let residual_matrix = crate::prims_bridge::scalar_binary_same_shape(
135 ctx,
136 &projected_rhs,
137 &rhs_matrix,
138 tenferro_prims::ScalarBinaryOp::Sub,
139 )?;
140 let aux = lstsq_aux(ctx, a)?;
141 let residuals = if m > n && lstsq_has_full_rank(&aux.rank, n)? {
142 full_rank_residual_summaries(&residual_matrix, m, nrhs, batch_dims, rhs_is_vector)?
143 } else {
144 empty_residual_summary::<T::Real>(a.logical_memory_space())?
145 };
146
147 Ok(LstsqResult {
148 solution: x,
149 residuals,
150 })
151}
152
153pub fn lstsq_aux<T: KernelLinalgScalar, C>(
158 ctx: &mut C,
159 a: &Tensor<T>,
160) -> Result<LstsqAuxResult<T::Real>>
161where
162 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
163 T::Real: tenferro_tensor::KeepCountScalar + num_traits::Float,
164 C: backend::TensorLinalgContextFor<T>
165 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
166 C::Backend: 'static,
167{
168 let singular_values = svdvals(ctx, a)?;
169 let (_, _, batch_dims) = validate_2d(a)?;
170 let rank = lstsq_rank_counts_tensor(
171 ctx,
172 &singular_values,
173 a.dims()[0].max(a.dims()[1]),
174 batch_dims,
175 )?;
176 Ok(LstsqAuxResult {
177 rank,
178 singular_values,
179 })
180}
181
182fn lstsq_rank_counts_tensor<R, C>(
183 ctx: &mut C,
184 singular_values: &Tensor<R>,
185 scale: usize,
186 batch_dims: &[usize],
187) -> Result<Tensor<R>>
188where
189 R: LinalgScalar<Real = R> + num_traits::Float + tenferro_tensor::KeepCountScalar,
190 C: tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>,
191{
192 let k = singular_values.dims().first().copied().unwrap_or(0);
193 if k == 0 {
194 return crate::prims_bridge::full_like_constant(
195 R::zero(),
196 batch_dims,
197 singular_values.logical_memory_space(),
198 );
199 }
200
201 let kept_axes: Vec<usize> = (1..singular_values.ndim()).collect();
202 let max_sigma = crate::prims_bridge::scalar_reduce_keep_axes(
203 ctx,
204 singular_values,
205 &kept_axes,
206 tenferro_prims::ScalarReductionOp::Max,
207 )?;
208 let scaled_eps = scalar_from::<R>(scale as f64)? * R::epsilon();
209 let scaled_eps_tensor = crate::prims_bridge::full_like_constant(
210 scaled_eps,
211 max_sigma.dims(),
212 max_sigma.logical_memory_space(),
213 )?;
214 let tol_by_batch = crate::prims_bridge::scalar_binary_same_shape(
215 ctx,
216 &max_sigma,
217 &scaled_eps_tensor,
218 tenferro_prims::ScalarBinaryOp::Mul,
219 )?;
220 let tol = broadcast_lstsq_batch_control(&tol_by_batch, singular_values.dims())?;
221 let active = crate::prims_bridge::scalar_binary_same_shape(
222 ctx,
223 singular_values,
224 &tol,
225 tenferro_prims::ScalarBinaryOp::Greater,
226 )?;
227 crate::prims_bridge::scalar_sum_keep_axes(ctx, &active, &kept_axes)
228}
229
230fn broadcast_lstsq_batch_control<R: LinalgScalar>(
231 value_by_batch: &Tensor<R>,
232 singular_dims: &[usize],
233) -> Result<Tensor<R>> {
234 if singular_dims.len() <= 1 {
235 return value_by_batch.reshape(&[1])?.broadcast(singular_dims);
236 }
237
238 let mut reshape_dims = vec![1];
239 reshape_dims.extend_from_slice(&singular_dims[1..]);
240 value_by_batch
241 .reshape(&reshape_dims)?
242 .broadcast(singular_dims)
243}
244
245pub fn cholesky<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
261where
262 C: backend::TensorLinalgContextFor<T>,
263 C::Backend: 'static,
264{
265 <C::Backend as backend::TensorLinalgBackend<T>>::cholesky(ctx, tensor)
266}
267
268pub fn cholesky_ex<T: KernelLinalgScalar, C>(
285 ctx: &mut C,
286 tensor: &Tensor<T>,
287) -> Result<CholeskyExResult<T>>
288where
289 T: KernelLinalgScalar,
290 C: backend::TensorLinalgContextFor<T>,
291 C::Backend: 'static,
292{
293 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::CholeskyEx, "cholesky_ex")?;
294 let result = <C::Backend as backend::TensorLinalgBackend<T>>::cholesky_ex(ctx, tensor)?;
295 Ok(CholeskyExResult {
296 l: result.l,
297 info: result.info,
298 })
299}