Skip to main content

tenferro_tensor/
types.rs

1use num_complex::{Complex, Complex32, Complex64};
2use num_traits::{One, Zero};
3
4use crate::{DotGeneralConfig, TensorBackend};
5
6/// Memory location for tensor storage.
7///
8/// # Examples
9///
10/// ```ignore
11/// use tenferro_tensor::MemoryKind;
12///
13/// let kind = MemoryKind::UnpinnedHost;
14/// ```
15#[derive(Clone, Debug)]
16pub enum MemoryKind {
17    Device,
18    PinnedHost,
19    UnpinnedHost,
20    Other(String),
21}
22
23/// Concrete compute device description.
24///
25/// # Examples
26///
27/// ```ignore
28/// use tenferro_tensor::ComputeDevice;
29///
30/// let device = ComputeDevice { kind: "cuda".into(), ordinal: 0 };
31/// ```
32#[derive(Clone, Debug)]
33pub struct ComputeDevice {
34    pub kind: String,
35    pub ordinal: usize,
36}
37
38/// Placement metadata for a tensor buffer.
39///
40/// # Examples
41///
42/// ```ignore
43/// use tenferro_tensor::{ComputeDevice, MemoryKind, Placement};
44///
45/// let placement = Placement {
46///     memory_kind: MemoryKind::Device,
47///     resident_device: Some(ComputeDevice { kind: "cuda".into(), ordinal: 0 }),
48/// };
49/// ```
50#[derive(Clone, Debug)]
51pub struct Placement {
52    pub memory_kind: MemoryKind,
53    pub resident_device: Option<ComputeDevice>,
54}
55
56/// Backend-owned buffer handle.
57///
58/// # Examples
59///
60/// ```ignore
61/// use tenferro_tensor::BufferHandle;
62///
63/// let handle = BufferHandle::<f64>::new(7);
64/// ```
65#[derive(Clone, Debug)]
66pub struct BufferHandle<T> {
67    pub id: u64,
68    _phantom: std::marker::PhantomData<T>,
69}
70
71impl<T> BufferHandle<T> {
72    /// Create a new backend buffer handle.
73    ///
74    /// # Examples
75    ///
76    /// ```ignore
77    /// use tenferro_tensor::BufferHandle;
78    ///
79    /// let handle = BufferHandle::<f64>::new(1);
80    /// assert_eq!(handle.id, 1);
81    /// ```
82    pub fn new(id: u64) -> Self {
83        Self {
84            id,
85            _phantom: std::marker::PhantomData,
86        }
87    }
88}
89
90/// Tensor storage.
91///
92/// # Examples
93///
94/// ```ignore
95/// use tenferro_tensor::Buffer;
96///
97/// let host = Buffer::Host(vec![1.0_f64, 2.0]);
98/// ```
99#[derive(Clone, Debug)]
100pub enum Buffer<T> {
101    Host(Vec<T>),
102    Backend(BufferHandle<T>),
103    #[cfg(feature = "cubecl")]
104    Cubecl(CubeclBuffer<T>),
105}
106
107/// CubeCL-managed GPU buffer.
108///
109/// This wraps a CubeCL server handle that owns the underlying GPU allocation.
110///
111/// # Examples
112///
113/// ```
114/// let _name = core::any::type_name::<tenferro_tensor::CubeclBuffer<f64>>();
115/// assert!(_name.contains("CubeclBuffer"));
116/// ```
117#[cfg(feature = "cubecl")]
118#[derive(Clone, Debug)]
119pub struct CubeclBuffer<T> {
120    /// CubeCL server handle that owns the GPU allocation.
121    pub handle: cubecl::server::Handle,
122    /// Number of elements stored in the allocation.
123    pub len: usize,
124    pub(crate) _marker: std::marker::PhantomData<T>,
125}
126
127#[cfg(feature = "cubecl")]
128impl<T> CubeclBuffer<T> {
129    /// Create a CubeCL buffer wrapper from a handle and element count.
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// let _new = tenferro_tensor::CubeclBuffer::<f64>::new;
135    /// let _ = _new;
136    /// ```
137    pub fn new(handle: cubecl::server::Handle, len: usize) -> Self {
138        Self {
139            handle,
140            len,
141            _marker: std::marker::PhantomData,
142        }
143    }
144}
145
146/// Contiguous column-major typed tensor storage.
147///
148/// # Examples
149///
150/// ```ignore
151/// use tenferro_tensor::TypedTensor;
152///
153/// let t = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
154/// assert_eq!(t.shape, vec![2, 2]);
155/// ```
156#[derive(Clone, Debug)]
157pub struct TypedTensor<T> {
158    pub buffer: Buffer<T>,
159    pub shape: Vec<usize>,
160    pub placement: Placement,
161}
162
163/// Runtime scalar dtype tag.
164///
165/// # Examples
166///
167/// ```ignore
168/// use tenferro_tensor::DType;
169///
170/// assert_eq!(DType::F64 as u8, DType::F64 as u8);
171/// ```
172#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
173pub enum DType {
174    F32,
175    F64,
176    C32,
177    C64,
178}
179
180/// Sealed trait for scalar types that can be stored in a [`Tensor`].
181///
182/// This trait is implemented for `f64`, `f32`, [`Complex64`], and
183/// [`Complex32`].
184///
185/// # Examples
186///
187/// ```
188/// use tenferro_tensor::TensorScalar;
189///
190/// let tensor = <f64 as TensorScalar>::into_tensor(vec![2], vec![1.0, 2.0]);
191/// assert_eq!(tensor.as_slice::<f64>(), Some([1.0, 2.0].as_slice()));
192/// ```
193pub trait TensorScalar: Copy + Clone + Send + Sync + 'static + private::Sealed {
194    /// Real-valued counterpart of this scalar type.
195    type Real: TensorScalar;
196
197    /// The [`DType`] tag corresponding to this scalar type.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use tenferro_tensor::{DType, TensorScalar};
203    ///
204    /// assert_eq!(f64::dtype(), DType::F64);
205    /// assert_eq!(f32::dtype(), DType::F32);
206    /// ```
207    fn dtype() -> DType;
208
209    /// Wrap typed data into a [`Tensor`] enum variant.
210    fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor;
211
212    /// Try to borrow the host data from a [`Tensor`].
213    fn try_as_slice(tensor: &Tensor) -> Option<&[Self]>;
214
215    /// Try to extract a [`TypedTensor<Self>`] from a dynamic [`Tensor`].
216    ///
217    /// Returns `None` if the tensor dtype does not match `Self`.
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use tenferro_tensor::{Tensor, TensorScalar};
223    ///
224    /// let tensor = Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]);
225    /// let typed = <f64 as TensorScalar>::try_into_typed(tensor).unwrap();
226    ///
227    /// assert_eq!(typed.as_slice(), &[1.0, 2.0]);
228    /// ```
229    fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>>;
230}
231
232mod private {
233    pub trait Sealed {}
234
235    impl Sealed for f64 {}
236    impl Sealed for f32 {}
237    impl Sealed for num_complex::Complex64 {}
238    impl Sealed for num_complex::Complex32 {}
239}
240
241impl TensorScalar for f64 {
242    type Real = f64;
243
244    fn dtype() -> DType {
245        DType::F64
246    }
247
248    fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
249        Tensor::F64(TypedTensor::from_vec(shape, data))
250    }
251
252    fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
253        match tensor {
254            Tensor::F64(t) => Some(t.host_data()),
255            _ => None,
256        }
257    }
258
259    fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
260        match tensor {
261            Tensor::F64(inner) => Some(inner),
262            _ => None,
263        }
264    }
265}
266
267impl TensorScalar for f32 {
268    type Real = f32;
269
270    fn dtype() -> DType {
271        DType::F32
272    }
273
274    fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
275        Tensor::F32(TypedTensor::from_vec(shape, data))
276    }
277
278    fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
279        match tensor {
280            Tensor::F32(t) => Some(t.host_data()),
281            _ => None,
282        }
283    }
284
285    fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
286        match tensor {
287            Tensor::F32(inner) => Some(inner),
288            _ => None,
289        }
290    }
291}
292
293impl TensorScalar for Complex64 {
294    type Real = f64;
295
296    fn dtype() -> DType {
297        DType::C64
298    }
299
300    fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
301        Tensor::C64(TypedTensor::from_vec(shape, data))
302    }
303
304    fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
305        match tensor {
306            Tensor::C64(t) => Some(t.host_data()),
307            _ => None,
308        }
309    }
310
311    fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
312        match tensor {
313            Tensor::C64(inner) => Some(inner),
314            _ => None,
315        }
316    }
317}
318
319impl TensorScalar for Complex32 {
320    type Real = f32;
321
322    fn dtype() -> DType {
323        DType::C32
324    }
325
326    fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
327        Tensor::C32(TypedTensor::from_vec(shape, data))
328    }
329
330    fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
331        match tensor {
332            Tensor::C32(t) => Some(t.host_data()),
333            _ => None,
334        }
335    }
336
337    fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
338        match tensor {
339            Tensor::C32(inner) => Some(inner),
340            _ => None,
341        }
342    }
343}
344
345/// Dynamic tensor enum over the supported scalar types.
346///
347/// # Examples
348///
349/// ```ignore
350/// use tenferro_tensor::{Tensor, TypedTensor};
351///
352/// let t = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
353/// assert_eq!(t.shape(), &[2]);
354/// ```
355#[derive(Clone, Debug)]
356pub enum Tensor {
357    F32(TypedTensor<f32>),
358    F64(TypedTensor<f64>),
359    C32(TypedTensor<Complex<f32>>),
360    C64(TypedTensor<Complex<f64>>),
361}
362
363/// Wrap an `f64` [`TypedTensor`] into the corresponding [`Tensor`] variant.
364///
365/// # Examples
366///
367/// ```
368/// use tenferro_tensor::{Tensor, TypedTensor};
369///
370/// let typed = TypedTensor::from_vec(vec![2], vec![1.0_f64, 2.0]);
371/// let tensor: Tensor = typed.into();
372/// assert_eq!(tensor.shape(), &[2]);
373/// ```
374impl From<TypedTensor<f64>> for Tensor {
375    fn from(t: TypedTensor<f64>) -> Self {
376        Tensor::F64(t)
377    }
378}
379
380/// Wrap an `f32` [`TypedTensor`] into the corresponding [`Tensor`] variant.
381///
382/// # Examples
383///
384/// ```
385/// use tenferro_tensor::{Tensor, TypedTensor};
386///
387/// let typed = TypedTensor::from_vec(vec![2], vec![1.0_f32, 2.0]);
388/// let tensor: Tensor = typed.into();
389/// assert_eq!(tensor.shape(), &[2]);
390/// ```
391impl From<TypedTensor<f32>> for Tensor {
392    fn from(t: TypedTensor<f32>) -> Self {
393        Tensor::F32(t)
394    }
395}
396
397/// Wrap a [`Complex64`] [`TypedTensor`] into the corresponding [`Tensor`]
398/// variant.
399///
400/// # Examples
401///
402/// ```
403/// use num_complex::Complex64;
404/// use tenferro_tensor::{Tensor, TypedTensor};
405///
406/// let typed = TypedTensor::from_vec(
407///     vec![1],
408///     vec![Complex64::new(1.0, 2.0)],
409/// );
410/// let tensor: Tensor = typed.into();
411/// assert_eq!(tensor.shape(), &[1]);
412/// ```
413impl From<TypedTensor<Complex<f64>>> for Tensor {
414    fn from(t: TypedTensor<Complex<f64>>) -> Self {
415        Tensor::C64(t)
416    }
417}
418
419/// Wrap a [`Complex32`] [`TypedTensor`] into the corresponding [`Tensor`]
420/// variant.
421///
422/// # Examples
423///
424/// ```
425/// use num_complex::Complex32;
426/// use tenferro_tensor::{Tensor, TypedTensor};
427///
428/// let typed = TypedTensor::from_vec(
429///     vec![1],
430///     vec![Complex32::new(1.0, 2.0)],
431/// );
432/// let tensor: Tensor = typed.into();
433/// assert_eq!(tensor.shape(), &[1]);
434/// ```
435impl From<TypedTensor<Complex<f32>>> for Tensor {
436    fn from(t: TypedTensor<Complex<f32>>) -> Self {
437        Tensor::C32(t)
438    }
439}
440
441/// Column-major strides derived from a shape.
442///
443/// # Examples
444///
445/// ```ignore
446/// use tenferro_tensor::col_major_strides;
447///
448/// assert_eq!(col_major_strides(&[2, 3]), vec![1, 2]);
449/// ```
450pub fn col_major_strides(shape: &[usize]) -> Vec<isize> {
451    if shape.is_empty() {
452        return vec![];
453    }
454    let mut strides = vec![1isize; shape.len()];
455    for i in 1..shape.len() {
456        strides[i] = strides[i - 1] * shape[i - 1] as isize;
457    }
458    strides
459}
460
461pub(crate) fn default_placement() -> Placement {
462    Placement {
463        memory_kind: MemoryKind::UnpinnedHost,
464        resident_device: None,
465    }
466}
467
468impl<T: Clone + Zero> TypedTensor<T> {
469    /// Allocate a zero-filled tensor.
470    ///
471    /// # Examples
472    ///
473    /// ```ignore
474    /// use tenferro_tensor::TypedTensor;
475    ///
476    /// let t = TypedTensor::<f64>::zeros(vec![2, 3]);
477    /// assert_eq!(t.n_elements(), 6);
478    /// ```
479    pub fn zeros(shape: Vec<usize>) -> Self {
480        let n: usize = shape.iter().product();
481        Self {
482            buffer: Buffer::Host(vec![T::zero(); n]),
483            shape,
484            placement: default_placement(),
485        }
486    }
487}
488
489impl<T: Clone + One + Zero> TypedTensor<T> {
490    /// Allocate a one-filled tensor.
491    ///
492    /// # Examples
493    ///
494    /// ```ignore
495    /// use tenferro_tensor::TypedTensor;
496    ///
497    /// let t = TypedTensor::<f64>::ones(vec![2]);
498    /// assert_eq!(t.host_data(), &[1.0, 1.0]);
499    /// ```
500    pub fn ones(shape: Vec<usize>) -> Self {
501        let n: usize = shape.iter().product();
502        Self {
503            buffer: Buffer::Host(vec![T::one(); n]),
504            shape,
505            placement: default_placement(),
506        }
507    }
508}
509
510impl<T: Clone> TypedTensor<T> {
511    /// Create a tensor from a column-major buffer.
512    ///
513    /// # Examples
514    ///
515    /// ```ignore
516    /// use tenferro_tensor::TypedTensor;
517    ///
518    /// let t = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
519    /// assert_eq!(t.get(&[1, 0]), &2.0);
520    /// ```
521    pub fn from_vec(shape: Vec<usize>, data: Vec<T>) -> Self {
522        let n: usize = shape.iter().product();
523        assert_eq!(
524            data.len(),
525            n,
526            "data length {} does not match shape product {}",
527            data.len(),
528            n
529        );
530        Self {
531            buffer: Buffer::Host(data),
532            shape,
533            placement: default_placement(),
534        }
535    }
536
537    /// Number of elements in the tensor.
538    ///
539    /// # Examples
540    ///
541    /// ```ignore
542    /// use tenferro_tensor::TypedTensor;
543    ///
544    /// let t = TypedTensor::<f64>::from_vec(vec![2, 3], vec![0.0; 6]);
545    /// assert_eq!(t.n_elements(), 6);
546    /// ```
547    pub fn n_elements(&self) -> usize {
548        self.shape.iter().product()
549    }
550
551    /// Borrow the host buffer.
552    ///
553    /// # Examples
554    ///
555    /// ```ignore
556    /// use tenferro_tensor::TypedTensor;
557    ///
558    /// let t = TypedTensor::<f64>::from_vec(vec![2], vec![1.0, 2.0]);
559    /// assert_eq!(t.host_data(), &[1.0, 2.0]);
560    /// ```
561    pub fn host_data(&self) -> &[T] {
562        match &self.buffer {
563            Buffer::Host(v) => v,
564            Buffer::Backend(_) => panic!("host_data called on backend buffer"),
565            #[cfg(feature = "cubecl")]
566            Buffer::Cubecl(_) => {
567                panic!(
568                    "Cannot access GPU buffer (Buffer::Cubecl) as host data. \
569                       Use cubecl::download_tensor() to transfer to CPU first."
570                )
571            }
572        }
573    }
574
575    /// View the tensor data as a flat slice.
576    ///
577    /// This is an alias for `host_data()` for API consistency with
578    /// `Tensor::as_slice`.
579    ///
580    /// # Examples
581    ///
582    /// ```
583    /// use tenferro_tensor::TypedTensor;
584    ///
585    /// let t = TypedTensor::<f64>::from_vec(vec![2], vec![1.0, 2.0]);
586    /// assert_eq!(t.as_slice(), &[1.0, 2.0]);
587    /// ```
588    pub fn as_slice(&self) -> &[T] {
589        self.host_data()
590    }
591
592    /// Mutably borrow the host buffer.
593    ///
594    /// # Examples
595    ///
596    /// ```ignore
597    /// use tenferro_tensor::TypedTensor;
598    ///
599    /// let mut t = TypedTensor::<f64>::zeros(vec![2]);
600    /// t.host_data_mut()[0] = 3.0;
601    /// assert_eq!(t.host_data(), &[3.0, 0.0]);
602    /// ```
603    pub fn host_data_mut(&mut self) -> &mut [T] {
604        match &mut self.buffer {
605            Buffer::Host(v) => v,
606            Buffer::Backend(_) => panic!("host_data_mut called on backend buffer"),
607            #[cfg(feature = "cubecl")]
608            Buffer::Cubecl(_) => {
609                panic!(
610                    "Cannot access GPU buffer (Buffer::Cubecl) as host data. \
611                       Use cubecl::download_tensor() to transfer to CPU first."
612                )
613            }
614        }
615    }
616
617    /// Compute the linear column-major offset for an index.
618    ///
619    /// # Examples
620    ///
621    /// ```ignore
622    /// use tenferro_tensor::TypedTensor;
623    ///
624    /// let t = TypedTensor::<f64>::zeros(vec![2, 3]);
625    /// assert_eq!(t.linear_offset(&[1, 2]), 5);
626    /// ```
627    pub fn linear_offset(&self, indices: &[usize]) -> usize {
628        assert_eq!(indices.len(), self.shape.len());
629        let mut offset = 0usize;
630        let mut stride = 1usize;
631        for (i, &idx) in indices.iter().enumerate() {
632            assert!(idx < self.shape[i], "index out of bounds");
633            offset += idx * stride;
634            stride *= self.shape[i];
635        }
636        offset
637    }
638
639    /// Borrow a single element by multi-index.
640    ///
641    /// # Examples
642    ///
643    /// ```ignore
644    /// use tenferro_tensor::TypedTensor;
645    ///
646    /// let t = TypedTensor::<f64>::from_vec(vec![2], vec![1.0, 2.0]);
647    /// assert_eq!(t.get(&[1]), &2.0);
648    /// ```
649    pub fn get(&self, indices: &[usize]) -> &T {
650        let off = self.linear_offset(indices);
651        &self.host_data()[off]
652    }
653
654    /// Mutably borrow a single element by multi-index.
655    ///
656    /// # Examples
657    ///
658    /// ```ignore
659    /// use tenferro_tensor::TypedTensor;
660    ///
661    /// let mut t = TypedTensor::<f64>::zeros(vec![1]);
662    /// *t.get_mut(&[0]) = 7.0;
663    /// assert_eq!(t.host_data(), &[7.0]);
664    /// ```
665    pub fn get_mut(&mut self, indices: &[usize]) -> &mut T {
666        let off = self.linear_offset(indices);
667        &mut self.host_data_mut()[off]
668    }
669}
670
671/// Element-wise conjugation helper.
672pub trait ConjElem {
673    fn conj_elem(self) -> Self;
674}
675
676impl ConjElem for f32 {
677    fn conj_elem(self) -> Self {
678        self
679    }
680}
681
682impl ConjElem for f64 {
683    fn conj_elem(self) -> Self {
684        self
685    }
686}
687
688impl ConjElem for Complex<f32> {
689    fn conj_elem(self) -> Self {
690        self.conj()
691    }
692}
693
694impl ConjElem for Complex<f64> {
695    fn conj_elem(self) -> Self {
696        self.conj()
697    }
698}
699
700macro_rules! dispatch_tensor {
701    ($self:expr, $inner:ident => $body:expr) => {
702        match $self {
703            Tensor::F32($inner) => Tensor::F32($body),
704            Tensor::F64($inner) => Tensor::F64($body),
705            Tensor::C32($inner) => Tensor::C32($body),
706            Tensor::C64($inner) => Tensor::C64($body),
707        }
708    };
709}
710
711macro_rules! dispatch_binary {
712    ($lhs:expr, $rhs:expr, |$a:ident, $b:ident| $body:expr) => {
713        match ($lhs, $rhs) {
714            (Tensor::F32($a), Tensor::F32($b)) => Tensor::F32($body),
715            (Tensor::F64($a), Tensor::F64($b)) => Tensor::F64($body),
716            (Tensor::C32($a), Tensor::C32($b)) => Tensor::C32($body),
717            (Tensor::C64($a), Tensor::C64($b)) => Tensor::C64($body),
718            _ => panic!("dtype mismatch in binary op"),
719        }
720    };
721}
722
723pub(crate) use dispatch_binary;
724pub(crate) use dispatch_tensor;
725
726impl Tensor {
727    /// Create a tensor from a shape and flat data.
728    ///
729    /// This is the `Tensor`-level equivalent of `TypedTensor::<T>::from_vec`.
730    ///
731    /// # Examples
732    ///
733    /// ```
734    /// use tenferro_tensor::Tensor;
735    ///
736    /// let t = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
737    /// assert_eq!(t.shape(), &[2, 3]);
738    /// assert_eq!(t.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
739    /// ```
740    pub fn from_vec<T: TensorScalar>(shape: Vec<usize>, data: Vec<T>) -> Self {
741        T::into_tensor(shape, data)
742    }
743
744    /// Tensor shape.
745    ///
746    /// # Examples
747    ///
748    /// ```ignore
749    /// use tenferro_tensor::{Tensor, TypedTensor};
750    ///
751    /// let t = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
752    /// assert_eq!(t.shape(), &[2]);
753    /// ```
754    pub fn shape(&self) -> &[usize] {
755        match self {
756            Tensor::F32(t) => &t.shape,
757            Tensor::F64(t) => &t.shape,
758            Tensor::C32(t) => &t.shape,
759            Tensor::C64(t) => &t.shape,
760        }
761    }
762
763    /// Tensor dtype tag.
764    ///
765    /// # Examples
766    ///
767    /// ```ignore
768    /// use tenferro_tensor::{DType, Tensor, TypedTensor};
769    ///
770    /// let t = Tensor::F64(TypedTensor::from_vec(vec![], vec![1.0]));
771    /// assert_eq!(t.dtype(), DType::F64);
772    /// ```
773    pub fn dtype(&self) -> DType {
774        match self {
775            Tensor::F32(_) => DType::F32,
776            Tensor::F64(_) => DType::F64,
777            Tensor::C32(_) => DType::C32,
778            Tensor::C64(_) => DType::C64,
779        }
780    }
781
782    /// Try to borrow the host data as a typed slice.
783    ///
784    /// Returns `None` if the tensor dtype does not match `T`.
785    ///
786    /// # Examples
787    ///
788    /// ```
789    /// use tenferro_tensor::{Tensor, TypedTensor};
790    ///
791    /// let t = Tensor::F64(TypedTensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]));
792    /// assert_eq!(t.as_slice::<f64>(), Some([1.0, 2.0, 3.0].as_slice()));
793    /// assert_eq!(t.as_slice::<f32>(), None);
794    /// ```
795    pub fn as_slice<T: TensorScalar>(&self) -> Option<&[T]> {
796        T::try_as_slice(self)
797    }
798
799    /// Singular value decomposition: `A = U diag(S) Vt`.
800    ///
801    /// Returns `(U, S, Vt)` using the thin/economy SVD.
802    ///
803    /// # Examples
804    ///
805    /// ```
806    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
807    ///
808    /// let mut ctx = CpuBackend::new();
809    /// let a = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
810    /// let (u, s, vt) = a.svd(&mut ctx).unwrap();
811    ///
812    /// assert_eq!(u.shape(), &[3, 2]);
813    /// assert_eq!(s.shape(), &[2]);
814    /// assert_eq!(vt.shape(), &[2, 2]);
815    /// ```
816    pub fn svd(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self, Self)> {
817        ctx.with_exec_session(|exec| unpack_three("svd", exec.svd(self)?))
818    }
819
820    /// QR decomposition: `A = Q R`.
821    ///
822    /// Returns `(Q, R)` using the thin/economy QR decomposition.
823    ///
824    /// # Examples
825    ///
826    /// ```
827    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
828    ///
829    /// let mut ctx = CpuBackend::new();
830    /// let a = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
831    /// let (q, r) = a.qr(&mut ctx).unwrap();
832    ///
833    /// assert_eq!(q.shape(), &[3, 2]);
834    /// assert_eq!(r.shape(), &[2, 2]);
835    /// ```
836    pub fn qr(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self)> {
837        ctx.with_exec_session(|exec| unpack_two("qr", exec.qr(self)?))
838    }
839
840    /// LU decomposition with partial pivoting: `P A = L U`.
841    ///
842    /// Returns `(P, L, U, parity)`.
843    ///
844    /// # Examples
845    ///
846    /// ```
847    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
848    ///
849    /// let mut ctx = CpuBackend::new();
850    /// let a = Tensor::from_vec(vec![2, 2], vec![0.0_f64, 1.0, 1.0, 0.0]);
851    /// let (p, l, u, parity) = a.lu(&mut ctx).unwrap();
852    ///
853    /// assert_eq!(p.shape(), &[2, 2]);
854    /// assert_eq!(l.shape(), &[2, 2]);
855    /// assert_eq!(u.shape(), &[2, 2]);
856    /// assert_eq!(parity.shape(), &[] as &[usize]);
857    /// ```
858    pub fn lu(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self, Self, Self)> {
859        ctx.with_exec_session(|exec| unpack_four("lu", exec.lu(self)?))
860    }
861
862    /// Cholesky decomposition: `A = L L^T` or `A = L L^H` for complex inputs.
863    ///
864    /// Returns the lower-triangular factor `L`.
865    ///
866    /// # Examples
867    ///
868    /// ```
869    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
870    ///
871    /// let mut ctx = CpuBackend::new();
872    /// let a = Tensor::from_vec(vec![2, 2], vec![4.0_f64, 1.0, 1.0, 3.0]);
873    /// let l = a.cholesky(&mut ctx).unwrap();
874    ///
875    /// assert_eq!(l.shape(), &[2, 2]);
876    /// ```
877    pub fn cholesky(&self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
878        ctx.with_exec_session(|exec| exec.cholesky(self))
879    }
880
881    /// Symmetric or Hermitian eigendecomposition: `A = V diag(W) V^T`.
882    ///
883    /// Returns `(eigenvalues, eigenvectors)`.
884    ///
885    /// # Examples
886    ///
887    /// ```
888    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
889    ///
890    /// let mut ctx = CpuBackend::new();
891    /// let a = Tensor::from_vec(vec![2, 2], vec![4.0_f64, 1.0, 1.0, 3.0]);
892    /// let (w, v) = a.eigh(&mut ctx).unwrap();
893    ///
894    /// assert_eq!(w.shape(), &[2]);
895    /// assert_eq!(v.shape(), &[2, 2]);
896    /// ```
897    pub fn eigh(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self)> {
898        ctx.with_exec_session(|exec| unpack_two("eigh", exec.eigh(self)?))
899    }
900
901    /// General eigendecomposition.
902    ///
903    /// Returns `(eigenvalues, eigenvectors)`.
904    ///
905    /// # Examples
906    ///
907    /// ```
908    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
909    ///
910    /// let mut ctx = CpuBackend::new();
911    /// let a = Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]);
912    /// let (w, v) = a.eig(&mut ctx).unwrap();
913    ///
914    /// assert_eq!(w.shape(), &[2]);
915    /// assert_eq!(v.shape(), &[2, 2]);
916    /// ```
917    pub fn eig(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self)> {
918        ctx.with_exec_session(|exec| unpack_two("eig", exec.eig(self)?))
919    }
920
921    /// Solve `A x = b` for `x`.
922    ///
923    /// # Examples
924    ///
925    /// ```
926    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
927    ///
928    /// let mut ctx = CpuBackend::new();
929    /// let a = Tensor::from_vec(vec![2, 2], vec![2.0_f64, 1.0, 1.0, 2.0]);
930    /// let b = Tensor::from_vec(vec![2, 1], vec![1.0_f64, 0.0]);
931    /// let x = a.solve(&b, &mut ctx).unwrap();
932    ///
933    /// assert_eq!(x.shape(), &[2, 1]);
934    /// ```
935    pub fn solve(&self, b: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
936        ctx.solve(self, b)
937    }
938
939    /// Solve a triangular system.
940    ///
941    /// # Examples
942    ///
943    /// ```
944    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
945    ///
946    /// let mut ctx = CpuBackend::new();
947    /// let a = Tensor::from_vec(vec![2, 2], vec![2.0_f64, 1.0, 0.0, 3.0]);
948    /// let b = Tensor::from_vec(vec![2, 1], vec![2.0_f64, 7.0]);
949    /// let x = a
950    ///     .triangular_solve(&b, true, true, false, false, &mut ctx)
951    ///     .unwrap();
952    ///
953    /// assert_eq!(x.shape(), &[2, 1]);
954    /// ```
955    pub fn triangular_solve(
956        &self,
957        b: &Self,
958        left_side: bool,
959        lower: bool,
960        transpose_a: bool,
961        unit_diagonal: bool,
962        ctx: &mut impl TensorBackend,
963    ) -> crate::Result<Self> {
964        ctx.with_exec_session(|exec| {
965            exec.triangular_solve(self, b, left_side, lower, transpose_a, unit_diagonal)
966        })
967    }
968
969    /// Elementwise addition.
970    ///
971    /// # Examples
972    ///
973    /// ```
974    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
975    ///
976    /// let mut ctx = CpuBackend::new();
977    /// let a = Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
978    /// let b = Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]);
979    /// let c = a.add(&b, &mut ctx).unwrap();
980    ///
981    /// assert_eq!(c.as_slice::<f64>().unwrap(), &[5.0, 7.0, 9.0]);
982    /// ```
983    pub fn add(&self, other: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
984        ctx.with_exec_session(|exec| exec.add(self, other))
985    }
986
987    /// Elementwise multiplication.
988    ///
989    /// # Examples
990    ///
991    /// ```
992    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
993    ///
994    /// let mut ctx = CpuBackend::new();
995    /// let a = Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
996    /// let b = Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]);
997    /// let c = a.mul(&b, &mut ctx).unwrap();
998    ///
999    /// assert_eq!(c.as_slice::<f64>().unwrap(), &[4.0, 10.0, 18.0]);
1000    /// ```
1001    pub fn mul(&self, other: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1002        ctx.with_exec_session(|exec| exec.mul(self, other))
1003    }
1004
1005    /// Negation.
1006    ///
1007    /// # Examples
1008    ///
1009    /// ```
1010    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1011    ///
1012    /// let mut ctx = CpuBackend::new();
1013    /// let a = Tensor::from_vec(vec![3], vec![1.0_f64, -2.0, 3.0]);
1014    /// let b = a.neg(&mut ctx).unwrap();
1015    ///
1016    /// assert_eq!(b.as_slice::<f64>().unwrap(), &[-1.0, 2.0, -3.0]);
1017    /// ```
1018    pub fn neg(&self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1019        ctx.with_exec_session(|exec| exec.neg(self))
1020    }
1021
1022    /// Transpose with an explicit permutation.
1023    ///
1024    /// # Examples
1025    ///
1026    /// ```
1027    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1028    ///
1029    /// let mut ctx = CpuBackend::new();
1030    /// let a = Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]);
1031    /// let b = a.transpose(&[1, 0], &mut ctx).unwrap();
1032    ///
1033    /// assert_eq!(b.shape(), &[2, 2]);
1034    /// assert_eq!(b.as_slice::<f64>().unwrap(), &[1.0, 3.0, 2.0, 4.0]);
1035    /// ```
1036    pub fn transpose(&self, perm: &[usize], ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1037        ctx.with_exec_session(|exec| exec.transpose(self, perm))
1038    }
1039
1040    /// Reshape to a new shape with the same number of elements.
1041    ///
1042    /// # Examples
1043    ///
1044    /// ```
1045    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1046    ///
1047    /// let mut ctx = CpuBackend::new();
1048    /// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1049    /// let b = a.reshape(&[3, 2], &mut ctx).unwrap();
1050    ///
1051    /// assert_eq!(b.shape(), &[3, 2]);
1052    /// assert_eq!(b.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1053    /// ```
1054    pub fn reshape(&self, shape: &[usize], ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1055        ctx.with_exec_session(|exec| exec.reshape(self, shape))
1056    }
1057
1058    /// Reduce sum over the specified axes.
1059    ///
1060    /// # Examples
1061    ///
1062    /// ```
1063    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1064    ///
1065    /// let mut ctx = CpuBackend::new();
1066    /// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1067    /// let b = a.reduce_sum(&[1], &mut ctx).unwrap();
1068    ///
1069    /// assert_eq!(b.shape(), &[2]);
1070    /// assert_eq!(b.as_slice::<f64>().unwrap(), &[9.0, 12.0]);
1071    /// ```
1072    pub fn reduce_sum(&self, axes: &[usize], ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1073        ctx.with_exec_session(|exec| exec.reduce_sum(self, axes))
1074    }
1075
1076    /// Matrix multiplication for rank-2 tensors.
1077    ///
1078    /// This is a convenience wrapper around `dot_general`.
1079    ///
1080    /// # Examples
1081    ///
1082    /// ```
1083    /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1084    ///
1085    /// let mut ctx = CpuBackend::new();
1086    /// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1087    /// let b = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1088    /// let c = a.matmul(&b, &mut ctx).unwrap();
1089    ///
1090    /// assert_eq!(c.shape(), &[2, 2]);
1091    /// assert_eq!(c.as_slice::<f64>().unwrap(), &[22.0, 28.0, 49.0, 64.0]);
1092    /// ```
1093    pub fn matmul(&self, other: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1094        let config = DotGeneralConfig {
1095            lhs_contracting_dims: vec![1],
1096            rhs_contracting_dims: vec![0],
1097            lhs_batch_dims: vec![],
1098            rhs_batch_dims: vec![],
1099            lhs_rank: self.shape().len(),
1100            rhs_rank: other.shape().len(),
1101        };
1102        ctx.with_exec_session(|exec| exec.dot_general(self, other, &config))
1103    }
1104}
1105
1106pub(crate) fn flat_to_multi(mut flat: usize, shape: &[usize], out: &mut [usize]) {
1107    for i in 0..shape.len() {
1108        out[i] = flat % shape[i];
1109        flat /= shape[i];
1110    }
1111}
1112
1113fn invalid_output_count(op: &'static str, expected: usize, actual: usize) -> crate::Error {
1114    crate::Error::BackendFailure {
1115        op,
1116        message: format!("expected {expected} output tensors, got {actual}"),
1117    }
1118}
1119
1120fn unpack_two(op: &'static str, results: Vec<Tensor>) -> crate::Result<(Tensor, Tensor)> {
1121    let actual = results.len();
1122    let mut iter = results.into_iter();
1123    match (iter.next(), iter.next(), iter.next()) {
1124        (Some(a), Some(b), None) => Ok((a, b)),
1125        _ => Err(invalid_output_count(op, 2, actual)),
1126    }
1127}
1128
1129fn unpack_three(op: &'static str, results: Vec<Tensor>) -> crate::Result<(Tensor, Tensor, Tensor)> {
1130    let actual = results.len();
1131    let mut iter = results.into_iter();
1132    match (iter.next(), iter.next(), iter.next(), iter.next()) {
1133        (Some(a), Some(b), Some(c), None) => Ok((a, b, c)),
1134        _ => Err(invalid_output_count(op, 3, actual)),
1135    }
1136}
1137
1138fn unpack_four(
1139    op: &'static str,
1140    results: Vec<Tensor>,
1141) -> crate::Result<(Tensor, Tensor, Tensor, Tensor)> {
1142    let actual = results.len();
1143    let mut iter = results.into_iter();
1144    match (
1145        iter.next(),
1146        iter.next(),
1147        iter.next(),
1148        iter.next(),
1149        iter.next(),
1150    ) {
1151        (Some(a), Some(b), Some(c), Some(d), None) => Ok((a, b, c, d)),
1152        _ => Err(invalid_output_count(op, 4, actual)),
1153    }
1154}