1use std::any::Any;
2use std::hash::Hasher;
3use std::sync::Arc;
4
5use tenferro_extension_macros::define_extension_runtime;
6use tenferro_ops::SymDim;
7use tenferro_runtime::extension::{ExtensionExecutionContext, ExtensionOp};
8use tenferro_tensor::{
9 DType, DeviceKind, Error, GpuBackendKind, MemoryKind, Placement, Tensor, TensorRead,
10};
11
12use crate::backend::LinalgBackend;
13
14#[cfg(all(test, not(feature = "cuda")))]
15mod tests;
16
17pub const LINALG_EXTENSION_FAMILY_ID: &str = "tenferro-linalg.linalg.v1";
18
19#[derive(Clone, Copy, Debug, PartialEq)]
20#[doc(hidden)]
21pub(crate) enum LinalgOp {
22 Cholesky,
23 Lu,
24 LuFactor,
25 LuSolvePrepared {
26 transpose_a: bool,
27 conjugate_a: bool,
28 },
29 FullPivLu,
30 FullPivLuSolve {
31 transpose_a: bool,
32 },
33 Svd {
34 eps: f64,
35 },
36 SvdVals {
37 eps: f64,
38 },
39 Qr,
40 Eigh {
41 eps: f64,
42 },
43 EighVals {
44 eps: f64,
45 },
46 Eig {
47 input_dtype: DType,
48 },
49 EigVals {
50 input_dtype: DType,
51 },
52 TriangularSolve {
53 left_side: bool,
54 lower: bool,
55 transpose_a: bool,
56 unit_diagonal: bool,
57 },
58}
59
60impl LinalgOp {
61 fn output_count(self) -> usize {
62 match self {
63 Self::Cholesky
64 | Self::EighVals { .. }
65 | Self::EigVals { .. }
66 | Self::FullPivLuSolve { .. }
67 | Self::LuSolvePrepared { .. }
68 | Self::SvdVals { .. }
69 | Self::TriangularSolve { .. } => 1,
70 Self::Svd { .. } => 3,
71 Self::Qr | Self::Eigh { .. } | Self::Eig { .. } => 2,
72 Self::LuFactor => 3,
73 Self::Lu => 4,
74 Self::FullPivLu => 5,
75 }
76 }
77
78 fn input_count(self) -> usize {
79 match self {
80 Self::FullPivLuSolve { .. } | Self::TriangularSolve { .. } => 2,
81 Self::LuSolvePrepared { .. } => 4,
82 _ => 1,
83 }
84 }
85
86 fn tag(self) -> u8 {
87 match self {
88 Self::Cholesky => 0,
89 Self::Lu => 1,
90 Self::FullPivLu => 2,
91 Self::FullPivLuSolve { .. } => 3,
92 Self::Svd { .. } => 4,
93 Self::Qr => 5,
94 Self::Eigh { .. } => 6,
95 Self::Eig { .. } => 7,
96 Self::TriangularSolve { .. } => 9,
97 Self::LuFactor => 10,
98 Self::LuSolvePrepared { .. } => 11,
99 Self::SvdVals { .. } => 12,
100 Self::EighVals { .. } => 13,
101 Self::EigVals { .. } => 14,
102 }
103 }
104}
105
106#[derive(Clone, Debug, PartialEq)]
107#[doc(hidden)]
108pub(crate) struct LinalgExtensionOp {
109 op: LinalgOp,
110}
111
112impl LinalgExtensionOp {
113 pub(crate) fn new(op: LinalgOp) -> Self {
114 Self { op }
115 }
116
117 pub(crate) fn op(&self) -> LinalgOp {
118 self.op
119 }
120}
121
122#[derive(Clone, Copy, Debug, PartialEq, Eq)]
123enum EagerLinalgDevice {
124 Cpu,
125 Cuda(usize),
126}
127
128fn tensor_placement(input: &Tensor) -> &Placement {
129 input.placement()
130}
131
132fn input_eager_device(input: &Tensor) -> tenferro_tensor::Result<EagerLinalgDevice> {
133 let placement = tensor_placement(input);
134 match (&placement.memory_kind, placement.device.as_ref()) {
135 (MemoryKind::Device, Some(device)) => match &device.kind {
136 DeviceKind::Gpu(GpuBackendKind::Cuda) => Ok(EagerLinalgDevice::Cuda(device.ordinal)),
137 DeviceKind::Gpu(kind) => Err(Error::backend_failure(
138 "linalg_eager_execute",
139 format!("unsupported GPU backend {kind:?} for eager linalg"),
140 )),
141 kind => Err(Error::backend_failure(
142 "linalg_eager_execute",
143 format!("unsupported device kind {kind:?} for eager linalg"),
144 )),
145 },
146 (MemoryKind::Device, None) => Err(Error::backend_failure(
147 "linalg_eager_execute",
148 "device tensor is missing placement device metadata",
149 )),
150 _ => Ok(EagerLinalgDevice::Cpu),
151 }
152}
153
154fn eager_linalg_device(inputs: &[&Tensor]) -> tenferro_tensor::Result<EagerLinalgDevice> {
155 let mut selected = None;
156 for input in inputs {
157 let device = input_eager_device(input)?;
158 match (selected, device) {
159 (None, next) => selected = Some(next),
160 (Some(EagerLinalgDevice::Cpu), EagerLinalgDevice::Cpu) => {}
161 (Some(EagerLinalgDevice::Cuda(lhs)), EagerLinalgDevice::Cuda(rhs)) if lhs == rhs => {}
162 (Some(lhs), rhs) => {
163 return Err(Error::backend_failure(
164 "linalg_eager_execute",
165 format!("all eager linalg inputs must be on the same device, got {lhs:?} and {rhs:?}"),
166 ));
167 }
168 }
169 }
170 Ok(selected.unwrap_or(EagerLinalgDevice::Cpu))
171}
172
173#[cfg(feature = "cuda")]
174fn execute_cuda_eager_linalg(
175 op: LinalgOp,
176 inputs: &[&Tensor],
177 device_ordinal: usize,
178) -> tenferro_tensor::Result<Vec<Tensor>> {
179 let mut backend = tenferro_gpu::CudaBackend::new(device_ordinal)?;
180 execute_linalg(op, inputs, &mut backend)
181}
182
183#[cfg(not(feature = "cuda"))]
184fn execute_cuda_eager_linalg(
185 _op: LinalgOp,
186 _inputs: &[&Tensor],
187 device_ordinal: usize,
188) -> tenferro_tensor::Result<Vec<Tensor>> {
189 Err(Error::backend_failure(
190 "linalg_eager_execute",
191 format!(
192 "received CUDA tensor on cuda:{device_ordinal}, but tenferro-linalg was built \
193 without the cuda feature; enable the cuda feature or download the tensor to CPU \
194 before eager linalg"
195 ),
196 ))
197}
198
199impl ExtensionOp for LinalgExtensionOp {
200 fn family_id(&self) -> &'static str {
201 LINALG_EXTENSION_FAMILY_ID
202 }
203
204 fn payload_hash(&self, hasher: &mut dyn Hasher) {
205 hasher.write_u8(self.op.tag());
206 match self.op {
207 LinalgOp::Svd { eps }
208 | LinalgOp::SvdVals { eps }
209 | LinalgOp::Eigh { eps }
210 | LinalgOp::EighVals { eps } => hasher.write_u64(eps.to_bits()),
211 LinalgOp::Eig { input_dtype } | LinalgOp::EigVals { input_dtype } => {
212 hash_dtype(hasher, input_dtype);
213 }
214 LinalgOp::FullPivLuSolve { transpose_a } => {
215 hasher.write_u8(u8::from(transpose_a));
216 }
217 LinalgOp::LuSolvePrepared {
218 transpose_a,
219 conjugate_a,
220 } => {
221 hasher.write_u8(u8::from(transpose_a));
222 hasher.write_u8(u8::from(conjugate_a));
223 }
224 LinalgOp::TriangularSolve {
225 left_side,
226 lower,
227 transpose_a,
228 unit_diagonal,
229 } => {
230 hasher.write_u8(u8::from(left_side));
231 hasher.write_u8(u8::from(lower));
232 hasher.write_u8(u8::from(transpose_a));
233 hasher.write_u8(u8::from(unit_diagonal));
234 }
235 LinalgOp::Cholesky
236 | LinalgOp::Lu
237 | LinalgOp::LuFactor
238 | LinalgOp::FullPivLu
239 | LinalgOp::Qr => {}
240 }
241 }
242
243 fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
244 other
245 .as_any()
246 .downcast_ref::<Self>()
247 .is_some_and(|that| self == that)
248 }
249
250 fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
251 Arc::new(self.clone())
252 }
253
254 fn as_any(&self) -> &dyn Any {
255 self
256 }
257
258 fn input_count(&self) -> usize {
259 self.op.input_count()
260 }
261
262 fn output_count(&self) -> usize {
263 self.op.output_count()
264 }
265
266 fn infer_output_meta(
267 &self,
268 input_dtypes: &[DType],
269 input_shapes: &[&[SymDim]],
270 ) -> Vec<(DType, Vec<SymDim>)> {
271 if input_dtypes.len() != self.input_count() || input_shapes.len() != self.input_count() {
272 return Vec::new();
273 }
274 match self.op {
275 LinalgOp::Cholesky
276 | LinalgOp::FullPivLuSolve { .. }
277 | LinalgOp::TriangularSolve { .. } => {
278 let output_shape = if self.input_count() == 1 {
279 input_shapes[0].to_vec()
280 } else {
281 input_shapes[1].to_vec()
282 };
283 vec![(promote_dtypes(input_dtypes), output_shape)]
284 }
285 LinalgOp::LuSolvePrepared { .. } => {
286 vec![(
287 promote_dtypes(&[input_dtypes[0], input_dtypes[3]]),
288 input_shapes[3].to_vec(),
289 )]
290 }
291 LinalgOp::Lu => lu_meta(input_dtypes[0], input_shapes[0]),
292 LinalgOp::LuFactor => lu_factor_meta(input_dtypes[0], input_shapes[0]),
293 LinalgOp::FullPivLu => full_piv_lu_meta(input_dtypes[0], input_shapes[0]),
294 LinalgOp::Svd { .. } => svd_meta(input_dtypes[0], input_shapes[0]),
295 LinalgOp::SvdVals { .. } => {
296 vec![svd_values_meta(input_dtypes[0], input_shapes[0])]
297 }
298 LinalgOp::Qr => qr_meta(input_dtypes[0], input_shapes[0]),
299 LinalgOp::Eigh { .. } => eigh_meta(input_dtypes[0], input_shapes[0]),
300 LinalgOp::EighVals { .. } => vec![eigh_values_meta(input_dtypes[0], input_shapes[0])],
301 LinalgOp::Eig { input_dtype } => eig_meta(input_dtype, input_shapes[0]),
302 LinalgOp::EigVals { input_dtype } => {
303 vec![eig_values_meta(input_dtype, input_shapes[0])]
304 }
305 }
306 }
307
308 fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
309 let expected = self.input_count();
310 if inputs.len() != expected {
311 return Err(Error::InvalidConfig {
312 op: "linalg_eager_execute",
313 message: format!(
314 "expected {expected} inputs for {:?}, got {}",
315 self.op,
316 inputs.len()
317 ),
318 });
319 }
320
321 match eager_linalg_device(inputs)? {
322 EagerLinalgDevice::Cpu => {
323 let mut backend = tenferro_cpu::CpuBackend::new();
324 execute_linalg(self.op, inputs, &mut backend)
325 }
326 EagerLinalgDevice::Cuda(device_ordinal) => {
327 execute_cuda_eager_linalg(self.op, inputs, device_ordinal)
328 }
329 }
330 }
331}
332
333fn execute_linalg_extension<B: LinalgBackend + 'static>(
334 op: &LinalgExtensionOp,
335 inputs: &[&Tensor],
336 ctx: &mut ExtensionExecutionContext<'_, B>,
337) -> tenferro_tensor::Result<Vec<Tensor>> {
338 execute_linalg(op.op(), inputs, ctx.backend_mut())
339}
340
341fn execute_linalg_extension_reads<B: LinalgBackend + 'static>(
342 op: &LinalgExtensionOp,
343 inputs: &[TensorRead<'_>],
344 ctx: &mut ExtensionExecutionContext<'_, B>,
345) -> tenferro_tensor::Result<Vec<Tensor>> {
346 let materialized_inputs: Vec<Tensor> = inputs
349 .iter()
350 .map(TensorRead::to_tensor)
351 .collect::<tenferro_tensor::Result<_>>()?;
352 let input_refs: Vec<&Tensor> = materialized_inputs.iter().collect();
353 execute_linalg_extension(op, &input_refs, ctx)
354}
355
356define_extension_runtime! {
357 runtime = LinalgRuntime,
358 family_id = LINALG_EXTENSION_FAMILY_ID,
359 op_type = LinalgExtensionOp,
360 execute = execute_linalg_extension,
361 execute_reads = execute_linalg_extension_reads,
362 register_fn = register_runtime,
363 backend_bound = LinalgBackend,
364}
365
366fn execute_linalg<B: LinalgBackend>(
367 op: LinalgOp,
368 inputs: &[&Tensor],
369 backend: &mut B,
370) -> tenferro_tensor::Result<Vec<Tensor>> {
371 match op {
372 LinalgOp::Cholesky => Ok(vec![backend.cholesky(inputs[0])?]),
373 LinalgOp::Lu => backend.lu(inputs[0]),
374 LinalgOp::LuFactor => backend.lu_factor(inputs[0]),
375 LinalgOp::LuSolvePrepared {
376 transpose_a,
377 conjugate_a,
378 } => Ok(vec![backend.lu_solve_prepared(
379 inputs[0],
380 inputs[1],
381 inputs[2],
382 inputs[3],
383 transpose_a,
384 conjugate_a,
385 )?]),
386 LinalgOp::FullPivLu => backend.full_piv_lu(inputs[0]),
387 LinalgOp::FullPivLuSolve { transpose_a } => Ok(vec![backend.full_piv_lu_solve(
388 inputs[0],
389 inputs[1],
390 transpose_a,
391 )?]),
392 LinalgOp::Svd { .. } => backend.svd(inputs[0]),
393 LinalgOp::SvdVals { .. } => Ok(vec![backend.svd_values(inputs[0])?]),
394 LinalgOp::Qr => backend.qr(inputs[0]),
395 LinalgOp::Eigh { .. } => backend.eigh(inputs[0]),
396 LinalgOp::EighVals { .. } => Ok(vec![backend.eigh_values(inputs[0])?]),
397 LinalgOp::Eig { .. } => backend.eig(inputs[0]),
398 LinalgOp::EigVals { .. } => Ok(vec![backend.eig_values(inputs[0])?]),
399 LinalgOp::TriangularSolve {
400 left_side,
401 lower,
402 transpose_a,
403 unit_diagonal,
404 } => Ok(vec![backend.triangular_solve(
405 inputs[0],
406 inputs[1],
407 left_side,
408 lower,
409 transpose_a,
410 unit_diagonal,
411 )?]),
412 }
413}
414
415fn lu_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
416 let m = shape[0].clone();
417 let n = shape[1].clone();
418 let k = m.clone().min(n.clone());
419 let batch = &shape[2..];
420 vec![
421 (dtype, matrix_shape(m.clone(), m, batch)),
422 (dtype, matrix_shape(shape[0].clone(), k.clone(), batch)),
423 (dtype, matrix_shape(k, n, batch)),
424 (dtype, batch.to_vec()),
425 ]
426}
427
428fn lu_factor_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
429 let m = shape[0].clone();
430 let n = shape[1].clone();
431 let k = m.min(n);
432 let batch = &shape[2..];
433 vec![
434 (dtype, shape.to_vec()),
435 (DType::I32, vector_shape(k, batch)),
436 (dtype, batch.to_vec()),
437 ]
438}
439
440fn full_piv_lu_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
441 let n = shape[0].clone();
442 let batch = &shape[2..];
443 vec![
444 (dtype, matrix_shape(n.clone(), n.clone(), batch)),
445 (dtype, matrix_shape(n.clone(), n.clone(), batch)),
446 (dtype, matrix_shape(n.clone(), n.clone(), batch)),
447 (dtype, matrix_shape(n.clone(), n, batch)),
448 (singular_values_dtype(dtype), batch.to_vec()),
449 ]
450}
451
452fn svd_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
453 let m = shape[0].clone();
454 let n = shape[1].clone();
455 let k = m.clone().min(n.clone());
456 let batch = &shape[2..];
457 vec![
458 (dtype, matrix_shape(m, k.clone(), batch)),
459 (singular_values_dtype(dtype), vector_shape(k.clone(), batch)),
460 (dtype, matrix_shape(k, n, batch)),
461 ]
462}
463
464fn svd_values_meta(dtype: DType, shape: &[SymDim]) -> (DType, Vec<SymDim>) {
465 let m = shape[0].clone();
466 let n = shape[1].clone();
467 let k = m.min(n);
468 let batch = &shape[2..];
469 (singular_values_dtype(dtype), vector_shape(k, batch))
470}
471
472fn qr_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
473 let m = shape[0].clone();
474 let n = shape[1].clone();
475 let k = m.clone().min(n.clone());
476 let batch = &shape[2..];
477 vec![
478 (dtype, matrix_shape(m, k.clone(), batch)),
479 (dtype, matrix_shape(k, n, batch)),
480 ]
481}
482
483fn eigh_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
484 let n = shape[0].clone();
485 let batch = &shape[2..];
486 vec![
487 (singular_values_dtype(dtype), vector_shape(n.clone(), batch)),
488 (dtype, matrix_shape(n.clone(), n, batch)),
489 ]
490}
491
492fn eigh_values_meta(dtype: DType, shape: &[SymDim]) -> (DType, Vec<SymDim>) {
493 let n = shape[0].clone();
494 let batch = &shape[2..];
495 (singular_values_dtype(dtype), vector_shape(n, batch))
496}
497
498fn eig_meta(input_dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
499 let dtype = eig_output_dtype(input_dtype);
500 let n = shape[0].clone();
501 let batch = &shape[2..];
502 vec![
503 (dtype, vector_shape(n.clone(), batch)),
504 (dtype, matrix_shape(n.clone(), n, batch)),
505 ]
506}
507
508fn eig_values_meta(input_dtype: DType, shape: &[SymDim]) -> (DType, Vec<SymDim>) {
509 let dtype = eig_output_dtype(input_dtype);
510 let n = shape[0].clone();
511 let batch = &shape[2..];
512 (dtype, vector_shape(n, batch))
513}
514
515fn matrix_shape(rows: SymDim, cols: SymDim, batch: &[SymDim]) -> Vec<SymDim> {
516 let mut shape = vec![rows, cols];
517 shape.extend_from_slice(batch);
518 shape
519}
520
521fn vector_shape(len: SymDim, batch: &[SymDim]) -> Vec<SymDim> {
522 let mut shape = vec![len];
523 shape.extend_from_slice(batch);
524 shape
525}
526
527fn eig_output_dtype(dtype: DType) -> DType {
528 match dtype {
529 DType::F64 | DType::C64 => DType::C64,
530 DType::F32 | DType::C32 => DType::C32,
531 DType::I32 | DType::I64 | DType::Bool => DType::C64,
532 }
533}
534
535fn singular_values_dtype(dtype: DType) -> DType {
536 match dtype {
537 DType::C64 => DType::F64,
538 DType::C32 => DType::F32,
539 other => other,
540 }
541}
542
543fn promote_dtypes(dtypes: &[DType]) -> DType {
544 dtypes
545 .iter()
546 .copied()
547 .reduce(promote_dtype)
548 .unwrap_or(DType::F64)
549}
550
551fn promote_dtype(lhs: DType, rhs: DType) -> DType {
552 use DType::*;
553 match (lhs, rhs) {
554 (Bool, Bool) => Bool,
555 (Bool, other) | (other, Bool) => other,
556 (I32, I32) => I32,
557 (I32, I64) | (I64, I32) | (I64, I64) => I64,
558 (I32 | I64, F32 | F64) | (F32 | F64, I32 | I64) => F64,
559 (I32 | I64, C32 | C64) | (C32 | C64, I32 | I64) => C64,
560 (F32, F32) => F32,
561 (F32, F64) | (F64, F32) | (F64, F64) => F64,
562 (F32, C32) | (C32, F32) | (C32, C32) => C32,
563 (F32, C64) | (C64, F32) => C64,
564 (F64, C32 | C64) | (C32 | C64, F64) => C64,
565 (C32, C64) | (C64, C32) | (C64, C64) => C64,
566 }
567}
568
569fn hash_dtype(hasher: &mut dyn Hasher, dtype: DType) {
570 let tag = match dtype {
571 DType::F64 => 0,
572 DType::F32 => 1,
573 DType::I64 => 2,
574 DType::C64 => 3,
575 DType::C32 => 4,
576 DType::I32 => 5,
577 DType::Bool => 6,
578 };
579 hasher.write_u8(tag);
580}