1use std::collections::HashMap;
2use std::sync::Arc;
3
4use computegraph::graph::GraphBuilder;
5use computegraph::types::{OperationRole, ValueRef};
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_tensor::{GatherConfig, Tensor, TypedTensor};
8
9use crate::checkpoint::CheckpointNode;
10use crate::error::{Error, Result};
11use crate::metadata::{metadata_scopes_with_new, register_scoped_value_metadata, tensor_meta};
12use crate::shape_infer::promote_dtypes;
13use crate::sym_dim::SymDim;
14use crate::traced::{
15 apply_binary_preserve_input_dtypes, infer_traced_single_output_shape, next_traced_id,
16 try_concrete_shape,
17};
18use crate::TracedTensor;
19
20fn normalize_existing_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
21 let normalized = if axis < 0 { rank as isize + axis } else { axis };
22 if normalized < 0 || normalized >= rank as isize {
23 return Err(tenferro_tensor::Error::AxisOutOfBounds {
24 op,
25 axis: axis.unsigned_abs(),
26 rank,
27 }
28 .into());
29 }
30 Ok(normalized as usize)
31}
32
33fn normalize_insert_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
34 let normalized = if axis < 0 {
35 rank as isize + 1 + axis
36 } else {
37 axis
38 };
39 if normalized < 0 || normalized > rank as isize {
40 return Err(tenferro_tensor::Error::AxisOutOfBounds {
41 op,
42 axis: axis.unsigned_abs(),
43 rank: rank + 1,
44 }
45 .into());
46 }
47 Ok(normalized as usize)
48}
49
50fn index_select_config(
51 shape: &[usize],
52 axis: isize,
53 positions: &[usize],
54) -> Result<(Tensor, GatherConfig, Vec<usize>)> {
55 let axis = normalize_existing_axis("index_select", axis, shape.len())?;
56 let axis_extent = shape[axis];
57 for &position in positions {
58 if position >= axis_extent {
59 return Err(tenferro_tensor::Error::InvalidConfig {
60 op: "index_select",
61 message: format!(
62 "position {position} out of bounds for axis {axis} with extent {axis_extent}"
63 ),
64 }
65 .into());
66 }
67 }
68
69 let mut out_shape = shape.to_vec();
70 out_shape[axis] = positions.len();
71
72 let mut slice_sizes = shape.to_vec();
73 slice_sizes[axis] = 1;
74
75 let offset_dims = (0..shape.len()).filter(|&dim| dim != axis).collect();
76 let index_data = positions
77 .iter()
78 .map(|&position| {
79 i64::try_from(position).map_err(|_| tenferro_tensor::Error::InvalidConfig {
80 op: "index_select",
81 message: format!("position {position} cannot be represented as i64"),
82 })
83 })
84 .collect::<tenferro_tensor::Result<Vec<_>>>()?;
85 let indices = Tensor::I64(TypedTensor::from_vec_col_major(
86 vec![positions.len(), 1],
87 index_data,
88 )?);
89
90 let config = GatherConfig {
91 offset_dims,
92 collapsed_slice_dims: vec![axis],
93 start_index_map: vec![axis],
94 index_vector_dim: 1,
95 slice_sizes,
96 };
97
98 Ok((indices, config, out_shape))
99}
100
101fn validate_stack_shapes(op: &'static str, shapes: &[&[usize]]) -> Result<()> {
102 let Some(first) = shapes.first() else {
103 return Err(tenferro_tensor::Error::InvalidConfig {
104 op,
105 message: "stack requires at least one input".into(),
106 }
107 .into());
108 };
109 for shape in shapes.iter().skip(1) {
110 if *shape != *first {
111 return Err(tenferro_tensor::Error::ShapeMismatch {
112 op,
113 lhs: first.to_vec(),
114 rhs: shape.to_vec(),
115 }
116 .into());
117 }
118 }
119 Ok(())
120}
121
122impl TracedTensor {
123 pub fn index_select(&self, axis: isize, positions: &[usize]) -> Result<Self> {
146 let shape = try_concrete_shape(self).ok_or_else(|| Error::InvalidGraphBuild {
147 op: "index_select",
148 message: "index_select requires a concrete shape hint".into(),
149 })?;
150 let (indices_tensor, config, out_shape) = index_select_config(&shape, axis, positions)?;
151 let indices = TracedTensor::from_tensor_concrete_shape(indices_tensor)?;
152 Ok(apply_binary_preserve_input_dtypes(
153 StdTensorOp::Gather(config),
154 self,
155 &indices,
156 out_shape.len(),
157 Some(out_shape.into_iter().map(SymDim::from).collect()),
158 self.dtype,
159 ))
160 }
161
162 pub fn stack(tensors: &[&Self], dim: isize) -> Result<Self> {
183 let first = tensors.first().copied().ok_or_else(|| {
184 Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
185 op: "stack",
186 message: "stack requires at least one input".into(),
187 })
188 })?;
189 let first_shape = try_concrete_shape(first).ok_or_else(|| Error::InvalidGraphBuild {
190 op: "stack",
191 message: "stack requires concrete shape hints".into(),
192 })?;
193 let mut shapes = Vec::with_capacity(tensors.len());
194 shapes.push(first_shape.as_slice());
195 let mut owned_shapes = Vec::with_capacity(tensors.len().saturating_sub(1));
196 for tensor in tensors.iter().copied().skip(1) {
197 owned_shapes.push(try_concrete_shape(tensor).ok_or_else(|| {
198 Error::InvalidGraphBuild {
199 op: "stack",
200 message: "stack requires concrete shape hints".into(),
201 }
202 })?);
203 }
204 shapes.extend(owned_shapes.iter().map(Vec::as_slice));
205 validate_stack_shapes("stack", &shapes)?;
206
207 let axis = normalize_insert_axis("stack", dim, first.rank)?;
208 let mut expanded_shape = first_shape;
209 expanded_shape.insert(axis, 1);
210 let mut out_shape = expanded_shape.clone();
211 out_shape[axis] = tensors.len();
212 let expanded = tensors
213 .iter()
214 .map(|tensor| tensor.reshape(&expanded_shape))
215 .collect::<Vec<_>>();
216 let refs = expanded.iter().collect::<Vec<_>>();
217 Ok(apply_nary_concatenate(
218 &refs,
219 axis,
220 out_shape.into_iter().map(SymDim::from).collect(),
221 ))
222 }
223
224 pub fn concatenate(tensors: &[&Self], axis: usize) -> Result<Self> {
226 let first = tensors.first().copied().ok_or_else(|| {
227 Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
228 op: "concatenate",
229 message: "concatenate requires at least one input".into(),
230 })
231 })?;
232 if axis >= first.rank {
233 return Err(tenferro_tensor::Error::AxisOutOfBounds {
234 op: "concatenate",
235 axis,
236 rank: first.rank,
237 }
238 .into());
239 }
240 for tensor in tensors.iter().copied().skip(1) {
241 if tensor.rank != first.rank {
242 return Err(tenferro_tensor::Error::RankMismatch {
243 op: "concatenate",
244 expected: first.rank,
245 actual: tensor.rank,
246 }
247 .into());
248 }
249 }
250
251 let op = StdTensorOp::Concatenate {
252 axis,
253 input_count: tensors.len(),
254 };
255 let (_, out_shape_hint) =
256 infer_traced_single_output_shape("TracedTensor::concatenate", &op, tensors)?;
257 let out_shape = out_shape_hint.ok_or_else(|| {
258 Error::Internal("concatenate shape inference returned no shape hint".into())
259 })?;
260 Ok(apply_nary_concatenate(tensors, axis, out_shape))
261 }
262}
263
264fn apply_nary_concatenate(
265 tensors: &[&TracedTensor],
266 axis: usize,
267 out_shape: Vec<SymDim>,
268) -> TracedTensor {
269 let out_dtype = promote_dtypes(tensors.iter().map(|tensor| tensor.dtype));
270 let tensors = tensors
271 .iter()
272 .map(|tensor| {
273 if tensor.dtype != out_dtype {
274 tensor.cast(out_dtype)
275 } else {
276 (*tensor).clone()
277 }
278 })
279 .collect::<Vec<_>>();
280
281 let mut builder = GraphBuilder::new();
282 for tensor in &tensors {
283 builder.add_parent(tensor.graph.clone());
284 }
285 let input_refs = tensors
286 .iter()
287 .map(|tensor| ValueRef::External(tensor.graph.values()[tensor.val].key.clone()))
288 .collect::<Vec<_>>();
289 let outputs = builder.add_operation(
290 StdTensorOp::Concatenate {
291 axis,
292 input_count: tensors.len(),
293 },
294 input_refs,
295 OperationRole::Primary,
296 );
297 builder.set_outputs(outputs.clone());
298 let graph = Arc::new(builder.build());
299 let metadata_scope = register_scoped_value_metadata(
301 graph.values()[outputs[0]].key.clone(),
302 tensor_meta(out_dtype, out_shape.clone()),
303 )
304 .expect("fresh concatenate output metadata registration failed");
305
306 let mut inputs_map = HashMap::new();
307 let mut extra_roots = Vec::new();
308 let mut checkpoint_chain = None;
309 for tensor in &tensors {
310 inputs_map.extend(
311 tensor
312 .inputs_map
313 .iter()
314 .map(|(k, v)| (k.clone(), v.clone())),
315 );
316 extra_roots.extend(tensor.extra_roots.iter().cloned());
317 checkpoint_chain =
318 CheckpointNode::merge_chains(checkpoint_chain, tensor.checkpoint_chain.clone());
319 }
320 let inherited_scopes = tensors
321 .iter()
322 .map(|tensor| tensor.metadata_scopes.as_slice())
323 .collect::<Vec<_>>();
324
325 TracedTensor {
326 id: next_traced_id(),
327 rank: out_shape.len(),
328 dtype: out_dtype,
329 graph,
330 val: outputs[0],
331 data: None,
332 shape_hint: Some(out_shape),
333 inputs_map: Arc::new(inputs_map),
334 extra_roots,
335 checkpoint_chain,
336 metadata_scopes: metadata_scopes_with_new(metadata_scope, inherited_scopes),
337 }
338}