Skip to main content

tenferro_linalg/
extension.rs

1use std::any::Any;
2use std::hash::Hasher;
3use std::sync::Arc;
4
5use tenferro_extension_macros::define_extension_runtime;
6use tenferro_ops::SymDim;
7use tenferro_runtime::extension::{ExtensionExecutionContext, ExtensionOp};
8use tenferro_tensor::{
9    DType, DeviceKind, Error, GpuBackendKind, MemoryKind, Placement, Tensor, TensorRead,
10};
11
12use crate::backend::LinalgBackend;
13
14#[cfg(all(test, not(feature = "cuda")))]
15mod tests;
16
17pub const LINALG_EXTENSION_FAMILY_ID: &str = "tenferro-linalg.linalg.v1";
18
19#[derive(Clone, Copy, Debug, PartialEq)]
20#[doc(hidden)]
21pub(crate) enum LinalgOp {
22    Cholesky,
23    Lu,
24    LuFactor,
25    LuSolvePrepared {
26        transpose_a: bool,
27        conjugate_a: bool,
28    },
29    FullPivLu,
30    FullPivLuSolve {
31        transpose_a: bool,
32    },
33    Svd {
34        eps: f64,
35    },
36    SvdVals {
37        eps: f64,
38    },
39    Qr,
40    Eigh {
41        eps: f64,
42    },
43    EighVals {
44        eps: f64,
45    },
46    Eig {
47        input_dtype: DType,
48    },
49    EigVals {
50        input_dtype: DType,
51    },
52    TriangularSolve {
53        left_side: bool,
54        lower: bool,
55        transpose_a: bool,
56        unit_diagonal: bool,
57    },
58}
59
60impl LinalgOp {
61    fn output_count(self) -> usize {
62        match self {
63            Self::Cholesky
64            | Self::EighVals { .. }
65            | Self::EigVals { .. }
66            | Self::FullPivLuSolve { .. }
67            | Self::LuSolvePrepared { .. }
68            | Self::SvdVals { .. }
69            | Self::TriangularSolve { .. } => 1,
70            Self::Svd { .. } => 3,
71            Self::Qr | Self::Eigh { .. } | Self::Eig { .. } => 2,
72            Self::LuFactor => 3,
73            Self::Lu => 4,
74            Self::FullPivLu => 5,
75        }
76    }
77
78    fn input_count(self) -> usize {
79        match self {
80            Self::FullPivLuSolve { .. } | Self::TriangularSolve { .. } => 2,
81            Self::LuSolvePrepared { .. } => 4,
82            _ => 1,
83        }
84    }
85
86    fn tag(self) -> u8 {
87        match self {
88            Self::Cholesky => 0,
89            Self::Lu => 1,
90            Self::FullPivLu => 2,
91            Self::FullPivLuSolve { .. } => 3,
92            Self::Svd { .. } => 4,
93            Self::Qr => 5,
94            Self::Eigh { .. } => 6,
95            Self::Eig { .. } => 7,
96            Self::TriangularSolve { .. } => 9,
97            Self::LuFactor => 10,
98            Self::LuSolvePrepared { .. } => 11,
99            Self::SvdVals { .. } => 12,
100            Self::EighVals { .. } => 13,
101            Self::EigVals { .. } => 14,
102        }
103    }
104}
105
106#[derive(Clone, Debug, PartialEq)]
107#[doc(hidden)]
108pub(crate) struct LinalgExtensionOp {
109    op: LinalgOp,
110}
111
112impl LinalgExtensionOp {
113    pub(crate) fn new(op: LinalgOp) -> Self {
114        Self { op }
115    }
116
117    pub(crate) fn op(&self) -> LinalgOp {
118        self.op
119    }
120}
121
122#[derive(Clone, Copy, Debug, PartialEq, Eq)]
123enum EagerLinalgDevice {
124    Cpu,
125    Cuda(usize),
126}
127
128fn tensor_placement(input: &Tensor) -> &Placement {
129    input.placement()
130}
131
132fn input_eager_device(input: &Tensor) -> tenferro_tensor::Result<EagerLinalgDevice> {
133    let placement = tensor_placement(input);
134    match (&placement.memory_kind, placement.device.as_ref()) {
135        (MemoryKind::Device, Some(device)) => match &device.kind {
136            DeviceKind::Gpu(GpuBackendKind::Cuda) => Ok(EagerLinalgDevice::Cuda(device.ordinal)),
137            DeviceKind::Gpu(kind) => Err(Error::backend_failure(
138                "linalg_eager_execute",
139                format!("unsupported GPU backend {kind:?} for eager linalg"),
140            )),
141            kind => Err(Error::backend_failure(
142                "linalg_eager_execute",
143                format!("unsupported device kind {kind:?} for eager linalg"),
144            )),
145        },
146        (MemoryKind::Device, None) => Err(Error::backend_failure(
147            "linalg_eager_execute",
148            "device tensor is missing placement device metadata",
149        )),
150        _ => Ok(EagerLinalgDevice::Cpu),
151    }
152}
153
154fn eager_linalg_device(inputs: &[&Tensor]) -> tenferro_tensor::Result<EagerLinalgDevice> {
155    let mut selected = None;
156    for input in inputs {
157        let device = input_eager_device(input)?;
158        match (selected, device) {
159            (None, next) => selected = Some(next),
160            (Some(EagerLinalgDevice::Cpu), EagerLinalgDevice::Cpu) => {}
161            (Some(EagerLinalgDevice::Cuda(lhs)), EagerLinalgDevice::Cuda(rhs)) if lhs == rhs => {}
162            (Some(lhs), rhs) => {
163                return Err(Error::backend_failure(
164                    "linalg_eager_execute",
165                    format!("all eager linalg inputs must be on the same device, got {lhs:?} and {rhs:?}"),
166                ));
167            }
168        }
169    }
170    Ok(selected.unwrap_or(EagerLinalgDevice::Cpu))
171}
172
173#[cfg(feature = "cuda")]
174fn execute_cuda_eager_linalg(
175    op: LinalgOp,
176    inputs: &[&Tensor],
177    device_ordinal: usize,
178) -> tenferro_tensor::Result<Vec<Tensor>> {
179    let mut backend = tenferro_gpu::CudaBackend::new(device_ordinal)?;
180    execute_linalg(op, inputs, &mut backend)
181}
182
183#[cfg(not(feature = "cuda"))]
184fn execute_cuda_eager_linalg(
185    _op: LinalgOp,
186    _inputs: &[&Tensor],
187    device_ordinal: usize,
188) -> tenferro_tensor::Result<Vec<Tensor>> {
189    Err(Error::backend_failure(
190        "linalg_eager_execute",
191        format!(
192            "received CUDA tensor on cuda:{device_ordinal}, but tenferro-linalg was built \
193             without the cuda feature; enable the cuda feature or download the tensor to CPU \
194             before eager linalg"
195        ),
196    ))
197}
198
199impl ExtensionOp for LinalgExtensionOp {
200    fn family_id(&self) -> &'static str {
201        LINALG_EXTENSION_FAMILY_ID
202    }
203
204    fn payload_hash(&self, hasher: &mut dyn Hasher) {
205        hasher.write_u8(self.op.tag());
206        match self.op {
207            LinalgOp::Svd { eps }
208            | LinalgOp::SvdVals { eps }
209            | LinalgOp::Eigh { eps }
210            | LinalgOp::EighVals { eps } => hasher.write_u64(eps.to_bits()),
211            LinalgOp::Eig { input_dtype } | LinalgOp::EigVals { input_dtype } => {
212                hash_dtype(hasher, input_dtype);
213            }
214            LinalgOp::FullPivLuSolve { transpose_a } => {
215                hasher.write_u8(u8::from(transpose_a));
216            }
217            LinalgOp::LuSolvePrepared {
218                transpose_a,
219                conjugate_a,
220            } => {
221                hasher.write_u8(u8::from(transpose_a));
222                hasher.write_u8(u8::from(conjugate_a));
223            }
224            LinalgOp::TriangularSolve {
225                left_side,
226                lower,
227                transpose_a,
228                unit_diagonal,
229            } => {
230                hasher.write_u8(u8::from(left_side));
231                hasher.write_u8(u8::from(lower));
232                hasher.write_u8(u8::from(transpose_a));
233                hasher.write_u8(u8::from(unit_diagonal));
234            }
235            LinalgOp::Cholesky
236            | LinalgOp::Lu
237            | LinalgOp::LuFactor
238            | LinalgOp::FullPivLu
239            | LinalgOp::Qr => {}
240        }
241    }
242
243    fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
244        other
245            .as_any()
246            .downcast_ref::<Self>()
247            .is_some_and(|that| self == that)
248    }
249
250    fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
251        Arc::new(self.clone())
252    }
253
254    fn as_any(&self) -> &dyn Any {
255        self
256    }
257
258    fn input_count(&self) -> usize {
259        self.op.input_count()
260    }
261
262    fn output_count(&self) -> usize {
263        self.op.output_count()
264    }
265
266    fn infer_output_meta(
267        &self,
268        input_dtypes: &[DType],
269        input_shapes: &[&[SymDim]],
270    ) -> Vec<(DType, Vec<SymDim>)> {
271        if input_dtypes.len() != self.input_count() || input_shapes.len() != self.input_count() {
272            return Vec::new();
273        }
274        match self.op {
275            LinalgOp::Cholesky
276            | LinalgOp::FullPivLuSolve { .. }
277            | LinalgOp::TriangularSolve { .. } => {
278                let output_shape = if self.input_count() == 1 {
279                    input_shapes[0].to_vec()
280                } else {
281                    input_shapes[1].to_vec()
282                };
283                vec![(promote_dtypes(input_dtypes), output_shape)]
284            }
285            LinalgOp::LuSolvePrepared { .. } => {
286                vec![(
287                    promote_dtypes(&[input_dtypes[0], input_dtypes[3]]),
288                    input_shapes[3].to_vec(),
289                )]
290            }
291            LinalgOp::Lu => lu_meta(input_dtypes[0], input_shapes[0]),
292            LinalgOp::LuFactor => lu_factor_meta(input_dtypes[0], input_shapes[0]),
293            LinalgOp::FullPivLu => full_piv_lu_meta(input_dtypes[0], input_shapes[0]),
294            LinalgOp::Svd { .. } => svd_meta(input_dtypes[0], input_shapes[0]),
295            LinalgOp::SvdVals { .. } => {
296                vec![svd_values_meta(input_dtypes[0], input_shapes[0])]
297            }
298            LinalgOp::Qr => qr_meta(input_dtypes[0], input_shapes[0]),
299            LinalgOp::Eigh { .. } => eigh_meta(input_dtypes[0], input_shapes[0]),
300            LinalgOp::EighVals { .. } => vec![eigh_values_meta(input_dtypes[0], input_shapes[0])],
301            LinalgOp::Eig { input_dtype } => eig_meta(input_dtype, input_shapes[0]),
302            LinalgOp::EigVals { input_dtype } => {
303                vec![eig_values_meta(input_dtype, input_shapes[0])]
304            }
305        }
306    }
307
308    fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
309        let expected = self.input_count();
310        if inputs.len() != expected {
311            return Err(Error::InvalidConfig {
312                op: "linalg_eager_execute",
313                message: format!(
314                    "expected {expected} inputs for {:?}, got {}",
315                    self.op,
316                    inputs.len()
317                ),
318            });
319        }
320
321        match eager_linalg_device(inputs)? {
322            EagerLinalgDevice::Cpu => {
323                let mut backend = tenferro_cpu::CpuBackend::new();
324                execute_linalg(self.op, inputs, &mut backend)
325            }
326            EagerLinalgDevice::Cuda(device_ordinal) => {
327                execute_cuda_eager_linalg(self.op, inputs, device_ordinal)
328            }
329        }
330    }
331}
332
333fn execute_linalg_extension<B: LinalgBackend + 'static>(
334    op: &LinalgExtensionOp,
335    inputs: &[&Tensor],
336    ctx: &mut ExtensionExecutionContext<'_, B>,
337) -> tenferro_tensor::Result<Vec<Tensor>> {
338    execute_linalg(op.op(), inputs, ctx.backend_mut())
339}
340
341fn execute_linalg_extension_reads<B: LinalgBackend + 'static>(
342    op: &LinalgExtensionOp,
343    inputs: &[TensorRead<'_>],
344    ctx: &mut ExtensionExecutionContext<'_, B>,
345) -> tenferro_tensor::Result<Vec<Tensor>> {
346    // Linalg backends currently operate on compact tensors; materialization is
347    // explicit here so borrowed views cannot silently bypass backend errors.
348    let materialized_inputs: Vec<Tensor> = inputs
349        .iter()
350        .map(TensorRead::to_tensor)
351        .collect::<tenferro_tensor::Result<_>>()?;
352    let input_refs: Vec<&Tensor> = materialized_inputs.iter().collect();
353    execute_linalg_extension(op, &input_refs, ctx)
354}
355
356define_extension_runtime! {
357    runtime = LinalgRuntime,
358    family_id = LINALG_EXTENSION_FAMILY_ID,
359    op_type = LinalgExtensionOp,
360    execute = execute_linalg_extension,
361    execute_reads = execute_linalg_extension_reads,
362    register_fn = register_runtime,
363    backend_bound = LinalgBackend,
364}
365
366fn execute_linalg<B: LinalgBackend>(
367    op: LinalgOp,
368    inputs: &[&Tensor],
369    backend: &mut B,
370) -> tenferro_tensor::Result<Vec<Tensor>> {
371    match op {
372        LinalgOp::Cholesky => Ok(vec![backend.cholesky(inputs[0])?]),
373        LinalgOp::Lu => backend.lu(inputs[0]),
374        LinalgOp::LuFactor => backend.lu_factor(inputs[0]),
375        LinalgOp::LuSolvePrepared {
376            transpose_a,
377            conjugate_a,
378        } => Ok(vec![backend.lu_solve_prepared(
379            inputs[0],
380            inputs[1],
381            inputs[2],
382            inputs[3],
383            transpose_a,
384            conjugate_a,
385        )?]),
386        LinalgOp::FullPivLu => backend.full_piv_lu(inputs[0]),
387        LinalgOp::FullPivLuSolve { transpose_a } => Ok(vec![backend.full_piv_lu_solve(
388            inputs[0],
389            inputs[1],
390            transpose_a,
391        )?]),
392        LinalgOp::Svd { .. } => backend.svd(inputs[0]),
393        LinalgOp::SvdVals { .. } => Ok(vec![backend.svd_values(inputs[0])?]),
394        LinalgOp::Qr => backend.qr(inputs[0]),
395        LinalgOp::Eigh { .. } => backend.eigh(inputs[0]),
396        LinalgOp::EighVals { .. } => Ok(vec![backend.eigh_values(inputs[0])?]),
397        LinalgOp::Eig { .. } => backend.eig(inputs[0]),
398        LinalgOp::EigVals { .. } => Ok(vec![backend.eig_values(inputs[0])?]),
399        LinalgOp::TriangularSolve {
400            left_side,
401            lower,
402            transpose_a,
403            unit_diagonal,
404        } => Ok(vec![backend.triangular_solve(
405            inputs[0],
406            inputs[1],
407            left_side,
408            lower,
409            transpose_a,
410            unit_diagonal,
411        )?]),
412    }
413}
414
415fn lu_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
416    let m = shape[0].clone();
417    let n = shape[1].clone();
418    let k = m.clone().min(n.clone());
419    let batch = &shape[2..];
420    vec![
421        (dtype, matrix_shape(m.clone(), m, batch)),
422        (dtype, matrix_shape(shape[0].clone(), k.clone(), batch)),
423        (dtype, matrix_shape(k, n, batch)),
424        (dtype, batch.to_vec()),
425    ]
426}
427
428fn lu_factor_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
429    let m = shape[0].clone();
430    let n = shape[1].clone();
431    let k = m.min(n);
432    let batch = &shape[2..];
433    vec![
434        (dtype, shape.to_vec()),
435        (DType::I32, vector_shape(k, batch)),
436        (dtype, batch.to_vec()),
437    ]
438}
439
440fn full_piv_lu_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
441    let n = shape[0].clone();
442    let batch = &shape[2..];
443    vec![
444        (dtype, matrix_shape(n.clone(), n.clone(), batch)),
445        (dtype, matrix_shape(n.clone(), n.clone(), batch)),
446        (dtype, matrix_shape(n.clone(), n.clone(), batch)),
447        (dtype, matrix_shape(n.clone(), n, batch)),
448        (singular_values_dtype(dtype), batch.to_vec()),
449    ]
450}
451
452fn svd_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
453    let m = shape[0].clone();
454    let n = shape[1].clone();
455    let k = m.clone().min(n.clone());
456    let batch = &shape[2..];
457    vec![
458        (dtype, matrix_shape(m, k.clone(), batch)),
459        (singular_values_dtype(dtype), vector_shape(k.clone(), batch)),
460        (dtype, matrix_shape(k, n, batch)),
461    ]
462}
463
464fn svd_values_meta(dtype: DType, shape: &[SymDim]) -> (DType, Vec<SymDim>) {
465    let m = shape[0].clone();
466    let n = shape[1].clone();
467    let k = m.min(n);
468    let batch = &shape[2..];
469    (singular_values_dtype(dtype), vector_shape(k, batch))
470}
471
472fn qr_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
473    let m = shape[0].clone();
474    let n = shape[1].clone();
475    let k = m.clone().min(n.clone());
476    let batch = &shape[2..];
477    vec![
478        (dtype, matrix_shape(m, k.clone(), batch)),
479        (dtype, matrix_shape(k, n, batch)),
480    ]
481}
482
483fn eigh_meta(dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
484    let n = shape[0].clone();
485    let batch = &shape[2..];
486    vec![
487        (singular_values_dtype(dtype), vector_shape(n.clone(), batch)),
488        (dtype, matrix_shape(n.clone(), n, batch)),
489    ]
490}
491
492fn eigh_values_meta(dtype: DType, shape: &[SymDim]) -> (DType, Vec<SymDim>) {
493    let n = shape[0].clone();
494    let batch = &shape[2..];
495    (singular_values_dtype(dtype), vector_shape(n, batch))
496}
497
498fn eig_meta(input_dtype: DType, shape: &[SymDim]) -> Vec<(DType, Vec<SymDim>)> {
499    let dtype = eig_output_dtype(input_dtype);
500    let n = shape[0].clone();
501    let batch = &shape[2..];
502    vec![
503        (dtype, vector_shape(n.clone(), batch)),
504        (dtype, matrix_shape(n.clone(), n, batch)),
505    ]
506}
507
508fn eig_values_meta(input_dtype: DType, shape: &[SymDim]) -> (DType, Vec<SymDim>) {
509    let dtype = eig_output_dtype(input_dtype);
510    let n = shape[0].clone();
511    let batch = &shape[2..];
512    (dtype, vector_shape(n, batch))
513}
514
515fn matrix_shape(rows: SymDim, cols: SymDim, batch: &[SymDim]) -> Vec<SymDim> {
516    let mut shape = vec![rows, cols];
517    shape.extend_from_slice(batch);
518    shape
519}
520
521fn vector_shape(len: SymDim, batch: &[SymDim]) -> Vec<SymDim> {
522    let mut shape = vec![len];
523    shape.extend_from_slice(batch);
524    shape
525}
526
527fn eig_output_dtype(dtype: DType) -> DType {
528    match dtype {
529        DType::F64 | DType::C64 => DType::C64,
530        DType::F32 | DType::C32 => DType::C32,
531        DType::I32 | DType::I64 | DType::Bool => DType::C64,
532    }
533}
534
535fn singular_values_dtype(dtype: DType) -> DType {
536    match dtype {
537        DType::C64 => DType::F64,
538        DType::C32 => DType::F32,
539        other => other,
540    }
541}
542
543fn promote_dtypes(dtypes: &[DType]) -> DType {
544    dtypes
545        .iter()
546        .copied()
547        .reduce(promote_dtype)
548        .unwrap_or(DType::F64)
549}
550
551fn promote_dtype(lhs: DType, rhs: DType) -> DType {
552    use DType::*;
553    match (lhs, rhs) {
554        (Bool, Bool) => Bool,
555        (Bool, other) | (other, Bool) => other,
556        (I32, I32) => I32,
557        (I32, I64) | (I64, I32) | (I64, I64) => I64,
558        (I32 | I64, F32 | F64) | (F32 | F64, I32 | I64) => F64,
559        (I32 | I64, C32 | C64) | (C32 | C64, I32 | I64) => C64,
560        (F32, F32) => F32,
561        (F32, F64) | (F64, F32) | (F64, F64) => F64,
562        (F32, C32) | (C32, F32) | (C32, C32) => C32,
563        (F32, C64) | (C64, F32) => C64,
564        (F64, C32 | C64) | (C32 | C64, F64) => C64,
565        (C32, C64) | (C64, C32) | (C64, C64) => C64,
566    }
567}
568
569fn hash_dtype(hasher: &mut dyn Hasher, dtype: DType) {
570    let tag = match dtype {
571        DType::F64 => 0,
572        DType::F32 => 1,
573        DType::I64 => 2,
574        DType::C64 => 3,
575        DType::C32 => 4,
576        DType::I32 => 5,
577        DType::Bool => 6,
578    };
579    hasher.write_u8(tag);
580}