tenferro_tensor/types.rs
1use num_complex::{Complex, Complex32, Complex64};
2use num_traits::{One, Zero};
3
4use crate::{DotGeneralConfig, TensorBackend};
5
6/// Memory location for tensor storage.
7///
8/// # Examples
9///
10/// ```ignore
11/// use tenferro_tensor::MemoryKind;
12///
13/// let kind = MemoryKind::UnpinnedHost;
14/// ```
15#[derive(Clone, Debug)]
16pub enum MemoryKind {
17 Device,
18 PinnedHost,
19 UnpinnedHost,
20 Other(String),
21}
22
23/// Concrete compute device description.
24///
25/// # Examples
26///
27/// ```ignore
28/// use tenferro_tensor::ComputeDevice;
29///
30/// let device = ComputeDevice { kind: "cuda".into(), ordinal: 0 };
31/// ```
32#[derive(Clone, Debug)]
33pub struct ComputeDevice {
34 pub kind: String,
35 pub ordinal: usize,
36}
37
38/// Placement metadata for a tensor buffer.
39///
40/// # Examples
41///
42/// ```ignore
43/// use tenferro_tensor::{ComputeDevice, MemoryKind, Placement};
44///
45/// let placement = Placement {
46/// memory_kind: MemoryKind::Device,
47/// resident_device: Some(ComputeDevice { kind: "cuda".into(), ordinal: 0 }),
48/// };
49/// ```
50#[derive(Clone, Debug)]
51pub struct Placement {
52 pub memory_kind: MemoryKind,
53 pub resident_device: Option<ComputeDevice>,
54}
55
56/// Backend-owned buffer handle.
57///
58/// # Examples
59///
60/// ```ignore
61/// use tenferro_tensor::BufferHandle;
62///
63/// let handle = BufferHandle::<f64>::new(7);
64/// ```
65#[derive(Clone, Debug)]
66pub struct BufferHandle<T> {
67 pub id: u64,
68 _phantom: std::marker::PhantomData<T>,
69}
70
71impl<T> BufferHandle<T> {
72 /// Create a new backend buffer handle.
73 ///
74 /// # Examples
75 ///
76 /// ```ignore
77 /// use tenferro_tensor::BufferHandle;
78 ///
79 /// let handle = BufferHandle::<f64>::new(1);
80 /// assert_eq!(handle.id, 1);
81 /// ```
82 pub fn new(id: u64) -> Self {
83 Self {
84 id,
85 _phantom: std::marker::PhantomData,
86 }
87 }
88}
89
90/// Tensor storage.
91///
92/// # Examples
93///
94/// ```ignore
95/// use tenferro_tensor::Buffer;
96///
97/// let host = Buffer::Host(vec![1.0_f64, 2.0]);
98/// ```
99#[derive(Clone, Debug)]
100pub enum Buffer<T> {
101 Host(Vec<T>),
102 Backend(BufferHandle<T>),
103 #[cfg(feature = "cubecl")]
104 Cubecl(CubeclBuffer<T>),
105}
106
107/// CubeCL-managed GPU buffer.
108///
109/// This wraps a CubeCL server handle that owns the underlying GPU allocation.
110///
111/// # Examples
112///
113/// ```
114/// let _name = core::any::type_name::<tenferro_tensor::CubeclBuffer<f64>>();
115/// assert!(_name.contains("CubeclBuffer"));
116/// ```
117#[cfg(feature = "cubecl")]
118#[derive(Clone, Debug)]
119pub struct CubeclBuffer<T> {
120 /// CubeCL server handle that owns the GPU allocation.
121 pub handle: cubecl::server::Handle,
122 /// Number of elements stored in the allocation.
123 pub len: usize,
124 pub(crate) _marker: std::marker::PhantomData<T>,
125}
126
127#[cfg(feature = "cubecl")]
128impl<T> CubeclBuffer<T> {
129 /// Create a CubeCL buffer wrapper from a handle and element count.
130 ///
131 /// # Examples
132 ///
133 /// ```
134 /// let _new = tenferro_tensor::CubeclBuffer::<f64>::new;
135 /// let _ = _new;
136 /// ```
137 pub fn new(handle: cubecl::server::Handle, len: usize) -> Self {
138 Self {
139 handle,
140 len,
141 _marker: std::marker::PhantomData,
142 }
143 }
144}
145
146/// Contiguous column-major typed tensor storage.
147///
148/// # Examples
149///
150/// ```ignore
151/// use tenferro_tensor::TypedTensor;
152///
153/// let t = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
154/// assert_eq!(t.shape, vec![2, 2]);
155/// ```
156#[derive(Clone, Debug)]
157pub struct TypedTensor<T> {
158 pub buffer: Buffer<T>,
159 pub shape: Vec<usize>,
160 pub placement: Placement,
161}
162
163/// Runtime scalar dtype tag.
164///
165/// # Examples
166///
167/// ```ignore
168/// use tenferro_tensor::DType;
169///
170/// assert_eq!(DType::F64 as u8, DType::F64 as u8);
171/// ```
172#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
173pub enum DType {
174 F32,
175 F64,
176 C32,
177 C64,
178}
179
180/// Sealed trait for scalar types that can be stored in a [`Tensor`].
181///
182/// This trait is implemented for `f64`, `f32`, [`Complex64`], and
183/// [`Complex32`].
184///
185/// # Examples
186///
187/// ```
188/// use tenferro_tensor::TensorScalar;
189///
190/// let tensor = <f64 as TensorScalar>::into_tensor(vec![2], vec![1.0, 2.0]);
191/// assert_eq!(tensor.as_slice::<f64>(), Some([1.0, 2.0].as_slice()));
192/// ```
193pub trait TensorScalar: Copy + Clone + Send + Sync + 'static + private::Sealed {
194 /// Real-valued counterpart of this scalar type.
195 type Real: TensorScalar;
196
197 /// The [`DType`] tag corresponding to this scalar type.
198 ///
199 /// # Examples
200 ///
201 /// ```
202 /// use tenferro_tensor::{DType, TensorScalar};
203 ///
204 /// assert_eq!(f64::dtype(), DType::F64);
205 /// assert_eq!(f32::dtype(), DType::F32);
206 /// ```
207 fn dtype() -> DType;
208
209 /// Wrap typed data into a [`Tensor`] enum variant.
210 fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor;
211
212 /// Try to borrow the host data from a [`Tensor`].
213 fn try_as_slice(tensor: &Tensor) -> Option<&[Self]>;
214
215 /// Try to extract a [`TypedTensor<Self>`] from a dynamic [`Tensor`].
216 ///
217 /// Returns `None` if the tensor dtype does not match `Self`.
218 ///
219 /// # Examples
220 ///
221 /// ```
222 /// use tenferro_tensor::{Tensor, TensorScalar};
223 ///
224 /// let tensor = Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]);
225 /// let typed = <f64 as TensorScalar>::try_into_typed(tensor).unwrap();
226 ///
227 /// assert_eq!(typed.as_slice(), &[1.0, 2.0]);
228 /// ```
229 fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>>;
230}
231
232mod private {
233 pub trait Sealed {}
234
235 impl Sealed for f64 {}
236 impl Sealed for f32 {}
237 impl Sealed for num_complex::Complex64 {}
238 impl Sealed for num_complex::Complex32 {}
239}
240
241impl TensorScalar for f64 {
242 type Real = f64;
243
244 fn dtype() -> DType {
245 DType::F64
246 }
247
248 fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
249 Tensor::F64(TypedTensor::from_vec(shape, data))
250 }
251
252 fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
253 match tensor {
254 Tensor::F64(t) => Some(t.host_data()),
255 _ => None,
256 }
257 }
258
259 fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
260 match tensor {
261 Tensor::F64(inner) => Some(inner),
262 _ => None,
263 }
264 }
265}
266
267impl TensorScalar for f32 {
268 type Real = f32;
269
270 fn dtype() -> DType {
271 DType::F32
272 }
273
274 fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
275 Tensor::F32(TypedTensor::from_vec(shape, data))
276 }
277
278 fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
279 match tensor {
280 Tensor::F32(t) => Some(t.host_data()),
281 _ => None,
282 }
283 }
284
285 fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
286 match tensor {
287 Tensor::F32(inner) => Some(inner),
288 _ => None,
289 }
290 }
291}
292
293impl TensorScalar for Complex64 {
294 type Real = f64;
295
296 fn dtype() -> DType {
297 DType::C64
298 }
299
300 fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
301 Tensor::C64(TypedTensor::from_vec(shape, data))
302 }
303
304 fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
305 match tensor {
306 Tensor::C64(t) => Some(t.host_data()),
307 _ => None,
308 }
309 }
310
311 fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
312 match tensor {
313 Tensor::C64(inner) => Some(inner),
314 _ => None,
315 }
316 }
317}
318
319impl TensorScalar for Complex32 {
320 type Real = f32;
321
322 fn dtype() -> DType {
323 DType::C32
324 }
325
326 fn into_tensor(shape: Vec<usize>, data: Vec<Self>) -> Tensor {
327 Tensor::C32(TypedTensor::from_vec(shape, data))
328 }
329
330 fn try_as_slice(tensor: &Tensor) -> Option<&[Self]> {
331 match tensor {
332 Tensor::C32(t) => Some(t.host_data()),
333 _ => None,
334 }
335 }
336
337 fn try_into_typed(tensor: Tensor) -> Option<TypedTensor<Self>> {
338 match tensor {
339 Tensor::C32(inner) => Some(inner),
340 _ => None,
341 }
342 }
343}
344
345/// Dynamic tensor enum over the supported scalar types.
346///
347/// # Examples
348///
349/// ```ignore
350/// use tenferro_tensor::{Tensor, TypedTensor};
351///
352/// let t = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
353/// assert_eq!(t.shape(), &[2]);
354/// ```
355#[derive(Clone, Debug)]
356pub enum Tensor {
357 F32(TypedTensor<f32>),
358 F64(TypedTensor<f64>),
359 C32(TypedTensor<Complex<f32>>),
360 C64(TypedTensor<Complex<f64>>),
361}
362
363/// Wrap an `f64` [`TypedTensor`] into the corresponding [`Tensor`] variant.
364///
365/// # Examples
366///
367/// ```
368/// use tenferro_tensor::{Tensor, TypedTensor};
369///
370/// let typed = TypedTensor::from_vec(vec![2], vec![1.0_f64, 2.0]);
371/// let tensor: Tensor = typed.into();
372/// assert_eq!(tensor.shape(), &[2]);
373/// ```
374impl From<TypedTensor<f64>> for Tensor {
375 fn from(t: TypedTensor<f64>) -> Self {
376 Tensor::F64(t)
377 }
378}
379
380/// Wrap an `f32` [`TypedTensor`] into the corresponding [`Tensor`] variant.
381///
382/// # Examples
383///
384/// ```
385/// use tenferro_tensor::{Tensor, TypedTensor};
386///
387/// let typed = TypedTensor::from_vec(vec![2], vec![1.0_f32, 2.0]);
388/// let tensor: Tensor = typed.into();
389/// assert_eq!(tensor.shape(), &[2]);
390/// ```
391impl From<TypedTensor<f32>> for Tensor {
392 fn from(t: TypedTensor<f32>) -> Self {
393 Tensor::F32(t)
394 }
395}
396
397/// Wrap a [`Complex64`] [`TypedTensor`] into the corresponding [`Tensor`]
398/// variant.
399///
400/// # Examples
401///
402/// ```
403/// use num_complex::Complex64;
404/// use tenferro_tensor::{Tensor, TypedTensor};
405///
406/// let typed = TypedTensor::from_vec(
407/// vec![1],
408/// vec![Complex64::new(1.0, 2.0)],
409/// );
410/// let tensor: Tensor = typed.into();
411/// assert_eq!(tensor.shape(), &[1]);
412/// ```
413impl From<TypedTensor<Complex<f64>>> for Tensor {
414 fn from(t: TypedTensor<Complex<f64>>) -> Self {
415 Tensor::C64(t)
416 }
417}
418
419/// Wrap a [`Complex32`] [`TypedTensor`] into the corresponding [`Tensor`]
420/// variant.
421///
422/// # Examples
423///
424/// ```
425/// use num_complex::Complex32;
426/// use tenferro_tensor::{Tensor, TypedTensor};
427///
428/// let typed = TypedTensor::from_vec(
429/// vec![1],
430/// vec![Complex32::new(1.0, 2.0)],
431/// );
432/// let tensor: Tensor = typed.into();
433/// assert_eq!(tensor.shape(), &[1]);
434/// ```
435impl From<TypedTensor<Complex<f32>>> for Tensor {
436 fn from(t: TypedTensor<Complex<f32>>) -> Self {
437 Tensor::C32(t)
438 }
439}
440
441/// Column-major strides derived from a shape.
442///
443/// # Examples
444///
445/// ```ignore
446/// use tenferro_tensor::col_major_strides;
447///
448/// assert_eq!(col_major_strides(&[2, 3]), vec![1, 2]);
449/// ```
450pub fn col_major_strides(shape: &[usize]) -> Vec<isize> {
451 if shape.is_empty() {
452 return vec![];
453 }
454 let mut strides = vec![1isize; shape.len()];
455 for i in 1..shape.len() {
456 strides[i] = strides[i - 1] * shape[i - 1] as isize;
457 }
458 strides
459}
460
461pub(crate) fn default_placement() -> Placement {
462 Placement {
463 memory_kind: MemoryKind::UnpinnedHost,
464 resident_device: None,
465 }
466}
467
468impl<T: Clone + Zero> TypedTensor<T> {
469 /// Allocate a zero-filled tensor.
470 ///
471 /// # Examples
472 ///
473 /// ```ignore
474 /// use tenferro_tensor::TypedTensor;
475 ///
476 /// let t = TypedTensor::<f64>::zeros(vec![2, 3]);
477 /// assert_eq!(t.n_elements(), 6);
478 /// ```
479 pub fn zeros(shape: Vec<usize>) -> Self {
480 let n: usize = shape.iter().product();
481 Self {
482 buffer: Buffer::Host(vec![T::zero(); n]),
483 shape,
484 placement: default_placement(),
485 }
486 }
487}
488
489impl<T: Clone + One + Zero> TypedTensor<T> {
490 /// Allocate a one-filled tensor.
491 ///
492 /// # Examples
493 ///
494 /// ```ignore
495 /// use tenferro_tensor::TypedTensor;
496 ///
497 /// let t = TypedTensor::<f64>::ones(vec![2]);
498 /// assert_eq!(t.host_data(), &[1.0, 1.0]);
499 /// ```
500 pub fn ones(shape: Vec<usize>) -> Self {
501 let n: usize = shape.iter().product();
502 Self {
503 buffer: Buffer::Host(vec![T::one(); n]),
504 shape,
505 placement: default_placement(),
506 }
507 }
508}
509
510impl<T: Clone> TypedTensor<T> {
511 /// Create a tensor from a column-major buffer.
512 ///
513 /// # Examples
514 ///
515 /// ```ignore
516 /// use tenferro_tensor::TypedTensor;
517 ///
518 /// let t = TypedTensor::<f64>::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
519 /// assert_eq!(t.get(&[1, 0]), &2.0);
520 /// ```
521 pub fn from_vec(shape: Vec<usize>, data: Vec<T>) -> Self {
522 let n: usize = shape.iter().product();
523 assert_eq!(
524 data.len(),
525 n,
526 "data length {} does not match shape product {}",
527 data.len(),
528 n
529 );
530 Self {
531 buffer: Buffer::Host(data),
532 shape,
533 placement: default_placement(),
534 }
535 }
536
537 /// Number of elements in the tensor.
538 ///
539 /// # Examples
540 ///
541 /// ```ignore
542 /// use tenferro_tensor::TypedTensor;
543 ///
544 /// let t = TypedTensor::<f64>::from_vec(vec![2, 3], vec![0.0; 6]);
545 /// assert_eq!(t.n_elements(), 6);
546 /// ```
547 pub fn n_elements(&self) -> usize {
548 self.shape.iter().product()
549 }
550
551 /// Borrow the host buffer.
552 ///
553 /// # Examples
554 ///
555 /// ```ignore
556 /// use tenferro_tensor::TypedTensor;
557 ///
558 /// let t = TypedTensor::<f64>::from_vec(vec![2], vec![1.0, 2.0]);
559 /// assert_eq!(t.host_data(), &[1.0, 2.0]);
560 /// ```
561 pub fn host_data(&self) -> &[T] {
562 match &self.buffer {
563 Buffer::Host(v) => v,
564 Buffer::Backend(_) => panic!("host_data called on backend buffer"),
565 #[cfg(feature = "cubecl")]
566 Buffer::Cubecl(_) => {
567 panic!(
568 "Cannot access GPU buffer (Buffer::Cubecl) as host data. \
569 Use cubecl::download_tensor() to transfer to CPU first."
570 )
571 }
572 }
573 }
574
575 /// View the tensor data as a flat slice.
576 ///
577 /// This is an alias for `host_data()` for API consistency with
578 /// `Tensor::as_slice`.
579 ///
580 /// # Examples
581 ///
582 /// ```
583 /// use tenferro_tensor::TypedTensor;
584 ///
585 /// let t = TypedTensor::<f64>::from_vec(vec![2], vec![1.0, 2.0]);
586 /// assert_eq!(t.as_slice(), &[1.0, 2.0]);
587 /// ```
588 pub fn as_slice(&self) -> &[T] {
589 self.host_data()
590 }
591
592 /// Mutably borrow the host buffer.
593 ///
594 /// # Examples
595 ///
596 /// ```ignore
597 /// use tenferro_tensor::TypedTensor;
598 ///
599 /// let mut t = TypedTensor::<f64>::zeros(vec![2]);
600 /// t.host_data_mut()[0] = 3.0;
601 /// assert_eq!(t.host_data(), &[3.0, 0.0]);
602 /// ```
603 pub fn host_data_mut(&mut self) -> &mut [T] {
604 match &mut self.buffer {
605 Buffer::Host(v) => v,
606 Buffer::Backend(_) => panic!("host_data_mut called on backend buffer"),
607 #[cfg(feature = "cubecl")]
608 Buffer::Cubecl(_) => {
609 panic!(
610 "Cannot access GPU buffer (Buffer::Cubecl) as host data. \
611 Use cubecl::download_tensor() to transfer to CPU first."
612 )
613 }
614 }
615 }
616
617 /// Compute the linear column-major offset for an index.
618 ///
619 /// # Examples
620 ///
621 /// ```ignore
622 /// use tenferro_tensor::TypedTensor;
623 ///
624 /// let t = TypedTensor::<f64>::zeros(vec![2, 3]);
625 /// assert_eq!(t.linear_offset(&[1, 2]), 5);
626 /// ```
627 pub fn linear_offset(&self, indices: &[usize]) -> usize {
628 assert_eq!(indices.len(), self.shape.len());
629 let mut offset = 0usize;
630 let mut stride = 1usize;
631 for (i, &idx) in indices.iter().enumerate() {
632 assert!(idx < self.shape[i], "index out of bounds");
633 offset += idx * stride;
634 stride *= self.shape[i];
635 }
636 offset
637 }
638
639 /// Borrow a single element by multi-index.
640 ///
641 /// # Examples
642 ///
643 /// ```ignore
644 /// use tenferro_tensor::TypedTensor;
645 ///
646 /// let t = TypedTensor::<f64>::from_vec(vec![2], vec![1.0, 2.0]);
647 /// assert_eq!(t.get(&[1]), &2.0);
648 /// ```
649 pub fn get(&self, indices: &[usize]) -> &T {
650 let off = self.linear_offset(indices);
651 &self.host_data()[off]
652 }
653
654 /// Mutably borrow a single element by multi-index.
655 ///
656 /// # Examples
657 ///
658 /// ```ignore
659 /// use tenferro_tensor::TypedTensor;
660 ///
661 /// let mut t = TypedTensor::<f64>::zeros(vec![1]);
662 /// *t.get_mut(&[0]) = 7.0;
663 /// assert_eq!(t.host_data(), &[7.0]);
664 /// ```
665 pub fn get_mut(&mut self, indices: &[usize]) -> &mut T {
666 let off = self.linear_offset(indices);
667 &mut self.host_data_mut()[off]
668 }
669}
670
671/// Element-wise conjugation helper.
672pub trait ConjElem {
673 fn conj_elem(self) -> Self;
674}
675
676impl ConjElem for f32 {
677 fn conj_elem(self) -> Self {
678 self
679 }
680}
681
682impl ConjElem for f64 {
683 fn conj_elem(self) -> Self {
684 self
685 }
686}
687
688impl ConjElem for Complex<f32> {
689 fn conj_elem(self) -> Self {
690 self.conj()
691 }
692}
693
694impl ConjElem for Complex<f64> {
695 fn conj_elem(self) -> Self {
696 self.conj()
697 }
698}
699
700macro_rules! dispatch_tensor {
701 ($self:expr, $inner:ident => $body:expr) => {
702 match $self {
703 Tensor::F32($inner) => Tensor::F32($body),
704 Tensor::F64($inner) => Tensor::F64($body),
705 Tensor::C32($inner) => Tensor::C32($body),
706 Tensor::C64($inner) => Tensor::C64($body),
707 }
708 };
709}
710
711macro_rules! dispatch_binary {
712 ($lhs:expr, $rhs:expr, |$a:ident, $b:ident| $body:expr) => {
713 match ($lhs, $rhs) {
714 (Tensor::F32($a), Tensor::F32($b)) => Tensor::F32($body),
715 (Tensor::F64($a), Tensor::F64($b)) => Tensor::F64($body),
716 (Tensor::C32($a), Tensor::C32($b)) => Tensor::C32($body),
717 (Tensor::C64($a), Tensor::C64($b)) => Tensor::C64($body),
718 _ => panic!("dtype mismatch in binary op"),
719 }
720 };
721}
722
723pub(crate) use dispatch_binary;
724pub(crate) use dispatch_tensor;
725
726impl Tensor {
727 /// Create a tensor from a shape and flat data.
728 ///
729 /// This is the `Tensor`-level equivalent of `TypedTensor::<T>::from_vec`.
730 ///
731 /// # Examples
732 ///
733 /// ```
734 /// use tenferro_tensor::Tensor;
735 ///
736 /// let t = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
737 /// assert_eq!(t.shape(), &[2, 3]);
738 /// assert_eq!(t.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
739 /// ```
740 pub fn from_vec<T: TensorScalar>(shape: Vec<usize>, data: Vec<T>) -> Self {
741 T::into_tensor(shape, data)
742 }
743
744 /// Tensor shape.
745 ///
746 /// # Examples
747 ///
748 /// ```ignore
749 /// use tenferro_tensor::{Tensor, TypedTensor};
750 ///
751 /// let t = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
752 /// assert_eq!(t.shape(), &[2]);
753 /// ```
754 pub fn shape(&self) -> &[usize] {
755 match self {
756 Tensor::F32(t) => &t.shape,
757 Tensor::F64(t) => &t.shape,
758 Tensor::C32(t) => &t.shape,
759 Tensor::C64(t) => &t.shape,
760 }
761 }
762
763 /// Tensor dtype tag.
764 ///
765 /// # Examples
766 ///
767 /// ```ignore
768 /// use tenferro_tensor::{DType, Tensor, TypedTensor};
769 ///
770 /// let t = Tensor::F64(TypedTensor::from_vec(vec![], vec![1.0]));
771 /// assert_eq!(t.dtype(), DType::F64);
772 /// ```
773 pub fn dtype(&self) -> DType {
774 match self {
775 Tensor::F32(_) => DType::F32,
776 Tensor::F64(_) => DType::F64,
777 Tensor::C32(_) => DType::C32,
778 Tensor::C64(_) => DType::C64,
779 }
780 }
781
782 /// Try to borrow the host data as a typed slice.
783 ///
784 /// Returns `None` if the tensor dtype does not match `T`.
785 ///
786 /// # Examples
787 ///
788 /// ```
789 /// use tenferro_tensor::{Tensor, TypedTensor};
790 ///
791 /// let t = Tensor::F64(TypedTensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]));
792 /// assert_eq!(t.as_slice::<f64>(), Some([1.0, 2.0, 3.0].as_slice()));
793 /// assert_eq!(t.as_slice::<f32>(), None);
794 /// ```
795 pub fn as_slice<T: TensorScalar>(&self) -> Option<&[T]> {
796 T::try_as_slice(self)
797 }
798
799 /// Singular value decomposition: `A = U diag(S) Vt`.
800 ///
801 /// Returns `(U, S, Vt)` using the thin/economy SVD.
802 ///
803 /// # Examples
804 ///
805 /// ```
806 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
807 ///
808 /// let mut ctx = CpuBackend::new();
809 /// let a = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
810 /// let (u, s, vt) = a.svd(&mut ctx).unwrap();
811 ///
812 /// assert_eq!(u.shape(), &[3, 2]);
813 /// assert_eq!(s.shape(), &[2]);
814 /// assert_eq!(vt.shape(), &[2, 2]);
815 /// ```
816 pub fn svd(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self, Self)> {
817 ctx.with_exec_session(|exec| unpack_three("svd", exec.svd(self)?))
818 }
819
820 /// QR decomposition: `A = Q R`.
821 ///
822 /// Returns `(Q, R)` using the thin/economy QR decomposition.
823 ///
824 /// # Examples
825 ///
826 /// ```
827 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
828 ///
829 /// let mut ctx = CpuBackend::new();
830 /// let a = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
831 /// let (q, r) = a.qr(&mut ctx).unwrap();
832 ///
833 /// assert_eq!(q.shape(), &[3, 2]);
834 /// assert_eq!(r.shape(), &[2, 2]);
835 /// ```
836 pub fn qr(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self)> {
837 ctx.with_exec_session(|exec| unpack_two("qr", exec.qr(self)?))
838 }
839
840 /// LU decomposition with partial pivoting: `P A = L U`.
841 ///
842 /// Returns `(P, L, U, parity)`.
843 ///
844 /// # Examples
845 ///
846 /// ```
847 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
848 ///
849 /// let mut ctx = CpuBackend::new();
850 /// let a = Tensor::from_vec(vec![2, 2], vec![0.0_f64, 1.0, 1.0, 0.0]);
851 /// let (p, l, u, parity) = a.lu(&mut ctx).unwrap();
852 ///
853 /// assert_eq!(p.shape(), &[2, 2]);
854 /// assert_eq!(l.shape(), &[2, 2]);
855 /// assert_eq!(u.shape(), &[2, 2]);
856 /// assert_eq!(parity.shape(), &[] as &[usize]);
857 /// ```
858 pub fn lu(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self, Self, Self)> {
859 ctx.with_exec_session(|exec| unpack_four("lu", exec.lu(self)?))
860 }
861
862 /// Cholesky decomposition: `A = L L^T` or `A = L L^H` for complex inputs.
863 ///
864 /// Returns the lower-triangular factor `L`.
865 ///
866 /// # Examples
867 ///
868 /// ```
869 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
870 ///
871 /// let mut ctx = CpuBackend::new();
872 /// let a = Tensor::from_vec(vec![2, 2], vec![4.0_f64, 1.0, 1.0, 3.0]);
873 /// let l = a.cholesky(&mut ctx).unwrap();
874 ///
875 /// assert_eq!(l.shape(), &[2, 2]);
876 /// ```
877 pub fn cholesky(&self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
878 ctx.with_exec_session(|exec| exec.cholesky(self))
879 }
880
881 /// Symmetric or Hermitian eigendecomposition: `A = V diag(W) V^T`.
882 ///
883 /// Returns `(eigenvalues, eigenvectors)`.
884 ///
885 /// # Examples
886 ///
887 /// ```
888 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
889 ///
890 /// let mut ctx = CpuBackend::new();
891 /// let a = Tensor::from_vec(vec![2, 2], vec![4.0_f64, 1.0, 1.0, 3.0]);
892 /// let (w, v) = a.eigh(&mut ctx).unwrap();
893 ///
894 /// assert_eq!(w.shape(), &[2]);
895 /// assert_eq!(v.shape(), &[2, 2]);
896 /// ```
897 pub fn eigh(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self)> {
898 ctx.with_exec_session(|exec| unpack_two("eigh", exec.eigh(self)?))
899 }
900
901 /// General eigendecomposition.
902 ///
903 /// Returns `(eigenvalues, eigenvectors)`.
904 ///
905 /// # Examples
906 ///
907 /// ```
908 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
909 ///
910 /// let mut ctx = CpuBackend::new();
911 /// let a = Tensor::from_vec(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 3.0]);
912 /// let (w, v) = a.eig(&mut ctx).unwrap();
913 ///
914 /// assert_eq!(w.shape(), &[2]);
915 /// assert_eq!(v.shape(), &[2, 2]);
916 /// ```
917 pub fn eig(&self, ctx: &mut impl TensorBackend) -> crate::Result<(Self, Self)> {
918 ctx.with_exec_session(|exec| unpack_two("eig", exec.eig(self)?))
919 }
920
921 /// Solve `A x = b` for `x`.
922 ///
923 /// # Examples
924 ///
925 /// ```
926 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
927 ///
928 /// let mut ctx = CpuBackend::new();
929 /// let a = Tensor::from_vec(vec![2, 2], vec![2.0_f64, 1.0, 1.0, 2.0]);
930 /// let b = Tensor::from_vec(vec![2, 1], vec![1.0_f64, 0.0]);
931 /// let x = a.solve(&b, &mut ctx).unwrap();
932 ///
933 /// assert_eq!(x.shape(), &[2, 1]);
934 /// ```
935 pub fn solve(&self, b: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
936 ctx.solve(self, b)
937 }
938
939 /// Solve a triangular system.
940 ///
941 /// # Examples
942 ///
943 /// ```
944 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
945 ///
946 /// let mut ctx = CpuBackend::new();
947 /// let a = Tensor::from_vec(vec![2, 2], vec![2.0_f64, 1.0, 0.0, 3.0]);
948 /// let b = Tensor::from_vec(vec![2, 1], vec![2.0_f64, 7.0]);
949 /// let x = a
950 /// .triangular_solve(&b, true, true, false, false, &mut ctx)
951 /// .unwrap();
952 ///
953 /// assert_eq!(x.shape(), &[2, 1]);
954 /// ```
955 pub fn triangular_solve(
956 &self,
957 b: &Self,
958 left_side: bool,
959 lower: bool,
960 transpose_a: bool,
961 unit_diagonal: bool,
962 ctx: &mut impl TensorBackend,
963 ) -> crate::Result<Self> {
964 ctx.with_exec_session(|exec| {
965 exec.triangular_solve(self, b, left_side, lower, transpose_a, unit_diagonal)
966 })
967 }
968
969 /// Elementwise addition.
970 ///
971 /// # Examples
972 ///
973 /// ```
974 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
975 ///
976 /// let mut ctx = CpuBackend::new();
977 /// let a = Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
978 /// let b = Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]);
979 /// let c = a.add(&b, &mut ctx).unwrap();
980 ///
981 /// assert_eq!(c.as_slice::<f64>().unwrap(), &[5.0, 7.0, 9.0]);
982 /// ```
983 pub fn add(&self, other: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
984 ctx.with_exec_session(|exec| exec.add(self, other))
985 }
986
987 /// Elementwise multiplication.
988 ///
989 /// # Examples
990 ///
991 /// ```
992 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
993 ///
994 /// let mut ctx = CpuBackend::new();
995 /// let a = Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
996 /// let b = Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]);
997 /// let c = a.mul(&b, &mut ctx).unwrap();
998 ///
999 /// assert_eq!(c.as_slice::<f64>().unwrap(), &[4.0, 10.0, 18.0]);
1000 /// ```
1001 pub fn mul(&self, other: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1002 ctx.with_exec_session(|exec| exec.mul(self, other))
1003 }
1004
1005 /// Negation.
1006 ///
1007 /// # Examples
1008 ///
1009 /// ```
1010 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1011 ///
1012 /// let mut ctx = CpuBackend::new();
1013 /// let a = Tensor::from_vec(vec![3], vec![1.0_f64, -2.0, 3.0]);
1014 /// let b = a.neg(&mut ctx).unwrap();
1015 ///
1016 /// assert_eq!(b.as_slice::<f64>().unwrap(), &[-1.0, 2.0, -3.0]);
1017 /// ```
1018 pub fn neg(&self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1019 ctx.with_exec_session(|exec| exec.neg(self))
1020 }
1021
1022 /// Transpose with an explicit permutation.
1023 ///
1024 /// # Examples
1025 ///
1026 /// ```
1027 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1028 ///
1029 /// let mut ctx = CpuBackend::new();
1030 /// let a = Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]);
1031 /// let b = a.transpose(&[1, 0], &mut ctx).unwrap();
1032 ///
1033 /// assert_eq!(b.shape(), &[2, 2]);
1034 /// assert_eq!(b.as_slice::<f64>().unwrap(), &[1.0, 3.0, 2.0, 4.0]);
1035 /// ```
1036 pub fn transpose(&self, perm: &[usize], ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1037 ctx.with_exec_session(|exec| exec.transpose(self, perm))
1038 }
1039
1040 /// Reshape to a new shape with the same number of elements.
1041 ///
1042 /// # Examples
1043 ///
1044 /// ```
1045 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1046 ///
1047 /// let mut ctx = CpuBackend::new();
1048 /// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1049 /// let b = a.reshape(&[3, 2], &mut ctx).unwrap();
1050 ///
1051 /// assert_eq!(b.shape(), &[3, 2]);
1052 /// assert_eq!(b.as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1053 /// ```
1054 pub fn reshape(&self, shape: &[usize], ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1055 ctx.with_exec_session(|exec| exec.reshape(self, shape))
1056 }
1057
1058 /// Reduce sum over the specified axes.
1059 ///
1060 /// # Examples
1061 ///
1062 /// ```
1063 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1064 ///
1065 /// let mut ctx = CpuBackend::new();
1066 /// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1067 /// let b = a.reduce_sum(&[1], &mut ctx).unwrap();
1068 ///
1069 /// assert_eq!(b.shape(), &[2]);
1070 /// assert_eq!(b.as_slice::<f64>().unwrap(), &[9.0, 12.0]);
1071 /// ```
1072 pub fn reduce_sum(&self, axes: &[usize], ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1073 ctx.with_exec_session(|exec| exec.reduce_sum(self, axes))
1074 }
1075
1076 /// Matrix multiplication for rank-2 tensors.
1077 ///
1078 /// This is a convenience wrapper around `dot_general`.
1079 ///
1080 /// # Examples
1081 ///
1082 /// ```
1083 /// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
1084 ///
1085 /// let mut ctx = CpuBackend::new();
1086 /// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1087 /// let b = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1088 /// let c = a.matmul(&b, &mut ctx).unwrap();
1089 ///
1090 /// assert_eq!(c.shape(), &[2, 2]);
1091 /// assert_eq!(c.as_slice::<f64>().unwrap(), &[22.0, 28.0, 49.0, 64.0]);
1092 /// ```
1093 pub fn matmul(&self, other: &Self, ctx: &mut impl TensorBackend) -> crate::Result<Self> {
1094 let config = DotGeneralConfig {
1095 lhs_contracting_dims: vec![1],
1096 rhs_contracting_dims: vec![0],
1097 lhs_batch_dims: vec![],
1098 rhs_batch_dims: vec![],
1099 lhs_rank: self.shape().len(),
1100 rhs_rank: other.shape().len(),
1101 };
1102 ctx.with_exec_session(|exec| exec.dot_general(self, other, &config))
1103 }
1104}
1105
1106pub(crate) fn flat_to_multi(mut flat: usize, shape: &[usize], out: &mut [usize]) {
1107 for i in 0..shape.len() {
1108 out[i] = flat % shape[i];
1109 flat /= shape[i];
1110 }
1111}
1112
1113fn invalid_output_count(op: &'static str, expected: usize, actual: usize) -> crate::Error {
1114 crate::Error::BackendFailure {
1115 op,
1116 message: format!("expected {expected} output tensors, got {actual}"),
1117 }
1118}
1119
1120fn unpack_two(op: &'static str, results: Vec<Tensor>) -> crate::Result<(Tensor, Tensor)> {
1121 let actual = results.len();
1122 let mut iter = results.into_iter();
1123 match (iter.next(), iter.next(), iter.next()) {
1124 (Some(a), Some(b), None) => Ok((a, b)),
1125 _ => Err(invalid_output_count(op, 2, actual)),
1126 }
1127}
1128
1129fn unpack_three(op: &'static str, results: Vec<Tensor>) -> crate::Result<(Tensor, Tensor, Tensor)> {
1130 let actual = results.len();
1131 let mut iter = results.into_iter();
1132 match (iter.next(), iter.next(), iter.next(), iter.next()) {
1133 (Some(a), Some(b), Some(c), None) => Ok((a, b, c)),
1134 _ => Err(invalid_output_count(op, 3, actual)),
1135 }
1136}
1137
1138fn unpack_four(
1139 op: &'static str,
1140 results: Vec<Tensor>,
1141) -> crate::Result<(Tensor, Tensor, Tensor, Tensor)> {
1142 let actual = results.len();
1143 let mut iter = results.into_iter();
1144 match (
1145 iter.next(),
1146 iter.next(),
1147 iter.next(),
1148 iter.next(),
1149 iter.next(),
1150 ) {
1151 (Some(a), Some(b), Some(c), Some(d), None) => Ok((a, b, c, d)),
1152 _ => Err(invalid_output_count(op, 4, actual)),
1153 }
1154}