Skip to main content

tenferro_ad/
shape_packing.rs

1use tenferro_tensor::{GatherConfig, Tensor, TensorDeviceTransfer, TypedTensor};
2
3use crate::eager::EagerTensor;
4use crate::error::{Error, Result};
5
6fn normalize_existing_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
7    let normalized = if axis < 0 { rank as isize + axis } else { axis };
8    if normalized < 0 || normalized >= rank as isize {
9        return Err(tenferro_tensor::Error::AxisOutOfBounds {
10            op,
11            axis: axis.unsigned_abs(),
12            rank,
13        }
14        .into());
15    }
16    Ok(normalized as usize)
17}
18
19fn normalize_insert_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
20    let normalized = if axis < 0 {
21        rank as isize + 1 + axis
22    } else {
23        axis
24    };
25    if normalized < 0 || normalized > rank as isize {
26        return Err(tenferro_tensor::Error::AxisOutOfBounds {
27            op,
28            axis: axis.unsigned_abs(),
29            rank: rank + 1,
30        }
31        .into());
32    }
33    Ok(normalized as usize)
34}
35
36fn index_select_config(
37    shape: &[usize],
38    axis: isize,
39    positions: &[usize],
40) -> Result<(Tensor, GatherConfig)> {
41    let axis = normalize_existing_axis("index_select", axis, shape.len())?;
42    let axis_extent = shape[axis];
43    for &position in positions {
44        if position >= axis_extent {
45            return Err(tenferro_tensor::Error::InvalidConfig {
46                op: "index_select",
47                message: format!(
48                    "position {position} out of bounds for axis {axis} with extent {axis_extent}"
49                ),
50            }
51            .into());
52        }
53    }
54
55    let mut slice_sizes = shape.to_vec();
56    slice_sizes[axis] = 1;
57
58    let offset_dims = (0..shape.len()).filter(|&dim| dim != axis).collect();
59    let index_data = positions
60        .iter()
61        .map(|&position| {
62            i64::try_from(position).map_err(|_| tenferro_tensor::Error::InvalidConfig {
63                op: "index_select",
64                message: format!("position {position} cannot be represented as i64"),
65            })
66        })
67        .collect::<tenferro_tensor::Result<Vec<_>>>()?;
68    let indices = Tensor::I64(TypedTensor::from_vec_col_major(
69        vec![positions.len(), 1],
70        index_data,
71    )?);
72
73    let config = GatherConfig {
74        offset_dims,
75        collapsed_slice_dims: vec![axis],
76        start_index_map: vec![axis],
77        index_vector_dim: 1,
78        slice_sizes,
79    };
80
81    Ok((indices, config))
82}
83
84fn validate_stack_shapes(op: &'static str, shapes: &[&[usize]]) -> Result<()> {
85    let Some(first) = shapes.first() else {
86        return Err(tenferro_tensor::Error::InvalidConfig {
87            op,
88            message: "stack requires at least one input".into(),
89        }
90        .into());
91    };
92    for shape in shapes.iter().skip(1) {
93        if *shape != *first {
94            return Err(tenferro_tensor::Error::ShapeMismatch {
95                op,
96                lhs: first.to_vec(),
97                rhs: shape.to_vec(),
98            }
99            .into());
100        }
101    }
102    Ok(())
103}
104
105impl EagerTensor {
106    /// Select entries from one axis using host-known indices.
107    ///
108    /// The index list is primal metadata: gradients flow to `self`, including
109    /// accumulation for repeated indices, but not to the selected positions.
110    ///
111    /// # Examples
112    ///
113    /// ```
114    /// use tenferro_cpu::CpuBackend;
115    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
116    ///
117    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
118    /// let x = EagerTensor::from_tensor_in(
119    ///     Tensor::from_vec_col_major(vec![3], vec![10.0_f64, 20.0, 30.0]).unwrap(),
120    ///     ctx,
121    /// ).unwrap();
122    /// let y = x.take_axis(0, &[2, 0]).unwrap();
123    ///
124    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[30.0, 10.0]);
125    /// ```
126    pub fn take_axis(&self, axis: usize, indices: &[usize]) -> Result<Self> {
127        let axis = isize::try_from(axis).map_err(|_| {
128            Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
129                op: "take_axis",
130                message: format!("axis {axis} cannot be represented as isize"),
131            })
132        })?;
133        self.index_select(axis, indices)
134    }
135
136    /// Select matrix rows using host-known row indices.
137    ///
138    /// # Examples
139    ///
140    /// ```
141    /// use tenferro_cpu::CpuBackend;
142    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
143    ///
144    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
145    /// let x = EagerTensor::from_tensor_in(
146    ///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
147    ///     ctx,
148    /// ).unwrap();
149    /// let y = x.take_rows(&[1]).unwrap();
150    ///
151    /// assert_eq!(y.shape(), &[1, 2]);
152    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0, 4.0]);
153    /// ```
154    pub fn take_rows(&self, rows: &[usize]) -> Result<Self> {
155        self.take_axis(0, rows)
156    }
157
158    /// Select matrix columns using host-known column indices.
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// use tenferro_cpu::CpuBackend;
164    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
165    ///
166    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
167    /// let x = EagerTensor::from_tensor_in(
168    ///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
169    ///     ctx,
170    /// ).unwrap();
171    /// let y = x.take_cols(&[1]).unwrap();
172    ///
173    /// assert_eq!(y.shape(), &[2, 1]);
174    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 4.0]);
175    /// ```
176    pub fn take_cols(&self, cols: &[usize]) -> Result<Self> {
177        self.take_axis(1, cols)
178    }
179
180    /// Select a matrix block using host-known row and column indices.
181    ///
182    /// This is a convenience wrapper over row selection followed by column
183    /// selection. The row and column lists, plus the approximation rank implied
184    /// by their lengths, are fixed primal metadata.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use tenferro_cpu::CpuBackend;
190    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
191    ///
192    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
193    /// let x = EagerTensor::from_tensor_in(
194    ///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
195    ///     ctx,
196    /// ).unwrap();
197    /// let y = x.take_block(&[1], &[0]).unwrap();
198    ///
199    /// assert_eq!(y.shape(), &[1, 1]);
200    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0]);
201    /// ```
202    pub fn take_block(&self, rows: &[usize], cols: &[usize]) -> Result<Self> {
203        self.take_rows(rows)?.take_cols(cols)
204    }
205
206    /// Select entries from one axis using host-known positions.
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// use tenferro_cpu::CpuBackend;
212    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
213    ///
214    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
215    /// let x = EagerTensor::from_tensor_in(
216    ///     Tensor::from_vec_col_major(vec![3], vec![10.0_f64, 20.0, 30.0]).unwrap(),
217    ///     ctx,
218    /// ).unwrap();
219    /// let y = x.index_select(-1, &[2, 0]).unwrap();
220    ///
221    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[30.0, 10.0]);
222    /// ```
223    pub fn index_select(&self, axis: isize, positions: &[usize]) -> Result<Self> {
224        let (indices, config) = index_select_config(self.shape(), axis, positions)?;
225        let indices = {
226            let mut backend = self
227                .ctx
228                .backend
229                .lock()
230                .map_err(|_| Error::Internal("backend lock poisoned".to_string()))?;
231            backend.upload_host_tensor(&indices)?
232        };
233        let indices = self.ctx.constant_from(indices)?;
234        self.gather(&indices, config)
235    }
236
237    /// Stack tensors along a newly inserted axis.
238    ///
239    /// The returned tensor uses the context of the first input, matching
240    /// [`Self::concatenate`]. All inputs must belong to that same context.
241    ///
242    /// # Examples
243    ///
244    /// ```
245    /// use tenferro_cpu::CpuBackend;
246    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
247    ///
248    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
249    /// let a = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
250    /// let b = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap(), ctx).unwrap();
251    /// let out = EagerTensor::stack(&[&a, &b], -1).unwrap();
252    ///
253    /// assert_eq!(out.shape(), &[2]);
254    /// assert_eq!(out.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
255    /// ```
256    pub fn stack(tensors: &[&Self], dim: isize) -> Result<Self> {
257        let first = tensors.first().copied().ok_or_else(|| {
258            Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
259                op: "stack",
260                message: "stack requires at least one input".into(),
261            })
262        })?;
263        let shapes = tensors
264            .iter()
265            .map(|tensor| tensor.shape())
266            .collect::<Vec<_>>();
267        validate_stack_shapes("stack", &shapes)?;
268
269        let axis = normalize_insert_axis("stack", dim, first.shape().len())?;
270        let mut expanded_shape = first.shape().to_vec();
271        expanded_shape.insert(axis, 1);
272
273        let expanded = tensors
274            .iter()
275            .map(|tensor| tensor.reshape(&expanded_shape))
276            .collect::<Result<Vec<_>>>()?;
277        let refs = expanded.iter().collect::<Vec<_>>();
278        Self::concatenate(&refs, axis)
279    }
280}