Skip to main content

tenferro_runtime/
tensor.rs

1//! Concrete tensor operation extension trait.
2//!
3//! `tenferro-tensor` owns storage and backend traits. This runtime crate
4//! provides backend-parametric operation methods through [`TensorOpsExt`].
5
6use tenferro_ops::broadcast::{broadcast_input_plan, broadcast_shape, broadcast_shapes};
7use tenferro_tensor::{CompareDir, DType, DotGeneralConfig, Error, Result, TensorBackend};
8
9use crate::TensorOpsExt;
10use tenferro_tensor::Tensor;
11
12impl TensorOpsExt for Tensor {
13    fn convert<B: TensorBackend>(&self, to: DType, backend: &mut B) -> Result<Tensor> {
14        convert(self, to, backend)
15    }
16
17    fn cast<B: TensorBackend>(&self, to: DType, backend: &mut B) -> Result<Tensor> {
18        cast(self, to, backend)
19    }
20
21    fn add<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
22        add(self, rhs, backend)
23    }
24
25    fn sub<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
26        sub(self, rhs, backend)
27    }
28
29    fn mul<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
30        mul(self, rhs, backend)
31    }
32
33    fn div<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
34        div(self, rhs, backend)
35    }
36
37    fn pow<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
38        pow(self, rhs, backend)
39    }
40
41    fn maximum<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
42        maximum(self, rhs, backend)
43    }
44
45    fn minimum<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
46        minimum(self, rhs, backend)
47    }
48
49    fn neg<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
50        neg(self, backend)
51    }
52
53    fn abs<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
54        abs(self, backend)
55    }
56
57    fn sign<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
58        sign(self, backend)
59    }
60
61    fn conj<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
62        conj(self, backend)
63    }
64
65    fn exp<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
66        exp(self, backend)
67    }
68
69    fn log<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
70        log(self, backend)
71    }
72
73    fn sin<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
74        sin(self, backend)
75    }
76
77    fn cos<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
78        cos(self, backend)
79    }
80
81    fn tanh<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
82        tanh(self, backend)
83    }
84
85    fn sqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
86        sqrt(self, backend)
87    }
88
89    fn rsqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
90        rsqrt(self, backend)
91    }
92
93    fn expm1<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
94        expm1(self, backend)
95    }
96
97    fn log1p<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
98        log1p(self, backend)
99    }
100
101    fn compare<B: TensorBackend>(
102        &self,
103        rhs: &Tensor,
104        dir: CompareDir,
105        backend: &mut B,
106    ) -> Result<Tensor> {
107        compare(self, rhs, dir, backend)
108    }
109
110    fn where_select<B: TensorBackend>(
111        &self,
112        on_true: &Tensor,
113        on_false: &Tensor,
114        backend: &mut B,
115    ) -> Result<Tensor> {
116        where_select(self, on_true, on_false, backend)
117    }
118
119    fn clamp<B: TensorBackend>(
120        &self,
121        lower: &Tensor,
122        upper: &Tensor,
123        backend: &mut B,
124    ) -> Result<Tensor> {
125        clamp(self, lower, upper, backend)
126    }
127
128    fn matmul<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
129        matmul(self, rhs, backend)
130    }
131
132    fn reshape<B: TensorBackend>(&self, shape: &[usize], backend: &mut B) -> Result<Tensor> {
133        reshape(self, shape, backend)
134    }
135
136    fn transpose<B: TensorBackend>(&self, perm: &[usize], backend: &mut B) -> Result<Tensor> {
137        transpose(self, perm, backend)
138    }
139
140    fn reduce_sum<B: TensorBackend>(&self, axes: &[usize], backend: &mut B) -> Result<Tensor> {
141        reduce_sum(self, axes, backend)
142    }
143}
144
145/// Convert a tensor to a different dtype using the checked conversion lattice.
146///
147/// Use [`cast`] for explicit lossy dtype projection.
148///
149/// # Examples
150///
151/// ```rust
152/// # use tenferro_cpu::CpuBackend;
153/// use tenferro_runtime::{DType, Tensor, TensorOpsExt};
154/// # let mut backend = CpuBackend::new();
155/// # let x = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
156/// let y = x.convert(DType::C64, &mut backend).unwrap();
157/// assert_eq!(y.dtype(), DType::C64);
158/// ```
159///
160/// # Errors
161///
162/// Returns an error when the requested conversion is outside tenferro's checked
163/// dtype-promotion lattice, or when the backend does not support the requested
164/// conversion.
165fn convert(input: &Tensor, to: DType, backend: &mut impl TensorBackend) -> Result<Tensor> {
166    backend.with_backend_session(|exec| exec.convert(input, to))
167}
168
169/// Cast a tensor to a different dtype using explicit dtype projection.
170///
171/// Unlike [`convert`], `cast` may truncate, narrow precision, project complex
172/// values to their real component, or use boolean truthiness where the backend
173/// supports the requested projection.
174///
175/// # Examples
176///
177/// ```rust
178/// # use tenferro_cpu::CpuBackend;
179/// use tenferro_runtime::{DType, Tensor, TensorOpsExt};
180/// # let mut backend = CpuBackend::new();
181/// # let x = Tensor::from_vec_col_major(vec![2], vec![1.2_f64, -2.8]).unwrap();
182/// let y = x.cast(DType::I32, &mut backend).unwrap();
183/// assert_eq!(y.as_slice::<i32>().unwrap(), &[1, -2]);
184/// ```
185///
186/// # Errors
187///
188/// Returns an error when the backend does not support the requested explicit
189/// dtype projection.
190fn cast(input: &Tensor, to: DType, backend: &mut impl TensorBackend) -> Result<Tensor> {
191    backend.with_backend_session(|exec| exec.cast(input, to))
192}
193
194/// Elementwise addition with NumPy-style broadcasting.
195///
196/// # Examples
197///
198/// ```rust
199/// # use tenferro_cpu::CpuBackend;
200/// use tenferro_runtime::{Tensor, TensorOpsExt};
201/// # let mut backend = CpuBackend::new();
202/// # let x = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
203/// # let y = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap();
204/// let z = x.add(&y, &mut backend).unwrap();
205/// ```
206fn add(lhs: &Tensor, rhs: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
207    let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
208    backend.with_backend_session(|exec| exec.add(&lhs, &rhs))
209}
210
211macro_rules! unary_fn {
212    ($name:ident, $method:ident, $summary:literal) => {
213        #[doc = $summary]
214        ///
215        /// # Examples
216        ///
217        /// ```rust
218        /// # use tenferro_cpu::CpuBackend;
219        /// use tenferro_runtime::{Tensor, TensorOpsExt};
220        /// # let mut backend = CpuBackend::new();
221        /// # let x = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 4.0]).unwrap();
222        #[doc = concat!("let y = x.", stringify!($name), "(&mut backend).unwrap();")]
223        /// ```
224        fn $name(input: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
225            backend.with_backend_session(|exec| exec.$method(input))
226        }
227    };
228}
229
230macro_rules! binary_fn {
231    ($name:ident, $method:ident, $summary:literal) => {
232        #[doc = $summary]
233        ///
234        /// # Examples
235        ///
236        /// ```rust
237        /// # use tenferro_cpu::CpuBackend;
238        /// use tenferro_runtime::{Tensor, TensorOpsExt};
239        /// # let mut backend = CpuBackend::new();
240        /// # let x = Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 4.0]).unwrap();
241        /// # let y = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 8.0]).unwrap();
242        #[doc = concat!("let z = x.", stringify!($name), "(&y, &mut backend).unwrap();")]
243        /// ```
244        fn $name(lhs: &Tensor, rhs: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
245            let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
246            backend.with_backend_session(|exec| exec.$method(&lhs, &rhs))
247        }
248    };
249}
250
251binary_fn!(
252    mul,
253    mul,
254    "Elementwise multiplication with NumPy-style broadcasting."
255);
256binary_fn!(
257    div,
258    div,
259    "Elementwise division with NumPy-style broadcasting."
260);
261binary_fn!(pow, pow, "Elementwise power with NumPy-style broadcasting.");
262binary_fn!(
263    maximum,
264    maximum,
265    "Elementwise maximum with NumPy-style broadcasting."
266);
267binary_fn!(
268    minimum,
269    minimum,
270    "Elementwise minimum with NumPy-style broadcasting."
271);
272
273unary_fn!(neg, neg, "Elementwise negation.");
274unary_fn!(abs, abs, "Elementwise absolute value.");
275unary_fn!(sign, sign, "Elementwise sign.");
276unary_fn!(conj, conj, "Elementwise complex conjugate.");
277unary_fn!(exp, exp, "Elementwise exponential.");
278unary_fn!(log, log, "Elementwise natural logarithm.");
279unary_fn!(sin, sin, "Elementwise sine.");
280unary_fn!(cos, cos, "Elementwise cosine.");
281unary_fn!(tanh, tanh, "Elementwise hyperbolic tangent.");
282unary_fn!(sqrt, sqrt, "Elementwise square root.");
283unary_fn!(rsqrt, rsqrt, "Elementwise reciprocal square root.");
284unary_fn!(expm1, expm1, "Elementwise `exp(x) - 1`.");
285unary_fn!(log1p, log1p, "Elementwise `log(1 + x)`.");
286
287/// Elementwise subtraction with NumPy-style broadcasting.
288///
289/// # Examples
290///
291/// ```rust
292/// # use tenferro_cpu::CpuBackend;
293/// use tenferro_runtime::{Tensor, TensorOpsExt};
294/// # let mut backend = CpuBackend::new();
295/// # let x = Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 4.0]).unwrap();
296/// # let y = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 8.0]).unwrap();
297/// let z = x.sub(&y, &mut backend).unwrap();
298/// ```
299fn sub(lhs: &Tensor, rhs: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
300    let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
301    let neg_rhs = backend.with_backend_session(|exec| exec.neg(&rhs))?;
302    backend.with_backend_session(|exec| exec.add(&lhs, &neg_rhs))
303}
304
305/// Elementwise comparison with NumPy-style broadcasting.
306///
307/// The result is a bool tensor.
308///
309/// # Examples
310///
311/// ```rust
312/// # use tenferro_cpu::CpuBackend;
313/// use tenferro_runtime::{CompareDir, Tensor, TensorOpsExt};
314/// # let mut backend = CpuBackend::new();
315/// # let x = Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 4.0]).unwrap();
316/// # let y = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 8.0]).unwrap();
317/// let z = x.compare(&y, CompareDir::Gt, &mut backend).unwrap();
318/// assert_eq!(z.as_slice::<bool>().unwrap(), &[true, false]);
319/// ```
320fn compare(
321    lhs: &Tensor,
322    rhs: &Tensor,
323    dir: CompareDir,
324    backend: &mut impl TensorBackend,
325) -> Result<Tensor> {
326    let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
327    backend.with_backend_session(|exec| exec.compare(&lhs, &rhs, &dir))
328}
329
330/// Select values from `on_true` or `on_false` using a condition tensor.
331///
332/// This corresponds to NumPy `where(condition, x, y)`.
333///
334/// # Examples
335///
336/// ```rust
337/// # use tenferro_cpu::CpuBackend;
338/// use tenferro_runtime::{CompareDir, Tensor, TensorOpsExt};
339/// # let mut backend = CpuBackend::new();
340/// # let x = Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 4.0]).unwrap();
341/// # let y = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 8.0]).unwrap();
342/// # let condition = x.compare(&y, CompareDir::Gt, &mut backend).unwrap();
343/// let z = condition.where_select(&x, &y, &mut backend).unwrap();
344/// ```
345fn where_select(
346    condition: &Tensor,
347    on_true: &Tensor,
348    on_false: &Tensor,
349    backend: &mut impl TensorBackend,
350) -> Result<Tensor> {
351    let (condition, on_true, on_false) = broadcast_ternary(condition, on_true, on_false, backend)?;
352    backend.with_backend_session(|exec| exec.select(&condition, &on_true, &on_false))
353}
354
355/// Clamp values elementwise between lower and upper bounds.
356///
357/// # Examples
358///
359/// ```rust
360/// # use tenferro_cpu::CpuBackend;
361/// use tenferro_runtime::{Tensor, TensorOpsExt};
362/// # let mut backend = CpuBackend::new();
363/// # let x = Tensor::from_vec_col_major(vec![2], vec![-2.0_f64, 4.0]).unwrap();
364/// # let lower = Tensor::from_vec_col_major(vec![], vec![0.0_f64]).unwrap();
365/// # let upper = Tensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
366/// let z = x.clamp(&lower, &upper, &mut backend).unwrap();
367/// ```
368fn clamp(
369    input: &Tensor,
370    lower: &Tensor,
371    upper: &Tensor,
372    backend: &mut impl TensorBackend,
373) -> Result<Tensor> {
374    let (input, lower, upper) = broadcast_ternary(input, lower, upper, backend)?;
375    backend.with_backend_session(|exec| exec.clamp(&input, &lower, &upper))
376}
377
378/// Matrix multiplication helper for rank-2 tensors.
379///
380/// This contracts the last dimension of `a` with the first dimension of `b`.
381///
382/// # Examples
383///
384/// ```rust
385/// # use tenferro_cpu::CpuBackend;
386/// use tenferro_runtime::{Tensor, TensorOpsExt};
387/// # let mut backend = CpuBackend::new();
388/// # let a = Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
389/// # let b = Tensor::from_vec_col_major(vec![3, 2], vec![1.0_f64; 6]).unwrap();
390/// let c = a.matmul(&b, &mut backend).unwrap();
391/// ```
392fn matmul(a: &Tensor, b: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
393    let config = DotGeneralConfig {
394        lhs_contracting_dims: vec![a.shape().len() - 1],
395        rhs_contracting_dims: vec![0],
396        lhs_batch_dims: vec![],
397        rhs_batch_dims: vec![],
398    };
399    backend.with_backend_session(|exec| exec.dot_general(a, b, &config))
400}
401
402/// Reshape a tensor without changing element order.
403///
404/// # Examples
405///
406/// ```rust
407/// # use tenferro_cpu::CpuBackend;
408/// use tenferro_runtime::{Tensor, TensorOpsExt};
409/// # let mut backend = CpuBackend::new();
410/// # let x = Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
411/// let y = x.reshape(&[4], &mut backend).unwrap();
412/// assert_eq!(y.shape(), &[4]);
413/// ```
414fn reshape(input: &Tensor, shape: &[usize], backend: &mut impl TensorBackend) -> Result<Tensor> {
415    backend.with_backend_session(|exec| exec.reshape(input, shape))
416}
417
418/// Permute tensor axes.
419///
420/// # Examples
421///
422/// ```rust
423/// # use tenferro_cpu::CpuBackend;
424/// use tenferro_runtime::{Tensor, TensorOpsExt};
425/// # let mut backend = CpuBackend::new();
426/// # let x = Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
427/// let y = x.transpose(&[1, 0], &mut backend).unwrap();
428/// assert_eq!(y.shape(), &[3, 2]);
429/// ```
430fn transpose(input: &Tensor, perm: &[usize], backend: &mut impl TensorBackend) -> Result<Tensor> {
431    backend.with_backend_session(|exec| exec.transpose(input, perm))
432}
433
434/// Sum a tensor over one or more axes.
435///
436/// # Examples
437///
438/// ```rust
439/// # use tenferro_cpu::CpuBackend;
440/// use tenferro_runtime::{Tensor, TensorOpsExt};
441/// # let mut backend = CpuBackend::new();
442/// # let x = Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
443/// let y = x.reduce_sum(&[0], &mut backend).unwrap();
444/// assert_eq!(y.shape(), &[2]);
445/// ```
446fn reduce_sum(input: &Tensor, axes: &[usize], backend: &mut impl TensorBackend) -> Result<Tensor> {
447    backend.with_backend_session(|exec| exec.reduce_sum(input, axes))
448}
449
450fn broadcast_binary(
451    lhs: &Tensor,
452    rhs: &Tensor,
453    backend: &mut impl TensorBackend,
454) -> Result<(Tensor, Tensor)> {
455    let shape = broadcast_shape(lhs.shape(), rhs.shape()).map_err(broadcast_error)?;
456    Ok((
457        broadcast_to(lhs, &shape, backend)?,
458        broadcast_to(rhs, &shape, backend)?,
459    ))
460}
461
462fn broadcast_ternary(
463    first: &Tensor,
464    second: &Tensor,
465    third: &Tensor,
466    backend: &mut impl TensorBackend,
467) -> Result<(Tensor, Tensor, Tensor)> {
468    let shape = broadcast_shapes([first.shape(), second.shape(), third.shape()])
469        .map_err(broadcast_error)?;
470    Ok((
471        broadcast_to(first, &shape, backend)?,
472        broadcast_to(second, &shape, backend)?,
473        broadcast_to(third, &shape, backend)?,
474    ))
475}
476
477fn broadcast_to(
478    input: &Tensor,
479    target_shape: &[usize],
480    backend: &mut impl TensorBackend,
481) -> Result<Tensor> {
482    let input_shape = input.shape();
483    if input_shape == target_shape {
484        return Ok(input.clone());
485    }
486
487    let plan = broadcast_input_plan(input_shape, target_shape).map_err(broadcast_error)?;
488    let source = if plan.source_shape == input_shape {
489        input.clone()
490    } else {
491        backend.with_backend_session(|exec| exec.reshape(input, &plan.source_shape))?
492    };
493    backend.with_backend_session(|exec| exec.broadcast_in_dim(&source, target_shape, &plan.dims))
494}
495
496fn broadcast_error(err: impl std::fmt::Display) -> Error {
497    Error::backend_failure("broadcast", err.to_string())
498}