tenferro_internal_ad_surface/core/dynamic/
tensor.rs1use chainrules_core::AutodiffError;
2use std::fmt;
3use tenferro_internal_ad_linalg::{
4 cholesky_dyn_value, det_dyn_value, eig_dyn_value, eigen_dyn_value, inv_dyn_value,
5 lstsq_dyn_values, lu_dyn_value, matrix_exp_dyn_value, norm_dyn_value, pinv_dyn_value,
6 qr_dyn_value, slogdet_dyn_value, solve_dyn_values, solve_triangular_dyn_value, svd_dyn_value,
7};
8use tenferro_internal_ad_ops::{add_dyn_values, einsum_dyn_values, exp_dyn_value, sum_dyn_value};
9use tenferro_internal_frontend_core::tensor_ops::tensor_element;
10use tenferro_internal_frontend_core::{DynTensor, DynTensorTyped, ScalarType, StructuredTensor};
11use tenferro_linalg::{LuPivot, MatrixNormOrd, NormKind, SvdOptions, VectorNormOrd};
12use tenferro_tensor::{MemoryOrder, Tensor as DenseTensor};
13
14use super::{EigResult, EighResult, LstsqResult, LuResult, QrResult, SlogdetResult, SvdResult};
15use crate::{jvp, Error, Result};
16
17pub struct Tensor {
18 inner: tidu::Value<DynTensor>,
19}
20
21impl Tensor {
22 pub fn new(primal: DynTensor) -> Self {
23 Self {
24 inner: tidu::Value::new(primal),
25 }
26 }
27
28 pub(crate) fn from_value(inner: tidu::Value<DynTensor>) -> Self {
29 Self { inner }
30 }
31
32 pub(crate) fn value(&self) -> &tidu::Value<DynTensor> {
33 &self.inner
34 }
35
36 pub(crate) fn primal(&self) -> &DynTensor {
37 self.inner.primal()
38 }
39
40 pub(crate) fn forward_id(&self) -> usize {
41 jvp::forward_id(self)
42 }
43
44 pub fn from_slice<T>(data: &[T], dims: &[usize]) -> Result<Self>
45 where
46 T: DynTensorTyped + Copy,
47 {
48 let payload = DenseTensor::<T>::from_slice(data, dims, MemoryOrder::ColumnMajor)?;
49 Ok(Self::from(payload))
50 }
51
52 pub fn scalar_type(&self) -> ScalarType {
53 self.primal().scalar_type()
54 }
55
56 pub fn dims(&self) -> &[usize] {
57 self.primal().dims()
58 }
59
60 pub fn ndim(&self) -> usize {
61 self.primal().ndim()
62 }
63
64 pub fn len(&self) -> usize {
65 self.primal().len()
66 }
67
68 pub fn is_empty(&self) -> bool {
69 self.primal().is_empty()
70 }
71
72 pub fn axis_classes(&self) -> &[usize] {
73 self.primal().axis_classes()
74 }
75
76 pub fn is_dense(&self) -> bool {
77 self.primal().is_dense()
78 }
79
80 pub fn is_diag(&self) -> bool {
81 self.primal().is_diag()
82 }
83
84 pub fn requires_grad(&self) -> bool {
85 self.inner.requires_grad()
86 }
87
88 pub fn with_requires_grad(self, enabled: bool) -> Self {
89 Self {
90 inner: self.inner.with_requires_grad(enabled),
91 }
92 }
93
94 pub fn detach(&self) -> Self {
95 Self::new(self.primal().clone())
96 }
97
98 pub fn to_dense(&self) -> Result<Self> {
99 Ok(Self::new(self.primal().to_dense()?))
100 }
101
102 pub fn grad(&self) -> Result<Option<Self>> {
103 Ok(self.inner.grad()?.map(Self::new))
104 }
105
106 pub fn zero_grad(&self) -> Result<()> {
107 Ok(self.inner.zero_grad()?)
108 }
109
110 pub fn backward(&self) -> Result<()> {
111 Ok(self.inner.backward()?)
112 }
113
114 pub fn backward_with_seed(&self, seed: &Self) -> Result<()> {
115 Ok(self.inner.backward_with_seed(seed.primal().clone())?)
116 }
117
118 pub fn shares_reverse_graph(&self, other: &Self) -> bool {
119 self.inner.shares_reverse_graph(&other.inner)
120 }
121
122 pub fn add(&self, rhs: &Self) -> Result<Self> {
123 let output = Self::from_value(add_dyn_values(self.value(), rhs.value())?);
124 jvp::add_tangent(self, rhs, &output)?;
125 Ok(output)
126 }
127
128 pub fn exp(&self) -> Result<Self> {
129 let output = Self::from_value(exp_dyn_value(self.value())?);
130 jvp::exp_tangent(self, &output)?;
131 Ok(output)
132 }
133
134 pub fn sum(&self) -> Result<Self> {
135 let output = Self::from_value(sum_dyn_value(self.value())?);
136 jvp::sum_tangent(self, &output)?;
137 Ok(output)
138 }
139
140 pub fn einsum(subscripts: &str, operands: &[&Self]) -> Result<Self> {
141 let values = operands
142 .iter()
143 .map(|tensor| tensor.value())
144 .collect::<Vec<_>>();
145 let output = Self::from_value(einsum_dyn_values(subscripts, &values)?);
146 jvp::einsum_tangent(subscripts, operands, &output)?;
147 Ok(output)
148 }
149
150 pub fn solve(&self, rhs: &Self) -> Result<Self> {
151 let output = Self::from_value(solve_dyn_values(self.value(), rhs.value())?);
152 jvp::solve_tangent(self, rhs, &output)?;
153 Ok(output)
154 }
155
156 pub fn lstsq(&self, rhs: &Self) -> Result<LstsqResult> {
157 let result = lstsq_dyn_values(self.value(), rhs.value())?;
158 let solution = Self::from_value(result.solution);
159 let residuals = Self::from_value(result.residuals);
160 jvp::lstsq_tangents(self, rhs, &solution, &residuals)?;
161 Ok(LstsqResult {
162 solution,
163 residuals,
164 rank: result.rank,
165 singular_values: Tensor::from(result.singular_values),
166 })
167 }
168
169 pub fn solve_triangular(&self, rhs: &Self, upper: bool) -> Result<Self> {
170 let output = Self::from_value(solve_triangular_dyn_value(
171 self.value(),
172 rhs.value(),
173 upper,
174 )?);
175 jvp::solve_triangular_tangent(self, rhs, &output, upper)?;
176 Ok(output)
177 }
178
179 pub fn det(&self) -> Result<Self> {
180 let output = Self::from_value(det_dyn_value(self.value())?);
181 jvp::det_tangent(self, &output)?;
182 Ok(output)
183 }
184
185 pub fn inv(&self) -> Result<Self> {
186 let output = Self::from_value(inv_dyn_value(self.value())?);
187 jvp::inv_tangent(self, &output)?;
188 Ok(output)
189 }
190
191 pub fn slogdet(&self) -> Result<SlogdetResult> {
192 let result: SlogdetResult = slogdet_dyn_value(self.value())?.into();
193 jvp::slogdet_tangents(self, &result.sign, &result.logabsdet)?;
194 Ok(result)
195 }
196
197 pub fn cholesky(&self) -> Result<Self> {
198 let output = Self::from_value(cholesky_dyn_value(self.value())?);
199 jvp::cholesky_tangent(self, &output)?;
200 Ok(output)
201 }
202
203 pub fn lu(&self, pivot: LuPivot) -> Result<LuResult> {
204 let result: LuResult = lu_dyn_value(self.value(), pivot)?.into();
205 jvp::lu_tangents(self, &result.p, &result.l, &result.u, pivot)?;
206 Ok(result)
207 }
208
209 pub fn norm(&self, kind: NormKind) -> Result<Self> {
210 let output = Self::from_value(norm_dyn_value(self.value(), kind)?);
211 jvp::norm_tangent(self, &output, kind)?;
212 Ok(output)
213 }
214
215 pub fn vector_norm(
216 &self,
217 ord: VectorNormOrd,
218 dim: Option<&[isize]>,
219 keepdim: bool,
220 ) -> Result<Self> {
221 validate_vector_norm_request(self.ndim(), dim, keepdim)?;
222 let kind = jvp::map_vector_norm_ord(ord)?;
223 let output = Self::from_value(norm_dyn_value(self.value(), kind)?);
224 jvp::vector_norm_tangent(self, &output, ord)?;
225 Ok(output)
226 }
227
228 pub fn matrix_norm(
229 &self,
230 ord: MatrixNormOrd,
231 dim: Option<(isize, isize)>,
232 keepdim: bool,
233 ) -> Result<Self> {
234 validate_matrix_norm_request(self.ndim(), dim, keepdim)?;
235 let kind = jvp::map_matrix_norm_ord(ord)?;
236 let output = Self::from_value(norm_dyn_value(self.value(), kind)?);
237 jvp::matrix_norm_tangent(self, &output, ord)?;
238 Ok(output)
239 }
240
241 pub fn qr(&self) -> Result<QrResult> {
242 crate::with_default_runtime(|_| Ok(()))?;
243 let result: QrResult = qr_dyn_value(self.value())?.into();
244 jvp::qr_tangents(self, &result.q, &result.r)?;
245 Ok(result)
246 }
247
248 pub fn svd(&self, options: Option<SvdOptions>) -> Result<SvdResult> {
249 let result: SvdResult = svd_dyn_value(self.value(), options.clone())?.into();
250 jvp::svd_tangents(self, &result.u, &result.s, &result.vt, options)?;
251 Ok(result)
252 }
253
254 pub fn eig(&self) -> Result<EigResult> {
255 let result: EigResult = eig_dyn_value(self.value())?.into();
256 jvp::eig_tangents(self, &result.values, &result.vectors)?;
257 Ok(result)
258 }
259
260 pub fn eigh(&self) -> Result<EighResult> {
261 let result: EighResult = eigen_dyn_value(self.value())?.into();
262 jvp::eigen_tangents(self, &result.values, &result.vectors)?;
263 Ok(result)
264 }
265
266 pub fn pinv(&self, rcond: Option<f64>) -> Result<Self> {
267 let output = Self::from_value(pinv_dyn_value(self.value(), rcond)?);
268 jvp::pinv_tangent(self, &output, rcond)?;
269 Ok(output)
270 }
271
272 pub fn matrix_exp(&self) -> Result<Self> {
273 let output = Self::from_value(matrix_exp_dyn_value(self.value())?);
274 jvp::matrix_exp_tangent(self, &output)?;
275 Ok(output)
276 }
277
278 pub fn try_to_vec<T>(&self) -> Result<Vec<T>>
279 where
280 T: DynTensorTyped + Copy,
281 {
282 let structured = T::structured_ref(self.primal()).ok_or_else(|| {
283 invalid_argument(format!(
284 "dtype mismatch in try_to_vec: tensor={:?}",
285 self.scalar_type()
286 ))
287 })?;
288 let dense = structured.to_dense()?;
289 let contiguous = dense.contiguous(MemoryOrder::ColumnMajor);
290 let slice = contiguous
291 .buffer()
292 .as_slice()
293 .ok_or_else(|| invalid_argument("try_to_vec requires host-accessible dense payload"))?;
294 Ok(slice.to_vec())
295 }
296
297 pub fn try_get<T>(&self, index: &[usize]) -> Result<T>
298 where
299 T: DynTensorTyped + Copy,
300 {
301 let structured = T::structured_ref(self.primal()).ok_or_else(|| {
302 invalid_argument(format!(
303 "dtype mismatch in try_get: tensor={:?}",
304 self.scalar_type()
305 ))
306 })?;
307 tensor_element(structured.payload(), index)
308 }
309}
310
311impl fmt::Debug for Tensor {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 f.debug_struct("Tensor")
314 .field("scalar_type", &self.scalar_type())
315 .field("dims", &self.dims())
316 .field("requires_grad", &self.requires_grad())
317 .finish()
318 }
319}
320
321impl<T> From<DenseTensor<T>> for Tensor
322where
323 T: DynTensorTyped + Copy,
324{
325 fn from(value: DenseTensor<T>) -> Self {
326 Self::from(StructuredTensor::from(value))
327 }
328}
329
330impl<T> From<StructuredTensor<T>> for Tensor
331where
332 T: DynTensorTyped + Copy,
333{
334 fn from(value: StructuredTensor<T>) -> Self {
335 Self::new(T::into_dyn(value))
336 }
337}
338
339impl From<DynTensor> for Tensor {
340 fn from(value: DynTensor) -> Self {
341 Self::new(value)
342 }
343}
344
345fn invalid_argument(message: impl Into<String>) -> Error {
346 AutodiffError::InvalidArgument(message.into()).into()
347}
348
349fn validate_vector_norm_request(ndim: usize, dim: Option<&[isize]>, keepdim: bool) -> Result<()> {
350 if keepdim {
351 return Err(invalid_argument(
352 "vector_norm currently supports keepdim=false only",
353 ));
354 }
355 if ndim != 1 {
356 return Err(invalid_argument(format!(
357 "vector_norm currently expects a rank-1 tensor, got ndim={ndim}",
358 )));
359 }
360 if dim.is_some() {
361 return Err(invalid_argument(
362 "vector_norm currently supports dim=None only",
363 ));
364 }
365 Ok(())
366}
367
368fn validate_matrix_norm_request(
369 ndim: usize,
370 dim: Option<(isize, isize)>,
371 keepdim: bool,
372) -> Result<()> {
373 if keepdim {
374 return Err(invalid_argument(
375 "matrix_norm currently supports keepdim=false only",
376 ));
377 }
378 if ndim != 2 {
379 return Err(invalid_argument(format!(
380 "matrix_norm currently expects a rank-2 tensor, got ndim={ndim}",
381 )));
382 }
383 if let Some(dim) = dim {
384 if dim != (0, 1) && dim != (1, 0) {
385 return Err(invalid_argument(format!(
386 "matrix_norm currently supports dim=(0, 1) only, got {dim:?}",
387 )));
388 }
389 }
390 Ok(())
391}