tenferro_ad/shape_packing.rs
1use tenferro_tensor::{GatherConfig, Tensor, TensorDeviceTransfer, TypedTensor};
2
3use crate::eager::EagerTensor;
4use crate::error::{Error, Result};
5
6fn normalize_existing_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
7 let normalized = if axis < 0 { rank as isize + axis } else { axis };
8 if normalized < 0 || normalized >= rank as isize {
9 return Err(tenferro_tensor::Error::AxisOutOfBounds {
10 op,
11 axis: axis.unsigned_abs(),
12 rank,
13 }
14 .into());
15 }
16 Ok(normalized as usize)
17}
18
19fn normalize_insert_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
20 let normalized = if axis < 0 {
21 rank as isize + 1 + axis
22 } else {
23 axis
24 };
25 if normalized < 0 || normalized > rank as isize {
26 return Err(tenferro_tensor::Error::AxisOutOfBounds {
27 op,
28 axis: axis.unsigned_abs(),
29 rank: rank + 1,
30 }
31 .into());
32 }
33 Ok(normalized as usize)
34}
35
36fn index_select_config(
37 shape: &[usize],
38 axis: isize,
39 positions: &[usize],
40) -> Result<(Tensor, GatherConfig)> {
41 let axis = normalize_existing_axis("index_select", axis, shape.len())?;
42 let axis_extent = shape[axis];
43 for &position in positions {
44 if position >= axis_extent {
45 return Err(tenferro_tensor::Error::InvalidConfig {
46 op: "index_select",
47 message: format!(
48 "position {position} out of bounds for axis {axis} with extent {axis_extent}"
49 ),
50 }
51 .into());
52 }
53 }
54
55 let mut slice_sizes = shape.to_vec();
56 slice_sizes[axis] = 1;
57
58 let offset_dims = (0..shape.len()).filter(|&dim| dim != axis).collect();
59 let index_data = positions
60 .iter()
61 .map(|&position| {
62 i64::try_from(position).map_err(|_| tenferro_tensor::Error::InvalidConfig {
63 op: "index_select",
64 message: format!("position {position} cannot be represented as i64"),
65 })
66 })
67 .collect::<tenferro_tensor::Result<Vec<_>>>()?;
68 let indices = Tensor::I64(TypedTensor::from_vec_col_major(
69 vec![positions.len(), 1],
70 index_data,
71 )?);
72
73 let config = GatherConfig {
74 offset_dims,
75 collapsed_slice_dims: vec![axis],
76 start_index_map: vec![axis],
77 index_vector_dim: 1,
78 slice_sizes,
79 };
80
81 Ok((indices, config))
82}
83
84fn validate_stack_shapes(op: &'static str, shapes: &[&[usize]]) -> Result<()> {
85 let Some(first) = shapes.first() else {
86 return Err(tenferro_tensor::Error::InvalidConfig {
87 op,
88 message: "stack requires at least one input".into(),
89 }
90 .into());
91 };
92 for shape in shapes.iter().skip(1) {
93 if *shape != *first {
94 return Err(tenferro_tensor::Error::ShapeMismatch {
95 op,
96 lhs: first.to_vec(),
97 rhs: shape.to_vec(),
98 }
99 .into());
100 }
101 }
102 Ok(())
103}
104
105impl EagerTensor {
106 /// Select entries from one axis using host-known indices.
107 ///
108 /// The index list is primal metadata: gradients flow to `self`, including
109 /// accumulation for repeated indices, but not to the selected positions.
110 ///
111 /// # Examples
112 ///
113 /// ```
114 /// use tenferro_cpu::CpuBackend;
115 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
116 ///
117 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
118 /// let x = EagerTensor::from_tensor_in(
119 /// Tensor::from_vec_col_major(vec![3], vec![10.0_f64, 20.0, 30.0]).unwrap(),
120 /// ctx,
121 /// ).unwrap();
122 /// let y = x.take_axis(0, &[2, 0]).unwrap();
123 ///
124 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[30.0, 10.0]);
125 /// ```
126 pub fn take_axis(&self, axis: usize, indices: &[usize]) -> Result<Self> {
127 let axis = isize::try_from(axis).map_err(|_| {
128 Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
129 op: "take_axis",
130 message: format!("axis {axis} cannot be represented as isize"),
131 })
132 })?;
133 self.index_select(axis, indices)
134 }
135
136 /// Select matrix rows using host-known row indices.
137 ///
138 /// # Examples
139 ///
140 /// ```
141 /// use tenferro_cpu::CpuBackend;
142 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
143 ///
144 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
145 /// let x = EagerTensor::from_tensor_in(
146 /// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
147 /// ctx,
148 /// ).unwrap();
149 /// let y = x.take_rows(&[1]).unwrap();
150 ///
151 /// assert_eq!(y.shape(), &[1, 2]);
152 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0, 4.0]);
153 /// ```
154 pub fn take_rows(&self, rows: &[usize]) -> Result<Self> {
155 self.take_axis(0, rows)
156 }
157
158 /// Select matrix columns using host-known column indices.
159 ///
160 /// # Examples
161 ///
162 /// ```
163 /// use tenferro_cpu::CpuBackend;
164 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
165 ///
166 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
167 /// let x = EagerTensor::from_tensor_in(
168 /// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
169 /// ctx,
170 /// ).unwrap();
171 /// let y = x.take_cols(&[1]).unwrap();
172 ///
173 /// assert_eq!(y.shape(), &[2, 1]);
174 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 4.0]);
175 /// ```
176 pub fn take_cols(&self, cols: &[usize]) -> Result<Self> {
177 self.take_axis(1, cols)
178 }
179
180 /// Select a matrix block using host-known row and column indices.
181 ///
182 /// This is a convenience wrapper over row selection followed by column
183 /// selection. The row and column lists, plus the approximation rank implied
184 /// by their lengths, are fixed primal metadata.
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// use tenferro_cpu::CpuBackend;
190 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
191 ///
192 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
193 /// let x = EagerTensor::from_tensor_in(
194 /// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
195 /// ctx,
196 /// ).unwrap();
197 /// let y = x.take_block(&[1], &[0]).unwrap();
198 ///
199 /// assert_eq!(y.shape(), &[1, 1]);
200 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0]);
201 /// ```
202 pub fn take_block(&self, rows: &[usize], cols: &[usize]) -> Result<Self> {
203 self.take_rows(rows)?.take_cols(cols)
204 }
205
206 /// Select entries from one axis using host-known positions.
207 ///
208 /// # Examples
209 ///
210 /// ```
211 /// use tenferro_cpu::CpuBackend;
212 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
213 ///
214 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
215 /// let x = EagerTensor::from_tensor_in(
216 /// Tensor::from_vec_col_major(vec![3], vec![10.0_f64, 20.0, 30.0]).unwrap(),
217 /// ctx,
218 /// ).unwrap();
219 /// let y = x.index_select(-1, &[2, 0]).unwrap();
220 ///
221 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[30.0, 10.0]);
222 /// ```
223 pub fn index_select(&self, axis: isize, positions: &[usize]) -> Result<Self> {
224 let (indices, config) = index_select_config(self.shape(), axis, positions)?;
225 let indices = {
226 let mut backend = self
227 .ctx
228 .backend
229 .lock()
230 .map_err(|_| Error::Internal("backend lock poisoned".to_string()))?;
231 backend.upload_host_tensor(&indices)?
232 };
233 let indices = self.ctx.constant_from(indices)?;
234 self.gather(&indices, config)
235 }
236
237 /// Stack tensors along a newly inserted axis.
238 ///
239 /// The returned tensor uses the context of the first input, matching
240 /// [`Self::concatenate`]. All inputs must belong to that same context.
241 ///
242 /// # Examples
243 ///
244 /// ```
245 /// use tenferro_cpu::CpuBackend;
246 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
247 ///
248 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
249 /// let a = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
250 /// let b = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap(), ctx).unwrap();
251 /// let out = EagerTensor::stack(&[&a, &b], -1).unwrap();
252 ///
253 /// assert_eq!(out.shape(), &[2]);
254 /// assert_eq!(out.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
255 /// ```
256 pub fn stack(tensors: &[&Self], dim: isize) -> Result<Self> {
257 let first = tensors.first().copied().ok_or_else(|| {
258 Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
259 op: "stack",
260 message: "stack requires at least one input".into(),
261 })
262 })?;
263 let shapes = tensors
264 .iter()
265 .map(|tensor| tensor.shape())
266 .collect::<Vec<_>>();
267 validate_stack_shapes("stack", &shapes)?;
268
269 let axis = normalize_insert_axis("stack", dim, first.shape().len())?;
270 let mut expanded_shape = first.shape().to_vec();
271 expanded_shape.insert(axis, 1);
272
273 let expanded = tensors
274 .iter()
275 .map(|tensor| tensor.reshape(&expanded_shape))
276 .collect::<Result<Vec<_>>>()?;
277 let refs = expanded.iter().collect::<Vec<_>>();
278 Self::concatenate(&refs, axis)
279 }
280}