tenferro_linalg/primal/
matrix_functions.rs1use super::*;
2
3fn batched_identity<T: KernelLinalgScalar>(
4 n: usize,
5 batch_dims: &[usize],
6 logical_memory_space: tenferro_device::LogicalMemorySpace,
7) -> Result<Tensor<T>> {
8 let mut reshape_dims = vec![n, n];
9 reshape_dims.extend(std::iter::repeat_n(1, batch_dims.len()));
10 let eye = crate::prims_bridge::identity_matrix(n, logical_memory_space)?;
11 let eye = eye.reshape(&reshape_dims)?;
12 eye.broadcast(&output_dims(&[n, n], batch_dims))
13}
14
15pub fn matrix_power<T: KernelLinalgScalar, C>(
31 ctx: &mut C,
32 tensor: &Tensor<T>,
33 exponent: i64,
34) -> Result<Tensor<T>>
35where
36 T: KernelLinalgScalar,
37 C: backend::TensorLinalgContextFor<T>,
38 C: tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
39 C::Backend: 'static,
40{
41 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixPower, "matrix_power")?;
42
43 let (n, batch_dims) = validate_square(tensor)?;
44
45 if exponent == 0 {
46 return batched_identity::<T>(n, batch_dims, tensor.logical_memory_space());
47 }
48 if exponent == 1 {
49 return Ok(tensor.clone());
50 }
51 if exponent == -1 {
52 return inv(ctx, tensor);
53 }
54
55 let mut positive_exponent = if exponent < 0 {
56 exponent.checked_abs().ok_or_else(|| {
57 Error::InvalidArgument("matrix_power does not support i64::MIN exponent".into())
58 })? as u64
59 } else {
60 exponent as u64
61 };
62 let mut base = if exponent < 0 {
63 inv(ctx, tensor)?
64 } else {
65 tensor.clone()
66 };
67
68 if positive_exponent == 2 {
69 return crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &base, &base, n, n, n);
70 }
71 if positive_exponent == 3 {
72 let base_squared =
73 crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &base, &base, n, n, n)?;
74 return crate::prims_bridge::batched_gemm_with_semiring_tensors(
75 ctx,
76 &base_squared,
77 &base,
78 n,
79 n,
80 n,
81 );
82 }
83
84 let mut result = batched_identity::<T>(n, batch_dims, tensor.logical_memory_space())?;
85
86 while positive_exponent > 0 {
87 if positive_exponent & 1 == 1 {
88 result = crate::prims_bridge::batched_gemm_with_semiring_tensors(
89 ctx, &result, &base, n, n, n,
90 )?;
91 }
92 let next_exponent = positive_exponent >> 1;
93 if next_exponent > 0 {
94 base = crate::prims_bridge::batched_gemm_with_semiring_tensors(
95 ctx, &base, &base, n, n, n,
96 )?;
97 }
98 positive_exponent = next_exponent;
99 }
100
101 Ok(result)
102}
103
104#[cfg(test)]
105mod tests;