tenferro_linalg/primal/
spectral.rs

1use super::*;
2
3/// Compute the eigendecomposition of a general (non-symmetric) square matrix.
4///
5/// Returns complex eigenvalues and eigenvectors even for real inputs.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_linalg::eig;
11/// use tenferro_prims::CpuContext;
12/// use tenferro_tensor::{MemoryOrder, Tensor};
13///
14/// let mut ctx = CpuContext::new(1);
15/// let col = MemoryOrder::ColumnMajor;
16/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 2.0], &[2, 2], col).unwrap();
17/// let result = eig(&mut ctx, &a).unwrap();
18/// assert_eq!(result.values.dims(), &[2]);
19/// assert_eq!(result.vectors.dims(), &[2, 2]);
20/// ```
21pub fn eig<
22    T: KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>> + num_traits::Float,
23    C,
24>(
25    ctx: &mut C,
26    tensor: &Tensor<T>,
27) -> Result<EigResult<T>>
28where
29    C: backend::TensorLinalgContextFor<T>,
30    C::Backend: 'static,
31{
32    let result = <C::Backend as backend::TensorLinalgBackend<T>>::eig(ctx, tensor)?;
33    Ok(EigResult {
34        values: result.values,
35        vectors: result.vectors,
36    })
37}
38
39pub(crate) fn require_linalg_support<T: KernelLinalgScalar, C>(
40    capability: backend::LinalgCapabilityOp,
41    op: &str,
42) -> Result<()>
43where
44    C: backend::TensorLinalgContextFor<T>,
45    C::Backend: 'static,
46{
47    if <C::Backend as backend::TensorLinalgBackend<T>>::has_linalg_support(capability) {
48        return Ok(());
49    }
50
51    Err(Error::DeviceError(format!(
52        "{op} is not supported on the current linalg backend"
53    )))
54}
55
56fn broadcast_batch_control_to_matrix<T: KernelLinalgScalar>(
57    value_by_batch: &Tensor<T::Real>,
58    batch_dims: &[usize],
59    matrix_dims: &[usize],
60) -> Result<Tensor<T::Real>> {
61    let mut reshape_dims = vec![1, 1];
62    reshape_dims.extend_from_slice(batch_dims);
63    value_by_batch
64        .reshape(&reshape_dims)?
65        .broadcast(matrix_dims)
66}
67
68/// Compute the Moore-Penrose pseudoinverse of a matrix.
69///
70/// # Examples
71///
72/// ```
73/// use tenferro_device::LogicalMemorySpace;
74/// use tenferro_linalg::pinv;
75/// use tenferro_prims::CpuContext;
76/// use tenferro_tensor::{MemoryOrder, Tensor};
77///
78/// let mut ctx = CpuContext::new(1);
79/// let a = Tensor::<f64>::from_slice(
80///     &[1.0, 2.0, 3.0, 4.0],
81///     &[2, 2],
82///     MemoryOrder::ColumnMajor,
83/// ).unwrap();
84/// let ap = pinv(&mut ctx, &a, None).unwrap();
85/// assert_eq!(ap.logical_memory_space(), LogicalMemorySpace::MainMemory);
86/// ```
87pub fn pinv<
88    T: KernelLinalgScalar
89        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
90        + tenferro_algebra::Conjugate,
91    C,
92>(
93    ctx: &mut C,
94    tensor: &Tensor<T>,
95    rcond: Option<f64>,
96) -> Result<Tensor<T>>
97where
98    C: backend::TensorLinalgContextFor<T>
99        + tenferro_prims::TensorResolveConjContextFor<T>
100        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
101        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
102    C::Backend: 'static,
103    T::Real: num_traits::Float + tenferro_tensor::KeepCountScalar,
104{
105    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Pinv, "pinv")?;
106
107    let (m, n, batch_dims) = validate_2d(tensor)?;
108    let k = m.min(n);
109    if k == 0 {
110        let dims = output_dims(&[n, m], batch_dims);
111        return Tensor::zeros(
112            &dims,
113            tensor.logical_memory_space(),
114            MemoryOrder::ColumnMajor,
115        );
116    }
117
118    let svd_result = svd(ctx, tensor, None)?;
119    let u_input = ensure_col_major(&svd_result.u);
120    let s_input = ensure_col_major(&svd_result.s);
121    let vt_input = ensure_col_major(&svd_result.vt);
122    let s_max_axes: Vec<usize> = (1..s_input.ndim()).collect();
123    let s_max = crate::prims_bridge::scalar_reduce_keep_axes(
124        ctx,
125        &s_input,
126        &s_max_axes,
127        tenferro_prims::ScalarReductionOp::Max,
128    )?;
129    let threshold: T::Real = scalar_from(rcond.unwrap_or(1e-15))?;
130    let threshold_tensor = crate::prims_bridge::full_like_constant(
131        threshold,
132        s_max.dims(),
133        s_max.logical_memory_space(),
134    )?;
135    let cutoff = crate::prims_bridge::scalar_binary_same_shape(
136        ctx,
137        &s_max,
138        &threshold_tensor,
139        tenferro_prims::ScalarBinaryOp::Mul,
140    )?;
141    let cutoff = cutoff.unsqueeze(0)?.broadcast(s_input.dims())?;
142    let keep_mask = crate::prims_bridge::scalar_binary_same_shape(
143        ctx,
144        &s_input,
145        &cutoff,
146        tenferro_prims::ScalarBinaryOp::Greater,
147    )?;
148    let one_mask = crate::prims_bridge::full_like_constant(
149        <T::Real as num_traits::One>::one(),
150        keep_mask.dims(),
151        keep_mask.logical_memory_space(),
152    )?;
153    let drop_mask = crate::prims_bridge::scalar_binary_same_shape(
154        ctx,
155        &one_mask,
156        &keep_mask,
157        tenferro_prims::ScalarBinaryOp::Sub,
158    )?;
159    let kept_s = crate::prims_bridge::scalar_binary_same_shape(
160        ctx,
161        &s_input,
162        &keep_mask,
163        tenferro_prims::ScalarBinaryOp::Mul,
164    )?;
165    let safe_s = crate::prims_bridge::scalar_binary_same_shape(
166        ctx,
167        &kept_s,
168        &drop_mask,
169        tenferro_prims::ScalarBinaryOp::Add,
170    )?;
171
172    let sinv = crate::prims_bridge::scalar_unary_same_shape(
173        ctx,
174        &safe_s,
175        tenferro_prims::ScalarUnaryOp::Reciprocal,
176    )?;
177    let sinv = crate::prims_bridge::scalar_binary_same_shape(
178        ctx,
179        &sinv,
180        &keep_mask,
181        tenferro_prims::ScalarBinaryOp::Mul,
182    )?;
183
184    let sinv_for_vt = sinv.unsqueeze(1)?.broadcast(vt_input.dims())?;
185    let vt_scaled = crate::prims_bridge::complex_scale_same_shape(ctx, &vt_input, &sinv_for_vt)?;
186
187    let mut perm = vec![1, 0];
188    perm.extend(2..u_input.ndim());
189    let u_t = crate::prims_bridge::resolve_conj(ctx, &u_input.conj().permute(&perm)?);
190    let vt_t = crate::prims_bridge::resolve_conj(ctx, &vt_scaled.conj().permute(&perm)?);
191
192    crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &vt_t, &u_t, n, k, m)
193}
194
195/// Compute the matrix exponential `exp(A)` of a square matrix.
196///
197/// # Examples
198///
199/// ```
200/// use tenferro_device::LogicalMemorySpace;
201/// use tenferro_linalg::matrix_exp;
202/// use tenferro_prims::CpuContext;
203/// use tenferro_tensor::{MemoryOrder, Tensor};
204///
205/// let mut ctx = CpuContext::new(1);
206/// let col = MemoryOrder::ColumnMajor;
207/// let mem = LogicalMemorySpace::MainMemory;
208/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col).unwrap();
209/// let result = matrix_exp(&mut ctx, &a).unwrap();
210/// assert_eq!(result.dims(), &[3, 3]);
211/// ```
212#[allow(private_bounds)]
213pub fn matrix_exp<T, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
214where
215    T: KernelLinalgScalar
216        + crate::prims_bridge::ScaleTensorByRealSameShape<C>
217        + MatrixExpAbsTensor<C>,
218    T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
219    C: backend::TensorLinalgContextFor<T>,
220    C: tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
221    C: tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
222    C: tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
223    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
224        tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T::Real>, Context = C>,
225    C::Backend: 'static,
226{
227    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixExp, "matrix_exp")?;
228
229    let (n, batch_dims) = validate_square(tensor)?;
230    let input = ensure_col_major(tensor);
231    if n == 0 {
232        let dims = output_dims(&[n, n], batch_dims);
233        return Tensor::zeros(
234            &dims,
235            input.logical_memory_space(),
236            MemoryOrder::ColumnMajor,
237        );
238    }
239
240    let batch_norms = crate::ad_helpers::matrix_exp_batch_1_norms_tensor(ctx, &input)?;
241    let s_by_batch_tensor =
242        crate::ad_helpers::matrix_exp_batch_squaring_counts_tensor(ctx, &batch_norms)?;
243    let s_max_tensor = if s_by_batch_tensor.ndim() == 0 {
244        s_by_batch_tensor.clone()
245    } else {
246        crate::prims_bridge::scalar_reduce_keep_axes(
247            ctx,
248            &s_by_batch_tensor,
249            &[],
250            tenferro_prims::ScalarReductionOp::Max,
251        )?
252    };
253    let s_max_host =
254        s_max_tensor.to_memory_space_async(tenferro_device::LogicalMemorySpace::MainMemory)?;
255    let s_max_slice = s_max_host.buffer().as_slice().ok_or_else(|| {
256        Error::InvalidArgument("matrix_exp: expected max squaring count tensor on host".into())
257    })?;
258    let s_max = s_max_slice
259        .first()
260        .copied()
261        .ok_or_else(|| Error::InvalidArgument("matrix_exp: missing max squaring count".into()))
262        .and_then(|value| {
263            num_traits::NumCast::from(value).ok_or_else(|| {
264                Error::InvalidArgument("matrix_exp: cannot convert max squaring count".into())
265            })
266        })?;
267
268    let two_by_batch = crate::prims_bridge::full_like_constant(
269        crate::ad_helpers::scalar_from::<T::Real>(2.0)?,
270        s_by_batch_tensor.dims(),
271        input.logical_memory_space(),
272    )?;
273    let scale_denom = crate::prims_bridge::analytic_binary_same_shape(
274        ctx,
275        &two_by_batch,
276        &s_by_batch_tensor,
277        tenferro_prims::AnalyticBinaryOp::Pow,
278    )?;
279    let scale_by_batch = crate::prims_bridge::scalar_unary_same_shape(
280        ctx,
281        &scale_denom,
282        tenferro_prims::ScalarUnaryOp::Reciprocal,
283    )?;
284    let scale_tensor =
285        broadcast_batch_control_to_matrix::<T>(&scale_by_batch, batch_dims, input.dims())?;
286    let scaled_input =
287        <T as crate::prims_bridge::ScaleTensorByRealSameShape<C>>::scale_tensor_by_real_same_shape(
288            ctx,
289            &input,
290            &scale_tensor,
291        )?;
292
293    let mut result =
294        crate::ad_helpers::matrix_exp_tensor_native(ctx, &scaled_input, n, batch_dims, 0)?;
295    for round in 0..s_max {
296        let squared = crate::prims_bridge::batched_gemm_with_semiring_tensors(
297            ctx, &result, &result, n, n, n,
298        )?;
299        let round_by_batch = crate::prims_bridge::full_like_constant(
300            crate::ad_helpers::scalar_from::<T::Real>(round as f64)?,
301            s_by_batch_tensor.dims(),
302            result.logical_memory_space(),
303        )?;
304        let mask_by_batch = crate::prims_bridge::scalar_binary_same_shape(
305            ctx,
306            &s_by_batch_tensor,
307            &round_by_batch,
308            tenferro_prims::ScalarBinaryOp::Greater,
309        )?;
310        let mask_tensor =
311            broadcast_batch_control_to_matrix::<T>(&mask_by_batch, batch_dims, result.dims())?;
312        result = crate::ad_helpers::blend_tensor_by_real_mask_same_shape(
313            ctx,
314            &squared,
315            &result,
316            &mask_tensor,
317        )?;
318    }
319
320    Ok(result)
321}