tenferro_internal_frontend_core/
structured_tensor.rs

1use std::ops::{Deref, DerefMut};
2
3use chainrules_core::Differentiable;
4use tenferro_algebra::{Conjugate, Scalar};
5use tenferro_device::{ComputeDevice, LogicalMemorySpace};
6use tenferro_internal_error::Result;
7use tenferro_tensor::Tensor;
8
9/// AD-capable structured tensor wrapper shared by dynamic tenferro frontends.
10///
11/// This newtype keeps `Differentiable` and placement helpers on top of
12/// [`tenferro_tensor::StructuredTensor<T>`].
13///
14/// # Examples
15///
16/// ```rust
17/// use tenferro_internal_frontend_core::StructuredTensor;
18/// use tenferro_tensor::{MemoryOrder, Tensor};
19///
20/// let payload = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
21/// let wrapped = StructuredTensor::from(payload);
22/// assert!(wrapped.is_dense());
23/// ```
24#[derive(Debug, Clone)]
25pub struct StructuredTensor<T: Scalar>(pub tenferro_tensor::StructuredTensor<T>);
26
27impl<T: Scalar> StructuredTensor<T> {
28    pub fn with_payload_like(&self, payload: Tensor<T>) -> Result<Self> {
29        Ok(Self(self.0.with_payload_like(payload)?))
30    }
31
32    pub fn into_payload(self) -> Tensor<T> {
33        self.0.into_payload()
34    }
35
36    pub fn permute_logical(&self, perm: &[usize]) -> Result<Self> {
37        Ok(Self(self.0.permute_logical(perm)?))
38    }
39
40    pub fn conj(&self) -> Self
41    where
42        T: Conjugate,
43    {
44        Self(self.0.conj())
45    }
46
47    pub fn to_dense(&self) -> Result<Tensor<T>> {
48        Ok(self.0.to_dense()?)
49    }
50
51    pub fn memory_space(&self) -> LogicalMemorySpace {
52        self.payload().logical_memory_space()
53    }
54
55    pub fn preferred_compute_device(&self) -> Option<ComputeDevice> {
56        self.payload().preferred_compute_device()
57    }
58
59    pub fn set_preferred_compute_device(&mut self, device: Option<ComputeDevice>) {
60        let mut payload = self.payload().clone();
61        payload.set_preferred_compute_device(device);
62        *self = Self(tenferro_tensor::StructuredTensor::from_validated_parts(
63            self.logical_dims().to_vec(),
64            self.axis_classes().to_vec(),
65            payload,
66        ));
67    }
68
69    pub fn to_memory_space_async(&self, target: LogicalMemorySpace) -> Result<Self> {
70        let payload = self.payload().to_memory_space_async(target)?;
71        Ok(Self(self.0.with_payload_like(payload)?))
72    }
73
74    pub fn wait(&self) {
75        self.payload().wait();
76    }
77
78    pub fn is_ready(&self) -> bool {
79        self.payload().is_ready()
80    }
81}
82
83impl<T> Differentiable for StructuredTensor<T>
84where
85    T: Scalar,
86{
87    type Tangent = StructuredTensor<T>;
88
89    fn zero_tangent(&self) -> Self::Tangent {
90        StructuredTensor(tenferro_tensor::StructuredTensor::from_validated_parts(
91            self.logical_dims().to_vec(),
92            self.axis_classes().to_vec(),
93            self.payload().zero_tangent(),
94        ))
95    }
96
97    fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent {
98        assert_eq!(
99            a.logical_dims(),
100            b.logical_dims(),
101            "StructuredTensor::accumulate_tangent requires matching logical dims"
102        );
103        assert_eq!(
104            a.axis_classes(),
105            b.axis_classes(),
106            "StructuredTensor::accumulate_tangent requires matching axis classes"
107        );
108        let logical_dims = a.logical_dims().to_vec();
109        let axis_classes = a.axis_classes().to_vec();
110        let payload = Tensor::<T>::accumulate_tangent(a.0.into_payload(), b.payload());
111        StructuredTensor(tenferro_tensor::StructuredTensor::from_validated_parts(
112            logical_dims,
113            axis_classes,
114            payload,
115        ))
116    }
117
118    fn num_elements(&self) -> usize {
119        self.logical_dims().iter().product()
120    }
121
122    fn seed_cotangent(&self) -> Self::Tangent {
123        StructuredTensor(tenferro_tensor::StructuredTensor::from_validated_parts(
124            self.logical_dims().to_vec(),
125            self.axis_classes().to_vec(),
126            self.payload().seed_cotangent(),
127        ))
128    }
129}
130
131impl<T: Scalar> Deref for StructuredTensor<T> {
132    type Target = tenferro_tensor::StructuredTensor<T>;
133
134    fn deref(&self) -> &Self::Target {
135        &self.0
136    }
137}
138
139impl<T: Scalar> DerefMut for StructuredTensor<T> {
140    fn deref_mut(&mut self) -> &mut Self::Target {
141        &mut self.0
142    }
143}
144
145impl<T: Scalar> From<tenferro_tensor::StructuredTensor<T>> for StructuredTensor<T> {
146    fn from(inner: tenferro_tensor::StructuredTensor<T>) -> Self {
147        Self(inner)
148    }
149}
150
151impl<T: Scalar> From<Tensor<T>> for StructuredTensor<T> {
152    fn from(tensor: Tensor<T>) -> Self {
153        Self(tenferro_tensor::StructuredTensor::from_dense(tensor))
154    }
155}
156
157impl<T: Scalar> AsRef<Tensor<T>> for StructuredTensor<T> {
158    fn as_ref(&self) -> &Tensor<T> {
159        self.0.payload()
160    }
161}