1use super::*;
2use crate::primal::linear_systems_sign::*;
3use num_complex::{Complex32, Complex64, ComplexFloat};
4use num_traits::{NumCast, One};
5use tenferro_algebra::Conjugate;
6use tenferro_prims::{TensorMetadataCastPrims, TensorMetadataContextFor, TensorMetadataPrims};
7fn inverse_rhs<T: KernelLinalgScalar>(
8 n: usize,
9 batch_dims: &[usize],
10 memory_space: tenferro_device::LogicalMemorySpace,
11) -> Result<Tensor<T>> {
12 let mut rhs = crate::prims_bridge::identity_matrix(n, memory_space)?;
13 for _ in batch_dims {
14 rhs = rhs.unsqueeze(-1)?;
15 }
16 rhs.broadcast(&output_dims(&[n, n], batch_dims))
17}
18
19pub fn solve<T: KernelLinalgScalar, C>(
36 ctx: &mut C,
37 a: &Tensor<T>,
38 b: &Tensor<T>,
39) -> Result<Tensor<T>>
40where
41 C: backend::TensorLinalgContextFor<T>,
42 C::Backend: 'static,
43{
44 <C::Backend as backend::TensorLinalgBackend<T>>::solve(ctx, a, b)
45}
46
47pub fn solve_ex<T: KernelLinalgScalar, C>(
65 ctx: &mut C,
66 a: &Tensor<T>,
67 b: &Tensor<T>,
68) -> Result<SolveExResult<T>>
69where
70 T: KernelLinalgScalar,
71 C: backend::TensorLinalgContextFor<T>,
72 C::Backend: 'static,
73{
74 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::SolveEx, "solve_ex")?;
75 let result = <C::Backend as backend::TensorLinalgBackend<T>>::solve_ex(ctx, a, b)?;
76 Ok(SolveExResult {
77 solution: result.solution,
78 info: result.info,
79 })
80}
81
82pub fn inv<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
98where
99 T: KernelLinalgScalar,
100 C: backend::TensorLinalgContextFor<T>,
101 C::Backend: 'static,
102{
103 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv")?;
104
105 let (n, batch_dims) = validate_square(tensor)?;
106 let rhs = inverse_rhs::<T>(n, batch_dims, tensor.logical_memory_space())?;
107 if n == 0 {
108 return Ok(rhs);
109 }
110 solve(ctx, tensor, &rhs)
111}
112
113pub fn inv_ex<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<InvExResult<T>>
130where
131 T: KernelLinalgScalar,
132 C: backend::TensorLinalgContextFor<T>,
133 C::Backend: 'static,
134{
135 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Inv, "inv_ex")?;
136 let (n, batch_dims) = validate_square(tensor)?;
137 let rhs = inverse_rhs::<T>(n, batch_dims, tensor.logical_memory_space())?;
138 if n == 0 {
139 return Ok(InvExResult {
140 inverse: rhs,
141 info: crate::backend::tensor_helpers::info_tensor_from_vec_on_space(
142 vec![0; batch_count(batch_dims)],
143 batch_dims,
144 tensor.logical_memory_space(),
145 )?,
146 });
147 }
148 let result = solve_ex(ctx, tensor, &rhs)?;
149 Ok(InvExResult {
150 inverse: result.solution,
151 info: result.info,
152 })
153}
154
155pub fn det<T: KernelLinalgScalar, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<Tensor<T>>
171where
172 C: backend::TensorLinalgContextFor<T>
173 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
174 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>
175 + TensorMetadataContextFor,
176 T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
177 T::Real: Scalar + NumCast + One + 'static,
178 C::MetadataBackend: TensorMetadataPrims<Context = C>,
179 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T::Real>>>::ScalarBackend:
180 TensorMetadataCastPrims<T::Real, Context = C>,
181 C::Backend: 'static,
182{
183 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Det, "det")?;
184
185 let (_n, batch_dims) = validate_square(tensor)?;
186 let lu = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
187 let diagonal = lu.u.diagonal(&[(0, 1)])?;
188 let kept_axes: Vec<usize> = (0..batch_dims.len()).collect();
189 let diagonal_prod = crate::prims_bridge::scalar_reduce_keep_axes(
190 ctx,
191 &diagonal,
192 &kept_axes,
193 tenferro_prims::ScalarReductionOp::Prod,
194 )?;
195 let sign_tensor = lu_permutation_sign_tensor::<T, C>(ctx, &lu.pivots)?;
196
197 <T as crate::prims_bridge::ScaleTensorByRealSameShape<C>>::scale_tensor_by_real_same_shape(
198 ctx,
199 &diagonal_prod,
200 &sign_tensor,
201 )
202}
203
204#[doc(hidden)]
216pub trait SlogdetDispatch<C>: KernelLinalgScalar {
217 fn slogdet_dispatch(
218 ctx: &mut C,
219 tensor: &Tensor<Self>,
220 ) -> Result<SlogdetResult<Self, Self::Real>>;
221}
222
223fn slogdet_real_impl<T, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<SlogdetResult<T, T::Real>>
224where
225 T: KernelLinalgScalar<Real = T> + num_traits::Float,
226 C: backend::TensorLinalgContextFor<T>
227 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
228 + TensorMetadataContextFor,
229 C::Backend: 'static,
230 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
231 'static
232 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>
233 + TensorMetadataCastPrims<T, Context = C>,
234 C::MetadataBackend: TensorMetadataPrims<Context = C>,
235{
236 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet")?;
237
238 let (_n, batch_dims) = validate_square(tensor)?;
239 let lu = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
240 let diagonal = lu.u.diagonal(&[(0, 1)])?;
241 let abs_diagonal = crate::prims_bridge::scalar_unary_same_shape(
242 ctx,
243 &diagonal,
244 tenferro_prims::ScalarUnaryOp::Abs,
245 )?;
246 let logabsdet_factor = crate::prims_bridge::analytic_unary_same_shape(
247 ctx,
248 &abs_diagonal,
249 tenferro_prims::AnalyticUnaryOp::Log,
250 )?;
251 let kept_axes: Vec<usize> = (0..batch_dims.len()).collect();
252 let logabsdet = crate::prims_bridge::scalar_reduce_keep_axes(
253 ctx,
254 &logabsdet_factor,
255 &kept_axes,
256 tenferro_prims::ScalarReductionOp::Sum,
257 )?;
258 let sign_perm = lu_permutation_sign_tensor::<T, C>(ctx, &lu.pivots)?;
259
260 let zero_diagonal = crate::prims_bridge::full_like_constant(
261 T::zero(),
262 diagonal.dims(),
263 tensor.logical_memory_space(),
264 )?;
265 let negative_mask = crate::prims_bridge::scalar_binary_same_shape(
266 ctx,
267 &zero_diagonal,
268 &diagonal,
269 tenferro_prims::ScalarBinaryOp::Greater,
270 )?;
271 let double_negative = crate::prims_bridge::scalar_binary_same_shape(
272 ctx,
273 &negative_mask,
274 &negative_mask,
275 tenferro_prims::ScalarBinaryOp::Add,
276 )?;
277 let one = crate::prims_bridge::full_like_constant(
278 T::one(),
279 diagonal.dims(),
280 tensor.logical_memory_space(),
281 )?;
282 let sign_factors = crate::prims_bridge::scalar_binary_same_shape(
283 ctx,
284 &one,
285 &double_negative,
286 tenferro_prims::ScalarBinaryOp::Sub,
287 )?;
288 let sign_from_diag = crate::prims_bridge::scalar_reduce_keep_axes(
289 ctx,
290 &sign_factors,
291 &kept_axes,
292 tenferro_prims::ScalarReductionOp::Prod,
293 )?;
294 let sign = crate::prims_bridge::scalar_binary_same_shape(
295 ctx,
296 &sign_perm,
297 &sign_from_diag,
298 tenferro_prims::ScalarBinaryOp::Mul,
299 )?;
300
301 Ok(SlogdetResult { sign, logabsdet })
302}
303
304fn slogdet_complex_impl<T, R, C>(ctx: &mut C, tensor: &Tensor<T>) -> Result<SlogdetResult<T, R>>
305where
306 T: KernelLinalgScalar<Real = R> + Conjugate + ComplexFloat<Real = R>,
307 T: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
308 C: backend::TensorLinalgContextFor<T>
309 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>
310 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
311 + tenferro_prims::TensorComplexRealContextFor<T>,
312 C: TensorMetadataContextFor,
313 C::Backend: 'static,
314 R: KernelLinalgScalar<Real = R> + num_traits::Float,
315 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<R>>>::ScalarBackend:
316 'static
317 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<R>, Context = C>
318 + TensorMetadataCastPrims<R, Context = C>,
319 <C as tenferro_prims::TensorComplexRealContextFor<T>>::ComplexRealBackend:
320 tenferro_prims::TensorComplexRealPrims<T, Context = C, Real = R>,
321 C::MetadataBackend: TensorMetadataPrims<Context = C>,
322{
323 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::Slogdet, "slogdet")?;
324
325 let (_n, batch_dims) = validate_square(tensor)?;
326 let lu = <C::Backend as backend::TensorLinalgBackend<T>>::lu_factor(ctx, tensor)?;
327 let diagonal = lu.u.diagonal(&[(0, 1)])?;
328 let abs_diagonal = crate::prims_bridge::complex_real_unary_same_shape(
329 ctx,
330 &diagonal,
331 tenferro_prims::ComplexRealUnaryOp::Abs,
332 )?;
333 let logabsdet_factor = crate::prims_bridge::analytic_unary_same_shape(
334 ctx,
335 &abs_diagonal,
336 tenferro_prims::AnalyticUnaryOp::Log,
337 )?;
338 let kept_axes: Vec<usize> = (0..batch_dims.len()).collect();
339 let logabsdet = crate::prims_bridge::scalar_reduce_keep_axes(
340 ctx,
341 &logabsdet_factor,
342 &kept_axes,
343 tenferro_prims::ScalarReductionOp::Sum,
344 )?;
345 let sign_perm = lu_permutation_sign_tensor::<T, C>(ctx, &lu.pivots)?;
346
347 let zero_real = crate::prims_bridge::full_like_constant(
348 R::zero(),
349 abs_diagonal.dims(),
350 tensor.logical_memory_space(),
351 )?;
352 let positive_mask = crate::prims_bridge::scalar_binary_same_shape(
353 ctx,
354 &abs_diagonal,
355 &zero_real,
356 tenferro_prims::ScalarBinaryOp::Greater,
357 )?;
358 let reciprocal_abs = crate::prims_bridge::scalar_unary_same_shape(
359 ctx,
360 &abs_diagonal,
361 tenferro_prims::ScalarUnaryOp::Reciprocal,
362 )?;
363 let zero_recip = crate::prims_bridge::full_like_constant(
364 R::zero(),
365 abs_diagonal.dims(),
366 tensor.logical_memory_space(),
367 )?;
368 let safe_recip = crate::prims_bridge::scalar_where_same_shape(
369 ctx,
370 &positive_mask,
371 &reciprocal_abs,
372 &zero_recip,
373 )?;
374 let phase_factors = crate::prims_bridge::complex_scale_same_shape(ctx, &diagonal, &safe_recip)?;
375 let sign_from_diag = crate::prims_bridge::scalar_reduce_keep_axes(
376 ctx,
377 &phase_factors,
378 &kept_axes,
379 tenferro_prims::ScalarReductionOp::Prod,
380 )?;
381 let sign = crate::prims_bridge::complex_scale_same_shape(ctx, &sign_from_diag, &sign_perm)?;
382
383 Ok(SlogdetResult { sign, logabsdet })
384}
385
386impl<C> SlogdetDispatch<C> for f32
387where
388 C: backend::TensorLinalgContextFor<f32>
389 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
390 + TensorMetadataContextFor,
391 C::Backend: 'static,
392 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
393 'static
394 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
395 + TensorMetadataCastPrims<f32, Context = C>,
396 C::MetadataBackend: TensorMetadataPrims<Context = C>,
397{
398 fn slogdet_dispatch(
399 ctx: &mut C,
400 tensor: &Tensor<Self>,
401 ) -> Result<SlogdetResult<Self, Self::Real>> {
402 slogdet_real_impl(ctx, tensor)
403 }
404}
405
406impl<C> SlogdetDispatch<C> for f64
407where
408 C: backend::TensorLinalgContextFor<f64>
409 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
410 + TensorMetadataContextFor,
411 C::Backend: 'static,
412 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
413 'static
414 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
415 + TensorMetadataCastPrims<f64, Context = C>,
416 C::MetadataBackend: TensorMetadataPrims<Context = C>,
417{
418 fn slogdet_dispatch(
419 ctx: &mut C,
420 tensor: &Tensor<Self>,
421 ) -> Result<SlogdetResult<Self, Self::Real>> {
422 slogdet_real_impl(ctx, tensor)
423 }
424}
425
426impl<C> SlogdetDispatch<C> for Complex32
427where
428 Complex32: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
429 C: backend::TensorLinalgContextFor<Complex32>
430 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>
431 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex32>>
432 + tenferro_prims::TensorComplexRealContextFor<Complex32>
433 + TensorMetadataContextFor,
434 C::Backend: 'static,
435 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f32>>>::ScalarBackend:
436 'static
437 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f32>, Context = C>
438 + TensorMetadataCastPrims<f32, Context = C>,
439 <C as tenferro_prims::TensorComplexRealContextFor<Complex32>>::ComplexRealBackend:
440 tenferro_prims::TensorComplexRealPrims<Complex32, Context = C, Real = f32>,
441 C::MetadataBackend: TensorMetadataPrims<Context = C>,
442{
443 fn slogdet_dispatch(
444 ctx: &mut C,
445 tensor: &Tensor<Self>,
446 ) -> Result<SlogdetResult<Self, Self::Real>> {
447 slogdet_complex_impl::<Complex32, f32, C>(ctx, tensor)
448 }
449}
450
451impl<C> SlogdetDispatch<C> for Complex64
452where
453 Complex64: crate::prims_bridge::ScaleTensorByRealSameShape<C>,
454 C: backend::TensorLinalgContextFor<Complex64>
455 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>
456 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<Complex64>>
457 + tenferro_prims::TensorComplexRealContextFor<Complex64>
458 + TensorMetadataContextFor,
459 C::Backend: 'static,
460 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<f64>>>::ScalarBackend:
461 'static
462 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<f64>, Context = C>
463 + TensorMetadataCastPrims<f64, Context = C>,
464 <C as tenferro_prims::TensorComplexRealContextFor<Complex64>>::ComplexRealBackend:
465 tenferro_prims::TensorComplexRealPrims<Complex64, Context = C, Real = f64>,
466 C::MetadataBackend: TensorMetadataPrims<Context = C>,
467{
468 fn slogdet_dispatch(
469 ctx: &mut C,
470 tensor: &Tensor<Self>,
471 ) -> Result<SlogdetResult<Self, Self::Real>> {
472 slogdet_complex_impl::<Complex64, f64, C>(ctx, tensor)
473 }
474}
475
476pub fn slogdet<T: KernelLinalgScalar, C>(
493 ctx: &mut C,
494 tensor: &Tensor<T>,
495) -> Result<SlogdetResult<T, T::Real>>
496where
497 T: SlogdetDispatch<C>,
498{
499 T::slogdet_dispatch(ctx, tensor)
500}