tenferro_linalg/primal/
matrix_functions.rs

1use super::*;
2
3fn batched_identity<T: KernelLinalgScalar>(
4    n: usize,
5    batch_dims: &[usize],
6    logical_memory_space: tenferro_device::LogicalMemorySpace,
7) -> Result<Tensor<T>> {
8    let mut reshape_dims = vec![n, n];
9    reshape_dims.extend(std::iter::repeat_n(1, batch_dims.len()));
10    let eye = crate::prims_bridge::identity_matrix(n, logical_memory_space)?;
11    let eye = eye.reshape(&reshape_dims)?;
12    eye.broadcast(&output_dims(&[n, n], batch_dims))
13}
14
15/// Raise a square matrix to an integer power.
16///
17/// # Examples
18///
19/// ```
20/// use tenferro_linalg::matrix_power;
21/// use tenferro_prims::CpuContext;
22/// use tenferro_tensor::{MemoryOrder, Tensor};
23///
24/// let mut ctx = CpuContext::new(1);
25/// let col = MemoryOrder::ColumnMajor;
26/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 2.0], &[2, 2], col).unwrap();
27/// let a3 = matrix_power(&mut ctx, &a, 3).unwrap();
28/// assert_eq!(a3.dims(), &[2, 2]);
29/// ```
30pub fn matrix_power<T: KernelLinalgScalar, C>(
31    ctx: &mut C,
32    tensor: &Tensor<T>,
33    exponent: i64,
34) -> Result<Tensor<T>>
35where
36    T: KernelLinalgScalar,
37    C: backend::TensorLinalgContextFor<T>,
38    C: tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
39    C::Backend: 'static,
40{
41    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixPower, "matrix_power")?;
42
43    let (n, batch_dims) = validate_square(tensor)?;
44
45    if exponent == 0 {
46        return batched_identity::<T>(n, batch_dims, tensor.logical_memory_space());
47    }
48    if exponent == 1 {
49        return Ok(tensor.clone());
50    }
51    if exponent == -1 {
52        return inv(ctx, tensor);
53    }
54
55    let mut positive_exponent = if exponent < 0 {
56        exponent.checked_abs().ok_or_else(|| {
57            Error::InvalidArgument("matrix_power does not support i64::MIN exponent".into())
58        })? as u64
59    } else {
60        exponent as u64
61    };
62    let mut base = if exponent < 0 {
63        inv(ctx, tensor)?
64    } else {
65        tensor.clone()
66    };
67
68    if positive_exponent == 2 {
69        return crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &base, &base, n, n, n);
70    }
71    if positive_exponent == 3 {
72        let base_squared =
73            crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &base, &base, n, n, n)?;
74        return crate::prims_bridge::batched_gemm_with_semiring_tensors(
75            ctx,
76            &base_squared,
77            &base,
78            n,
79            n,
80            n,
81        );
82    }
83
84    let mut result = batched_identity::<T>(n, batch_dims, tensor.logical_memory_space())?;
85
86    while positive_exponent > 0 {
87        if positive_exponent & 1 == 1 {
88            result = crate::prims_bridge::batched_gemm_with_semiring_tensors(
89                ctx, &result, &base, n, n, n,
90            )?;
91        }
92        let next_exponent = positive_exponent >> 1;
93        if next_exponent > 0 {
94            base = crate::prims_bridge::batched_gemm_with_semiring_tensors(
95                ctx, &base, &base, n, n, n,
96            )?;
97        }
98        positive_exponent = next_exponent;
99    }
100
101    Ok(result)
102}
103
104#[cfg(test)]
105mod tests;