tenferro_internal_ad_surface/core/dynamic/
tensor.rs

1use chainrules_core::AutodiffError;
2use std::fmt;
3use tenferro_internal_ad_linalg::{
4    cholesky_dyn_value, det_dyn_value, eig_dyn_value, eigen_dyn_value, inv_dyn_value,
5    lstsq_dyn_values, lu_dyn_value, matrix_exp_dyn_value, norm_dyn_value, pinv_dyn_value,
6    qr_dyn_value, slogdet_dyn_value, solve_dyn_values, solve_triangular_dyn_value, svd_dyn_value,
7};
8use tenferro_internal_ad_ops::{add_dyn_values, einsum_dyn_values, exp_dyn_value, sum_dyn_value};
9use tenferro_internal_frontend_core::tensor_ops::tensor_element;
10use tenferro_internal_frontend_core::{DynTensor, DynTensorTyped, ScalarType, StructuredTensor};
11use tenferro_linalg::{LuPivot, MatrixNormOrd, NormKind, SvdOptions, VectorNormOrd};
12use tenferro_tensor::{MemoryOrder, Tensor as DenseTensor};
13
14use super::{EigResult, EighResult, LstsqResult, LuResult, QrResult, SlogdetResult, SvdResult};
15use crate::{jvp, Error, Result};
16
17pub struct Tensor {
18    inner: tidu::Value<DynTensor>,
19}
20
21impl Tensor {
22    pub fn new(primal: DynTensor) -> Self {
23        Self {
24            inner: tidu::Value::new(primal),
25        }
26    }
27
28    pub(crate) fn from_value(inner: tidu::Value<DynTensor>) -> Self {
29        Self { inner }
30    }
31
32    pub(crate) fn value(&self) -> &tidu::Value<DynTensor> {
33        &self.inner
34    }
35
36    pub(crate) fn primal(&self) -> &DynTensor {
37        self.inner.primal()
38    }
39
40    pub(crate) fn forward_id(&self) -> usize {
41        jvp::forward_id(self)
42    }
43
44    pub fn from_slice<T>(data: &[T], dims: &[usize]) -> Result<Self>
45    where
46        T: DynTensorTyped + Copy,
47    {
48        let payload = DenseTensor::<T>::from_slice(data, dims, MemoryOrder::ColumnMajor)?;
49        Ok(Self::from(payload))
50    }
51
52    pub fn scalar_type(&self) -> ScalarType {
53        self.primal().scalar_type()
54    }
55
56    pub fn dims(&self) -> &[usize] {
57        self.primal().dims()
58    }
59
60    pub fn ndim(&self) -> usize {
61        self.primal().ndim()
62    }
63
64    pub fn len(&self) -> usize {
65        self.primal().len()
66    }
67
68    pub fn is_empty(&self) -> bool {
69        self.primal().is_empty()
70    }
71
72    pub fn axis_classes(&self) -> &[usize] {
73        self.primal().axis_classes()
74    }
75
76    pub fn is_dense(&self) -> bool {
77        self.primal().is_dense()
78    }
79
80    pub fn is_diag(&self) -> bool {
81        self.primal().is_diag()
82    }
83
84    pub fn requires_grad(&self) -> bool {
85        self.inner.requires_grad()
86    }
87
88    pub fn with_requires_grad(self, enabled: bool) -> Self {
89        Self {
90            inner: self.inner.with_requires_grad(enabled),
91        }
92    }
93
94    pub fn detach(&self) -> Self {
95        Self::new(self.primal().clone())
96    }
97
98    pub fn to_dense(&self) -> Result<Self> {
99        Ok(Self::new(self.primal().to_dense()?))
100    }
101
102    pub fn grad(&self) -> Result<Option<Self>> {
103        Ok(self.inner.grad()?.map(Self::new))
104    }
105
106    pub fn zero_grad(&self) -> Result<()> {
107        Ok(self.inner.zero_grad()?)
108    }
109
110    pub fn backward(&self) -> Result<()> {
111        Ok(self.inner.backward()?)
112    }
113
114    pub fn backward_with_seed(&self, seed: &Self) -> Result<()> {
115        Ok(self.inner.backward_with_seed(seed.primal().clone())?)
116    }
117
118    pub fn shares_reverse_graph(&self, other: &Self) -> bool {
119        self.inner.shares_reverse_graph(&other.inner)
120    }
121
122    pub fn add(&self, rhs: &Self) -> Result<Self> {
123        let output = Self::from_value(add_dyn_values(self.value(), rhs.value())?);
124        jvp::add_tangent(self, rhs, &output)?;
125        Ok(output)
126    }
127
128    pub fn exp(&self) -> Result<Self> {
129        let output = Self::from_value(exp_dyn_value(self.value())?);
130        jvp::exp_tangent(self, &output)?;
131        Ok(output)
132    }
133
134    pub fn sum(&self) -> Result<Self> {
135        let output = Self::from_value(sum_dyn_value(self.value())?);
136        jvp::sum_tangent(self, &output)?;
137        Ok(output)
138    }
139
140    pub fn einsum(subscripts: &str, operands: &[&Self]) -> Result<Self> {
141        let values = operands
142            .iter()
143            .map(|tensor| tensor.value())
144            .collect::<Vec<_>>();
145        let output = Self::from_value(einsum_dyn_values(subscripts, &values)?);
146        jvp::einsum_tangent(subscripts, operands, &output)?;
147        Ok(output)
148    }
149
150    pub fn solve(&self, rhs: &Self) -> Result<Self> {
151        let output = Self::from_value(solve_dyn_values(self.value(), rhs.value())?);
152        jvp::solve_tangent(self, rhs, &output)?;
153        Ok(output)
154    }
155
156    pub fn lstsq(&self, rhs: &Self) -> Result<LstsqResult> {
157        let result = lstsq_dyn_values(self.value(), rhs.value())?;
158        let solution = Self::from_value(result.solution);
159        let residuals = Self::from_value(result.residuals);
160        jvp::lstsq_tangents(self, rhs, &solution, &residuals)?;
161        Ok(LstsqResult {
162            solution,
163            residuals,
164            rank: result.rank,
165            singular_values: Tensor::from(result.singular_values),
166        })
167    }
168
169    pub fn solve_triangular(&self, rhs: &Self, upper: bool) -> Result<Self> {
170        let output = Self::from_value(solve_triangular_dyn_value(
171            self.value(),
172            rhs.value(),
173            upper,
174        )?);
175        jvp::solve_triangular_tangent(self, rhs, &output, upper)?;
176        Ok(output)
177    }
178
179    pub fn det(&self) -> Result<Self> {
180        let output = Self::from_value(det_dyn_value(self.value())?);
181        jvp::det_tangent(self, &output)?;
182        Ok(output)
183    }
184
185    pub fn inv(&self) -> Result<Self> {
186        let output = Self::from_value(inv_dyn_value(self.value())?);
187        jvp::inv_tangent(self, &output)?;
188        Ok(output)
189    }
190
191    pub fn slogdet(&self) -> Result<SlogdetResult> {
192        let result: SlogdetResult = slogdet_dyn_value(self.value())?.into();
193        jvp::slogdet_tangents(self, &result.sign, &result.logabsdet)?;
194        Ok(result)
195    }
196
197    pub fn cholesky(&self) -> Result<Self> {
198        let output = Self::from_value(cholesky_dyn_value(self.value())?);
199        jvp::cholesky_tangent(self, &output)?;
200        Ok(output)
201    }
202
203    pub fn lu(&self, pivot: LuPivot) -> Result<LuResult> {
204        let result: LuResult = lu_dyn_value(self.value(), pivot)?.into();
205        jvp::lu_tangents(self, &result.p, &result.l, &result.u, pivot)?;
206        Ok(result)
207    }
208
209    pub fn norm(&self, kind: NormKind) -> Result<Self> {
210        let output = Self::from_value(norm_dyn_value(self.value(), kind)?);
211        jvp::norm_tangent(self, &output, kind)?;
212        Ok(output)
213    }
214
215    pub fn vector_norm(
216        &self,
217        ord: VectorNormOrd,
218        dim: Option<&[isize]>,
219        keepdim: bool,
220    ) -> Result<Self> {
221        validate_vector_norm_request(self.ndim(), dim, keepdim)?;
222        let kind = jvp::map_vector_norm_ord(ord)?;
223        let output = Self::from_value(norm_dyn_value(self.value(), kind)?);
224        jvp::vector_norm_tangent(self, &output, ord)?;
225        Ok(output)
226    }
227
228    pub fn matrix_norm(
229        &self,
230        ord: MatrixNormOrd,
231        dim: Option<(isize, isize)>,
232        keepdim: bool,
233    ) -> Result<Self> {
234        validate_matrix_norm_request(self.ndim(), dim, keepdim)?;
235        let kind = jvp::map_matrix_norm_ord(ord)?;
236        let output = Self::from_value(norm_dyn_value(self.value(), kind)?);
237        jvp::matrix_norm_tangent(self, &output, ord)?;
238        Ok(output)
239    }
240
241    pub fn qr(&self) -> Result<QrResult> {
242        crate::with_default_runtime(|_| Ok(()))?;
243        let result: QrResult = qr_dyn_value(self.value())?.into();
244        jvp::qr_tangents(self, &result.q, &result.r)?;
245        Ok(result)
246    }
247
248    pub fn svd(&self, options: Option<SvdOptions>) -> Result<SvdResult> {
249        let result: SvdResult = svd_dyn_value(self.value(), options.clone())?.into();
250        jvp::svd_tangents(self, &result.u, &result.s, &result.vt, options)?;
251        Ok(result)
252    }
253
254    pub fn eig(&self) -> Result<EigResult> {
255        let result: EigResult = eig_dyn_value(self.value())?.into();
256        jvp::eig_tangents(self, &result.values, &result.vectors)?;
257        Ok(result)
258    }
259
260    pub fn eigh(&self) -> Result<EighResult> {
261        let result: EighResult = eigen_dyn_value(self.value())?.into();
262        jvp::eigen_tangents(self, &result.values, &result.vectors)?;
263        Ok(result)
264    }
265
266    pub fn pinv(&self, rcond: Option<f64>) -> Result<Self> {
267        let output = Self::from_value(pinv_dyn_value(self.value(), rcond)?);
268        jvp::pinv_tangent(self, &output, rcond)?;
269        Ok(output)
270    }
271
272    pub fn matrix_exp(&self) -> Result<Self> {
273        let output = Self::from_value(matrix_exp_dyn_value(self.value())?);
274        jvp::matrix_exp_tangent(self, &output)?;
275        Ok(output)
276    }
277
278    pub fn try_to_vec<T>(&self) -> Result<Vec<T>>
279    where
280        T: DynTensorTyped + Copy,
281    {
282        let structured = T::structured_ref(self.primal()).ok_or_else(|| {
283            invalid_argument(format!(
284                "dtype mismatch in try_to_vec: tensor={:?}",
285                self.scalar_type()
286            ))
287        })?;
288        let dense = structured.to_dense()?;
289        let contiguous = dense.contiguous(MemoryOrder::ColumnMajor);
290        let slice = contiguous
291            .buffer()
292            .as_slice()
293            .ok_or_else(|| invalid_argument("try_to_vec requires host-accessible dense payload"))?;
294        Ok(slice.to_vec())
295    }
296
297    pub fn try_get<T>(&self, index: &[usize]) -> Result<T>
298    where
299        T: DynTensorTyped + Copy,
300    {
301        let structured = T::structured_ref(self.primal()).ok_or_else(|| {
302            invalid_argument(format!(
303                "dtype mismatch in try_get: tensor={:?}",
304                self.scalar_type()
305            ))
306        })?;
307        tensor_element(structured.payload(), index)
308    }
309}
310
311impl fmt::Debug for Tensor {
312    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313        f.debug_struct("Tensor")
314            .field("scalar_type", &self.scalar_type())
315            .field("dims", &self.dims())
316            .field("requires_grad", &self.requires_grad())
317            .finish()
318    }
319}
320
321impl<T> From<DenseTensor<T>> for Tensor
322where
323    T: DynTensorTyped + Copy,
324{
325    fn from(value: DenseTensor<T>) -> Self {
326        Self::from(StructuredTensor::from(value))
327    }
328}
329
330impl<T> From<StructuredTensor<T>> for Tensor
331where
332    T: DynTensorTyped + Copy,
333{
334    fn from(value: StructuredTensor<T>) -> Self {
335        Self::new(T::into_dyn(value))
336    }
337}
338
339impl From<DynTensor> for Tensor {
340    fn from(value: DynTensor) -> Self {
341        Self::new(value)
342    }
343}
344
345fn invalid_argument(message: impl Into<String>) -> Error {
346    AutodiffError::InvalidArgument(message.into()).into()
347}
348
349fn validate_vector_norm_request(ndim: usize, dim: Option<&[isize]>, keepdim: bool) -> Result<()> {
350    if keepdim {
351        return Err(invalid_argument(
352            "vector_norm currently supports keepdim=false only",
353        ));
354    }
355    if ndim != 1 {
356        return Err(invalid_argument(format!(
357            "vector_norm currently expects a rank-1 tensor, got ndim={ndim}",
358        )));
359    }
360    if dim.is_some() {
361        return Err(invalid_argument(
362            "vector_norm currently supports dim=None only",
363        ));
364    }
365    Ok(())
366}
367
368fn validate_matrix_norm_request(
369    ndim: usize,
370    dim: Option<(isize, isize)>,
371    keepdim: bool,
372) -> Result<()> {
373    if keepdim {
374        return Err(invalid_argument(
375            "matrix_norm currently supports keepdim=false only",
376        ));
377    }
378    if ndim != 2 {
379        return Err(invalid_argument(format!(
380            "matrix_norm currently expects a rank-2 tensor, got ndim={ndim}",
381        )));
382    }
383    if let Some(dim) = dim {
384        if dim != (0, 1) && dim != (1, 0) {
385            return Err(invalid_argument(format!(
386                "matrix_norm currently supports dim=(0, 1) only, got {dim:?}",
387            )));
388        }
389    }
390    Ok(())
391}