tenferro_internal_frontend_core/
structured_tensor.rs1use std::ops::{Deref, DerefMut};
2
3use chainrules_core::Differentiable;
4use tenferro_algebra::{Conjugate, Scalar};
5use tenferro_device::{ComputeDevice, LogicalMemorySpace};
6use tenferro_internal_error::Result;
7use tenferro_tensor::Tensor;
8
9#[derive(Debug, Clone)]
25pub struct StructuredTensor<T: Scalar>(pub tenferro_tensor::StructuredTensor<T>);
26
27impl<T: Scalar> StructuredTensor<T> {
28 pub fn with_payload_like(&self, payload: Tensor<T>) -> Result<Self> {
29 Ok(Self(self.0.with_payload_like(payload)?))
30 }
31
32 pub fn into_payload(self) -> Tensor<T> {
33 self.0.into_payload()
34 }
35
36 pub fn permute_logical(&self, perm: &[usize]) -> Result<Self> {
37 Ok(Self(self.0.permute_logical(perm)?))
38 }
39
40 pub fn conj(&self) -> Self
41 where
42 T: Conjugate,
43 {
44 Self(self.0.conj())
45 }
46
47 pub fn to_dense(&self) -> Result<Tensor<T>> {
48 Ok(self.0.to_dense()?)
49 }
50
51 pub fn memory_space(&self) -> LogicalMemorySpace {
52 self.payload().logical_memory_space()
53 }
54
55 pub fn preferred_compute_device(&self) -> Option<ComputeDevice> {
56 self.payload().preferred_compute_device()
57 }
58
59 pub fn set_preferred_compute_device(&mut self, device: Option<ComputeDevice>) {
60 let mut payload = self.payload().clone();
61 payload.set_preferred_compute_device(device);
62 *self = Self(tenferro_tensor::StructuredTensor::from_validated_parts(
63 self.logical_dims().to_vec(),
64 self.axis_classes().to_vec(),
65 payload,
66 ));
67 }
68
69 pub fn to_memory_space_async(&self, target: LogicalMemorySpace) -> Result<Self> {
70 let payload = self.payload().to_memory_space_async(target)?;
71 Ok(Self(self.0.with_payload_like(payload)?))
72 }
73
74 pub fn wait(&self) {
75 self.payload().wait();
76 }
77
78 pub fn is_ready(&self) -> bool {
79 self.payload().is_ready()
80 }
81}
82
83impl<T> Differentiable for StructuredTensor<T>
84where
85 T: Scalar,
86{
87 type Tangent = StructuredTensor<T>;
88
89 fn zero_tangent(&self) -> Self::Tangent {
90 StructuredTensor(tenferro_tensor::StructuredTensor::from_validated_parts(
91 self.logical_dims().to_vec(),
92 self.axis_classes().to_vec(),
93 self.payload().zero_tangent(),
94 ))
95 }
96
97 fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent {
98 assert_eq!(
99 a.logical_dims(),
100 b.logical_dims(),
101 "StructuredTensor::accumulate_tangent requires matching logical dims"
102 );
103 assert_eq!(
104 a.axis_classes(),
105 b.axis_classes(),
106 "StructuredTensor::accumulate_tangent requires matching axis classes"
107 );
108 let logical_dims = a.logical_dims().to_vec();
109 let axis_classes = a.axis_classes().to_vec();
110 let payload = Tensor::<T>::accumulate_tangent(a.0.into_payload(), b.payload());
111 StructuredTensor(tenferro_tensor::StructuredTensor::from_validated_parts(
112 logical_dims,
113 axis_classes,
114 payload,
115 ))
116 }
117
118 fn num_elements(&self) -> usize {
119 self.logical_dims().iter().product()
120 }
121
122 fn seed_cotangent(&self) -> Self::Tangent {
123 StructuredTensor(tenferro_tensor::StructuredTensor::from_validated_parts(
124 self.logical_dims().to_vec(),
125 self.axis_classes().to_vec(),
126 self.payload().seed_cotangent(),
127 ))
128 }
129}
130
131impl<T: Scalar> Deref for StructuredTensor<T> {
132 type Target = tenferro_tensor::StructuredTensor<T>;
133
134 fn deref(&self) -> &Self::Target {
135 &self.0
136 }
137}
138
139impl<T: Scalar> DerefMut for StructuredTensor<T> {
140 fn deref_mut(&mut self) -> &mut Self::Target {
141 &mut self.0
142 }
143}
144
145impl<T: Scalar> From<tenferro_tensor::StructuredTensor<T>> for StructuredTensor<T> {
146 fn from(inner: tenferro_tensor::StructuredTensor<T>) -> Self {
147 Self(inner)
148 }
149}
150
151impl<T: Scalar> From<Tensor<T>> for StructuredTensor<T> {
152 fn from(tensor: Tensor<T>) -> Self {
153 Self(tenferro_tensor::StructuredTensor::from_dense(tensor))
154 }
155}
156
157impl<T: Scalar> AsRef<Tensor<T>> for StructuredTensor<T> {
158 fn as_ref(&self) -> &Tensor<T> {
159 self.0.payload()
160 }
161}