Skip to main content

tenferro_runtime/
shape_packing.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use computegraph::graph::GraphBuilder;
5use computegraph::types::{OperationRole, ValueRef};
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_tensor::{GatherConfig, Tensor, TypedTensor};
8
9use crate::checkpoint::CheckpointNode;
10use crate::error::{Error, Result};
11use crate::metadata::{metadata_scopes_with_new, register_scoped_value_metadata, tensor_meta};
12use crate::shape_infer::promote_dtypes;
13use crate::sym_dim::SymDim;
14use crate::traced::{
15    apply_binary_preserve_input_dtypes, infer_traced_single_output_shape, next_traced_id,
16    try_concrete_shape,
17};
18use crate::TracedTensor;
19
20fn normalize_existing_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
21    let normalized = if axis < 0 { rank as isize + axis } else { axis };
22    if normalized < 0 || normalized >= rank as isize {
23        return Err(tenferro_tensor::Error::AxisOutOfBounds {
24            op,
25            axis: axis.unsigned_abs(),
26            rank,
27        }
28        .into());
29    }
30    Ok(normalized as usize)
31}
32
33fn normalize_insert_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
34    let normalized = if axis < 0 {
35        rank as isize + 1 + axis
36    } else {
37        axis
38    };
39    if normalized < 0 || normalized > rank as isize {
40        return Err(tenferro_tensor::Error::AxisOutOfBounds {
41            op,
42            axis: axis.unsigned_abs(),
43            rank: rank + 1,
44        }
45        .into());
46    }
47    Ok(normalized as usize)
48}
49
50fn index_select_config(
51    shape: &[usize],
52    axis: isize,
53    positions: &[usize],
54) -> Result<(Tensor, GatherConfig, Vec<usize>)> {
55    let axis = normalize_existing_axis("index_select", axis, shape.len())?;
56    let axis_extent = shape[axis];
57    for &position in positions {
58        if position >= axis_extent {
59            return Err(tenferro_tensor::Error::InvalidConfig {
60                op: "index_select",
61                message: format!(
62                    "position {position} out of bounds for axis {axis} with extent {axis_extent}"
63                ),
64            }
65            .into());
66        }
67    }
68
69    let mut out_shape = shape.to_vec();
70    out_shape[axis] = positions.len();
71
72    let mut slice_sizes = shape.to_vec();
73    slice_sizes[axis] = 1;
74
75    let offset_dims = (0..shape.len()).filter(|&dim| dim != axis).collect();
76    let index_data = positions
77        .iter()
78        .map(|&position| {
79            i64::try_from(position).map_err(|_| tenferro_tensor::Error::InvalidConfig {
80                op: "index_select",
81                message: format!("position {position} cannot be represented as i64"),
82            })
83        })
84        .collect::<tenferro_tensor::Result<Vec<_>>>()?;
85    let indices = Tensor::I64(TypedTensor::from_vec_col_major(
86        vec![positions.len(), 1],
87        index_data,
88    )?);
89
90    let config = GatherConfig {
91        offset_dims,
92        collapsed_slice_dims: vec![axis],
93        start_index_map: vec![axis],
94        index_vector_dim: 1,
95        slice_sizes,
96    };
97
98    Ok((indices, config, out_shape))
99}
100
101fn validate_stack_shapes(op: &'static str, shapes: &[&[usize]]) -> Result<()> {
102    let Some(first) = shapes.first() else {
103        return Err(tenferro_tensor::Error::InvalidConfig {
104            op,
105            message: "stack requires at least one input".into(),
106        }
107        .into());
108    };
109    for shape in shapes.iter().skip(1) {
110        if *shape != *first {
111            return Err(tenferro_tensor::Error::ShapeMismatch {
112                op,
113                lhs: first.to_vec(),
114                rhs: shape.to_vec(),
115            }
116            .into());
117        }
118    }
119    Ok(())
120}
121
122impl TracedTensor {
123    /// Select entries from one axis using host-known positions.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// use tenferro_cpu::CpuBackend;
129    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, Tensor, TracedTensor};
130    ///
131    /// let x = TracedTensor::from_tensor_concrete_shape(
132    ///     Tensor::from_vec_col_major(vec![3], vec![10.0_f64, 20.0, 30.0]).unwrap(),
133    /// )
134    /// .unwrap();
135    /// let y = x.index_select(-1, &[2, 0]).unwrap();
136    /// let mut compiler = GraphCompiler::new();
137    /// let program = compiler.compile(&y).unwrap();
138    /// let out = GraphExecutor::new(CpuBackend::new()).run(&program).unwrap();
139    ///
140    /// assert_eq!(
141    ///     out.as_slice::<f64>().unwrap(),
142    ///     &[30.0, 10.0],
143    /// );
144    /// ```
145    pub fn index_select(&self, axis: isize, positions: &[usize]) -> Result<Self> {
146        let shape = try_concrete_shape(self).ok_or_else(|| Error::InvalidGraphBuild {
147            op: "index_select",
148            message: "index_select requires a concrete shape hint".into(),
149        })?;
150        let (indices_tensor, config, out_shape) = index_select_config(&shape, axis, positions)?;
151        let indices = TracedTensor::from_tensor_concrete_shape(indices_tensor)?;
152        Ok(apply_binary_preserve_input_dtypes(
153            StdTensorOp::Gather(config),
154            self,
155            &indices,
156            out_shape.len(),
157            Some(out_shape.into_iter().map(SymDim::from).collect()),
158            self.dtype,
159        ))
160    }
161
162    /// Stack tensors along a newly inserted axis.
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// use tenferro_cpu::CpuBackend;
168    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, Tensor, TracedTensor};
169    ///
170    /// let a = TracedTensor::from_tensor_concrete_shape(Tensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap()).unwrap();
171    /// let b = TracedTensor::from_tensor_concrete_shape(Tensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap()).unwrap();
172    /// let stacked = TracedTensor::stack(&[&a, &b], -1).unwrap();
173    /// let mut compiler = GraphCompiler::new();
174    /// let program = compiler.compile(&stacked).unwrap();
175    /// let out = GraphExecutor::new(CpuBackend::new()).run(&program).unwrap();
176    ///
177    /// assert_eq!(
178    ///     out.as_slice::<f64>().unwrap(),
179    ///     &[1.0, 2.0],
180    /// );
181    /// ```
182    pub fn stack(tensors: &[&Self], dim: isize) -> Result<Self> {
183        let first = tensors.first().copied().ok_or_else(|| {
184            Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
185                op: "stack",
186                message: "stack requires at least one input".into(),
187            })
188        })?;
189        let first_shape = try_concrete_shape(first).ok_or_else(|| Error::InvalidGraphBuild {
190            op: "stack",
191            message: "stack requires concrete shape hints".into(),
192        })?;
193        let mut shapes = Vec::with_capacity(tensors.len());
194        shapes.push(first_shape.as_slice());
195        let mut owned_shapes = Vec::with_capacity(tensors.len().saturating_sub(1));
196        for tensor in tensors.iter().copied().skip(1) {
197            owned_shapes.push(try_concrete_shape(tensor).ok_or_else(|| {
198                Error::InvalidGraphBuild {
199                    op: "stack",
200                    message: "stack requires concrete shape hints".into(),
201                }
202            })?);
203        }
204        shapes.extend(owned_shapes.iter().map(Vec::as_slice));
205        validate_stack_shapes("stack", &shapes)?;
206
207        let axis = normalize_insert_axis("stack", dim, first.rank)?;
208        let mut expanded_shape = first_shape;
209        expanded_shape.insert(axis, 1);
210        let mut out_shape = expanded_shape.clone();
211        out_shape[axis] = tensors.len();
212        let expanded = tensors
213            .iter()
214            .map(|tensor| tensor.reshape(&expanded_shape))
215            .collect::<Vec<_>>();
216        let refs = expanded.iter().collect::<Vec<_>>();
217        Ok(apply_nary_concatenate(
218            &refs,
219            axis,
220            out_shape.into_iter().map(SymDim::from).collect(),
221        ))
222    }
223
224    /// Concatenate tensors along one existing axis.
225    pub fn concatenate(tensors: &[&Self], axis: usize) -> Result<Self> {
226        let first = tensors.first().copied().ok_or_else(|| {
227            Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
228                op: "concatenate",
229                message: "concatenate requires at least one input".into(),
230            })
231        })?;
232        if axis >= first.rank {
233            return Err(tenferro_tensor::Error::AxisOutOfBounds {
234                op: "concatenate",
235                axis,
236                rank: first.rank,
237            }
238            .into());
239        }
240        for tensor in tensors.iter().copied().skip(1) {
241            if tensor.rank != first.rank {
242                return Err(tenferro_tensor::Error::RankMismatch {
243                    op: "concatenate",
244                    expected: first.rank,
245                    actual: tensor.rank,
246                }
247                .into());
248            }
249        }
250
251        let op = StdTensorOp::Concatenate {
252            axis,
253            input_count: tensors.len(),
254        };
255        let (_, out_shape_hint) =
256            infer_traced_single_output_shape("TracedTensor::concatenate", &op, tensors)?;
257        let out_shape = out_shape_hint.ok_or_else(|| {
258            Error::Internal("concatenate shape inference returned no shape hint".into())
259        })?;
260        Ok(apply_nary_concatenate(tensors, axis, out_shape))
261    }
262}
263
264fn apply_nary_concatenate(
265    tensors: &[&TracedTensor],
266    axis: usize,
267    out_shape: Vec<SymDim>,
268) -> TracedTensor {
269    let out_dtype = promote_dtypes(tensors.iter().map(|tensor| tensor.dtype));
270    let tensors = tensors
271        .iter()
272        .map(|tensor| {
273            if tensor.dtype != out_dtype {
274                tensor.cast(out_dtype)
275            } else {
276                (*tensor).clone()
277            }
278        })
279        .collect::<Vec<_>>();
280
281    let mut builder = GraphBuilder::new();
282    for tensor in &tensors {
283        builder.add_parent(tensor.graph.clone());
284    }
285    let input_refs = tensors
286        .iter()
287        .map(|tensor| ValueRef::External(tensor.graph.values()[tensor.val].key.clone()))
288        .collect::<Vec<_>>();
289    let outputs = builder.add_operation(
290        StdTensorOp::Concatenate {
291            axis,
292            input_count: tensors.len(),
293        },
294        input_refs,
295        OperationRole::Primary,
296    );
297    builder.set_outputs(outputs.clone());
298    let graph = Arc::new(builder.build());
299    // Callers route through shape inference before graph construction.
300    let metadata_scope = register_scoped_value_metadata(
301        graph.values()[outputs[0]].key.clone(),
302        tensor_meta(out_dtype, out_shape.clone()),
303    )
304    .expect("fresh concatenate output metadata registration failed");
305
306    let mut inputs_map = HashMap::new();
307    let mut extra_roots = Vec::new();
308    let mut checkpoint_chain = None;
309    for tensor in &tensors {
310        inputs_map.extend(
311            tensor
312                .inputs_map
313                .iter()
314                .map(|(k, v)| (k.clone(), v.clone())),
315        );
316        extra_roots.extend(tensor.extra_roots.iter().cloned());
317        checkpoint_chain =
318            CheckpointNode::merge_chains(checkpoint_chain, tensor.checkpoint_chain.clone());
319    }
320    let inherited_scopes = tensors
321        .iter()
322        .map(|tensor| tensor.metadata_scopes.as_slice())
323        .collect::<Vec<_>>();
324
325    TracedTensor {
326        id: next_traced_id(),
327        rank: out_shape.len(),
328        dtype: out_dtype,
329        graph,
330        val: outputs[0],
331        data: None,
332        shape_hint: Some(out_shape),
333        inputs_map: Arc::new(inputs_map),
334        extra_roots,
335        checkpoint_chain,
336        metadata_scopes: metadata_scopes_with_new(metadata_scope, inherited_scopes),
337    }
338}