1use super::*;
2
3pub fn eig<
22 T: KernelLinalgScalar<Real = T, Complex = num_complex::Complex<T>> + num_traits::Float,
23 C,
24>(
25 ctx: &mut C,
26 tensor: &Tensor<T>,
27) -> Result<EigResult<T>>
28where
29 C: backend::TensorLinalgContextFor<T>,
30 C::Backend: 'static,
31{
32 let result = <C::Backend as backend::TensorLinalgBackend<T>>::eig(ctx, tensor)?;
33 Ok(EigResult {
34 values: result.values,
35 vectors: result.vectors,
36 })
37}
38
39pub(crate) fn require_linalg_support<T: KernelLinalgScalar, C>(
40 capability: backend::LinalgCapabilityOp,
41 op: &str,
42) -> Result<()>
43where
44 C: backend::TensorLinalgContextFor<T>,
45 C::Backend: 'static,
46{
47 if <C::Backend as backend::TensorLinalgBackend<T>>::has_linalg_support(capability) {
48 return Ok(());
49 }
50
51 Err(Error::DeviceError(format!(
52 "{op} is not supported on the current linalg backend"
53 )))
54}
55
56fn broadcast_batch_control_to_matrix<T: KernelLinalgScalar>(
57 value_by_batch: &Tensor<T::Real>,
58 batch_dims: &[usize],
59 matrix_dims: &[usize],
60) -> Result<Tensor<T::Real>> {
61 let mut reshape_dims = vec![1, 1];
62 reshape_dims.extend_from_slice(batch_dims);
63 value_by_batch
64 .reshape(&reshape_dims)?
65 .broadcast(matrix_dims)
66}
67
68pub fn pinv<
88 T: KernelLinalgScalar
89 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
90 + tenferro_algebra::Conjugate,
91 C,
92>(
93 ctx: &mut C,
94 tensor: &Tensor<T>,
95 rcond: Option<f64>,
96) -> Result<Tensor<T>>
97where
98 C: backend::TensorLinalgContextFor<T>
99 + tenferro_prims::TensorResolveConjContextFor<T>
100 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
101 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
102 C::Backend: 'static,
103 T::Real: num_traits::Float + tenferro_tensor::KeepCountScalar,
104{
105 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Pinv, "pinv")?;
106
107 let (m, n, batch_dims) = validate_2d(tensor)?;
108 let k = m.min(n);
109 if k == 0 {
110 let dims = output_dims(&[n, m], batch_dims);
111 return Tensor::zeros(
112 &dims,
113 tensor.logical_memory_space(),
114 MemoryOrder::ColumnMajor,
115 );
116 }
117
118 let svd_result = svd(ctx, tensor, None)?;
119 let u_input = ensure_col_major(&svd_result.u);
120 let s_input = ensure_col_major(&svd_result.s);
121 let vt_input = ensure_col_major(&svd_result.vt);
122 let s_max_axes: Vec<usize> = (1..s_input.ndim()).collect();
123 let s_max = crate::prims_bridge::scalar_reduce_keep_axes(
124 ctx,
125 &s_input,
126 &s_max_axes,
127 tenferro_prims::ScalarReductionOp::Max,
128 )?;
129 let threshold: T::Real = scalar_from(rcond.unwrap_or(1e-15))?;
130 let threshold_tensor = crate::prims_bridge::full_like_constant(
131 threshold,
132 s_max.dims(),
133 s_max.logical_memory_space(),
134 )?;
135 let cutoff = crate::prims_bridge::scalar_binary_same_shape(
136 ctx,
137 &s_max,
138 &threshold_tensor,
139 tenferro_prims::ScalarBinaryOp::Mul,
140 )?;
141 let cutoff = cutoff.unsqueeze(0)?.broadcast(s_input.dims())?;
142 let keep_mask = crate::prims_bridge::scalar_binary_same_shape(
143 ctx,
144 &s_input,
145 &cutoff,
146 tenferro_prims::ScalarBinaryOp::Greater,
147 )?;
148 let one_mask = crate::prims_bridge::full_like_constant(
149 <T::Real as num_traits::One>::one(),
150 keep_mask.dims(),
151 keep_mask.logical_memory_space(),
152 )?;
153 let drop_mask = crate::prims_bridge::scalar_binary_same_shape(
154 ctx,
155 &one_mask,
156 &keep_mask,
157 tenferro_prims::ScalarBinaryOp::Sub,
158 )?;
159 let kept_s = crate::prims_bridge::scalar_binary_same_shape(
160 ctx,
161 &s_input,
162 &keep_mask,
163 tenferro_prims::ScalarBinaryOp::Mul,
164 )?;
165 let safe_s = crate::prims_bridge::scalar_binary_same_shape(
166 ctx,
167 &kept_s,
168 &drop_mask,
169 tenferro_prims::ScalarBinaryOp::Add,
170 )?;
171
172 let sinv = crate::prims_bridge::scalar_unary_same_shape(
173 ctx,
174 &safe_s,
175 tenferro_prims::ScalarUnaryOp::Reciprocal,
176 )?;
177 let sinv = crate::prims_bridge::scalar_binary_same_shape(
178 ctx,
179 &sinv,
180 &keep_mask,
181 tenferro_prims::ScalarBinaryOp::Mul,
182 )?;
183
184 let sinv_for_vt = sinv.unsqueeze(1)?.broadcast(vt_input.dims())?;
185 let vt_scaled = crate::prims_bridge::complex_scale_same_shape(ctx, &vt_input, &sinv_for_vt)?;
186
187 let mut perm = vec![1, 0];
188 perm.extend(2..u_input.ndim());
189 let u_t = crate::prims_bridge::resolve_conj(ctx, &u_input.conj().permute(&perm)?);
190 let vt_t = crate::prims_bridge::resolve_conj(ctx, &vt_scaled.conj().permute(&perm)?);
191
192 crate::prims_bridge::batched_gemm_with_semiring_tensors(ctx, &vt_t, &u_t, n, k, m)
193}
194
195#[allow(private_bounds)]
213pub fn matrix_exp<T, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
214where
215 T: KernelLinalgScalar
216 + crate::prims_bridge::ScaleTensorByRealSameShape<C>
217 + MatrixExpAbsTensor<C>,
218 T::Real: KernelLinalgScalar<Real = T::Real> + num_traits::Float,
219 C: backend::TensorLinalgContextFor<T>,
220 C: tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
221 C: tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>,
222 C: tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
223 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
224 tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T::Real>, Context = C>,
225 C::Backend: 'static,
226{
227 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::MatrixExp, "matrix_exp")?;
228
229 let (n, batch_dims) = validate_square(tensor)?;
230 let input = ensure_col_major(tensor);
231 if n == 0 {
232 let dims = output_dims(&[n, n], batch_dims);
233 return Tensor::zeros(
234 &dims,
235 input.logical_memory_space(),
236 MemoryOrder::ColumnMajor,
237 );
238 }
239
240 let batch_norms = crate::ad_helpers::matrix_exp_batch_1_norms_tensor(ctx, &input)?;
241 let s_by_batch_tensor =
242 crate::ad_helpers::matrix_exp_batch_squaring_counts_tensor(ctx, &batch_norms)?;
243 let s_max_tensor = if s_by_batch_tensor.ndim() == 0 {
244 s_by_batch_tensor.clone()
245 } else {
246 crate::prims_bridge::scalar_reduce_keep_axes(
247 ctx,
248 &s_by_batch_tensor,
249 &[],
250 tenferro_prims::ScalarReductionOp::Max,
251 )?
252 };
253 let s_max_host =
254 s_max_tensor.to_memory_space_async(tenferro_device::LogicalMemorySpace::MainMemory)?;
255 let s_max_slice = s_max_host.buffer().as_slice().ok_or_else(|| {
256 Error::InvalidArgument("matrix_exp: expected max squaring count tensor on host".into())
257 })?;
258 let s_max = s_max_slice
259 .first()
260 .copied()
261 .ok_or_else(|| Error::InvalidArgument("matrix_exp: missing max squaring count".into()))
262 .and_then(|value| {
263 num_traits::NumCast::from(value).ok_or_else(|| {
264 Error::InvalidArgument("matrix_exp: cannot convert max squaring count".into())
265 })
266 })?;
267
268 let two_by_batch = crate::prims_bridge::full_like_constant(
269 crate::ad_helpers::scalar_from::<T::Real>(2.0)?,
270 s_by_batch_tensor.dims(),
271 input.logical_memory_space(),
272 )?;
273 let scale_denom = crate::prims_bridge::analytic_binary_same_shape(
274 ctx,
275 &two_by_batch,
276 &s_by_batch_tensor,
277 tenferro_prims::AnalyticBinaryOp::Pow,
278 )?;
279 let scale_by_batch = crate::prims_bridge::scalar_unary_same_shape(
280 ctx,
281 &scale_denom,
282 tenferro_prims::ScalarUnaryOp::Reciprocal,
283 )?;
284 let scale_tensor =
285 broadcast_batch_control_to_matrix::<T>(&scale_by_batch, batch_dims, input.dims())?;
286 let scaled_input =
287 <T as crate::prims_bridge::ScaleTensorByRealSameShape<C>>::scale_tensor_by_real_same_shape(
288 ctx,
289 &input,
290 &scale_tensor,
291 )?;
292
293 let mut result =
294 crate::ad_helpers::matrix_exp_tensor_native(ctx, &scaled_input, n, batch_dims, 0)?;
295 for round in 0..s_max {
296 let squared = crate::prims_bridge::batched_gemm_with_semiring_tensors(
297 ctx, &result, &result, n, n, n,
298 )?;
299 let round_by_batch = crate::prims_bridge::full_like_constant(
300 crate::ad_helpers::scalar_from::<T::Real>(round as f64)?,
301 s_by_batch_tensor.dims(),
302 result.logical_memory_space(),
303 )?;
304 let mask_by_batch = crate::prims_bridge::scalar_binary_same_shape(
305 ctx,
306 &s_by_batch_tensor,
307 &round_by_batch,
308 tenferro_prims::ScalarBinaryOp::Greater,
309 )?;
310 let mask_tensor =
311 broadcast_batch_control_to_matrix::<T>(&mask_by_batch, batch_dims, result.dims())?;
312 result = crate::ad_helpers::blend_tensor_by_real_mask_same_shape(
313 ctx,
314 &squared,
315 &result,
316 &mask_tensor,
317 )?;
318 }
319
320 Ok(result)
321}