tenferro_prims/
tensor_ops.rs

1//! GPU-generic free functions for tensor data operations.
2//!
3//! These functions expose `triu`, `tril`, `cat`, and `stack` as context-bearing
4//! free functions in `tenferro-prims`, which is the correct layer for
5//! GPU-dispatch-aware tensor operations.
6//!
7//! **Design note**: `tenferro-tensor` does NOT depend on `tenferro-prims`
8//! (the dependency goes the other way), so these must be free functions here,
9//! not methods on `Tensor`.
10//!
11//! **Current implementation**: Delegates to existing `Tensor` methods for CPU
12//! (the context parameter is unused in the current CPU path). When GPU backends
13//! are wired, these functions will dispatch through the context to the
14//! appropriate backend kernels.
15//!
16//! # Examples
17//!
18//! ```ignore
19//! use tenferro_device::LogicalMemorySpace;
20//! use tenferro_prims::{tensor_ops, CpuContext};
21//! use tenferro_tensor::{MemoryOrder, Tensor};
22//!
23//! let mut ctx = CpuContext::new(1);
24//! let a = Tensor::<f64>::ones(
25//!     &[3, 3],
26//!     LogicalMemorySpace::MainMemory,
27//!     MemoryOrder::ColumnMajor,
28//! ).unwrap();
29//!
30//! let upper = tensor_ops::triu(&mut ctx, &a, 0).unwrap();
31//! let lower = tensor_ops::tril(&mut ctx, &a, 0).unwrap();
32//! ```
33
34use tenferro_algebra::Scalar;
35use tenferro_device::{Error, Result};
36use tenferro_tensor::Tensor;
37
38use crate::TensorMetadataContextFor;
39
40/// Extract the upper triangular part of a matrix (or batch of matrices).
41///
42/// Elements below the `diagonal`-th diagonal are set to zero. The context
43/// parameter enables future GPU-backend dispatch.
44///
45/// The `diagonal` parameter controls which diagonal is the boundary:
46/// - `0` is the main diagonal
47/// - positive values move above the main diagonal
48/// - negative values move below the main diagonal
49///
50/// # Errors
51///
52/// Returns an error if the tensor has fewer than 2 dimensions or if backend
53/// dispatch fails.
54///
55/// # Examples
56///
57/// ```ignore
58/// use tenferro_device::LogicalMemorySpace;
59/// use tenferro_prims::{tensor_ops, CpuContext};
60/// use tenferro_tensor::{MemoryOrder, Tensor};
61///
62/// let mut ctx = CpuContext::new(1);
63/// let a = Tensor::<f64>::ones(
64///     &[3, 3],
65///     LogicalMemorySpace::MainMemory,
66///     MemoryOrder::ColumnMajor,
67/// ).unwrap();
68/// let upper = tensor_ops::triu(&mut ctx, &a, 0).unwrap();
69/// assert_eq!(upper.dims(), &[3, 3]);
70/// ```
71pub fn triu<T, C>(_ctx: &mut C, tensor: &Tensor<T>, diagonal: isize) -> Result<Tensor<T>>
72where
73    T: Scalar,
74    C: TensorMetadataContextFor,
75{
76    if tensor.ndim() < 2 {
77        return Err(Error::InvalidArgument(format!(
78            "triu requires at least 2 dimensions (got {})",
79            tensor.ndim()
80        )));
81    }
82    Ok(tensor.triu(diagonal))
83}
84
85/// Extract the lower triangular part of a matrix (or batch of matrices).
86///
87/// Elements above the `diagonal`-th diagonal are set to zero. The context
88/// parameter enables future GPU-backend dispatch.
89///
90/// The `diagonal` parameter controls which diagonal is the boundary:
91/// - `0` is the main diagonal
92/// - positive values move above the main diagonal
93/// - negative values move below the main diagonal
94///
95/// # Errors
96///
97/// Returns an error if the tensor has fewer than 2 dimensions or if backend
98/// dispatch fails.
99///
100/// # Examples
101///
102/// ```ignore
103/// use tenferro_device::LogicalMemorySpace;
104/// use tenferro_prims::{tensor_ops, CpuContext};
105/// use tenferro_tensor::{MemoryOrder, Tensor};
106///
107/// let mut ctx = CpuContext::new(1);
108/// let a = Tensor::<f64>::ones(
109///     &[3, 3],
110///     LogicalMemorySpace::MainMemory,
111///     MemoryOrder::ColumnMajor,
112/// ).unwrap();
113/// let lower = tensor_ops::tril(&mut ctx, &a, 0).unwrap();
114/// assert_eq!(lower.dims(), &[3, 3]);
115/// ```
116pub fn tril<T, C>(_ctx: &mut C, tensor: &Tensor<T>, diagonal: isize) -> Result<Tensor<T>>
117where
118    T: Scalar,
119    C: TensorMetadataContextFor,
120{
121    if tensor.ndim() < 2 {
122        return Err(Error::InvalidArgument(format!(
123            "tril requires at least 2 dimensions (got {})",
124            tensor.ndim()
125        )));
126    }
127    Ok(tensor.tril(diagonal))
128}
129
130/// Concatenate tensors along an existing dimension.
131///
132/// The context parameter enables future GPU-backend dispatch. Currently
133/// delegates to `Tensor::cat`.
134///
135/// # Errors
136///
137/// Returns an error if tensors have incompatible shapes, the axis is out of
138/// range, or the tensor list is empty.
139///
140/// # Examples
141///
142/// ```ignore
143/// use tenferro_device::LogicalMemorySpace;
144/// use tenferro_prims::{tensor_ops, CpuContext};
145/// use tenferro_tensor::{MemoryOrder, Tensor};
146///
147/// let mut ctx = CpuContext::new(1);
148/// let a = Tensor::<f64>::zeros(
149///     &[2, 3],
150///     LogicalMemorySpace::MainMemory,
151///     MemoryOrder::ColumnMajor,
152/// ).unwrap();
153/// let b = Tensor::<f64>::zeros(
154///     &[2, 4],
155///     LogicalMemorySpace::MainMemory,
156///     MemoryOrder::ColumnMajor,
157/// ).unwrap();
158/// let result = tensor_ops::cat(&mut ctx, &[&a, &b], 1).unwrap();
159/// assert_eq!(result.dims(), &[2, 7]);
160/// ```
161pub fn cat<T, C>(_ctx: &mut C, tensors: &[&Tensor<T>], axis: usize) -> Result<Tensor<T>>
162where
163    T: Scalar,
164    C: TensorMetadataContextFor,
165{
166    Tensor::cat(tensors, axis as isize)
167}
168
169/// Stack tensors along a new dimension.
170///
171/// Creates a new dimension at `axis` and concatenates the input tensors
172/// along it. All input tensors must have the same shape.
173///
174/// The context parameter enables future GPU-backend dispatch. Currently
175/// delegates to `Tensor::stack`.
176///
177/// # Errors
178///
179/// Returns an error if tensors have different shapes, the axis is out of
180/// range, or the tensor list is empty.
181///
182/// # Examples
183///
184/// ```ignore
185/// use tenferro_device::LogicalMemorySpace;
186/// use tenferro_prims::{tensor_ops, CpuContext};
187/// use tenferro_tensor::{MemoryOrder, Tensor};
188///
189/// let mut ctx = CpuContext::new(1);
190/// let a = Tensor::<f64>::zeros(
191///     &[2, 3],
192///     LogicalMemorySpace::MainMemory,
193///     MemoryOrder::ColumnMajor,
194/// ).unwrap();
195/// let b = Tensor::<f64>::zeros(
196///     &[2, 3],
197///     LogicalMemorySpace::MainMemory,
198///     MemoryOrder::ColumnMajor,
199/// ).unwrap();
200/// let result = tensor_ops::stack(&mut ctx, &[&a, &b], 0).unwrap();
201/// assert_eq!(result.dims(), &[2, 2, 3]);
202/// ```
203pub fn stack<T, C>(_ctx: &mut C, tensors: &[&Tensor<T>], axis: usize) -> Result<Tensor<T>>
204where
205    T: Scalar,
206    C: TensorMetadataContextFor,
207{
208    Tensor::stack(tensors, axis as isize)
209}