Skip to main content

tenferro_runtime/
traced.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use computegraph::graph::{Graph, GraphBuilder};
7use computegraph::types::{OperationRole, ValueKey, ValueRef};
8use computegraph::LocalValueId;
9use num_complex::{Complex32, Complex64};
10use tenferro_ops::ad::context::GlobalMetadataScope;
11use tenferro_ops::broadcast::{broadcast_input_plan, broadcast_shape, broadcast_shapes};
12use tenferro_ops::dim_expr::DimExpr;
13use tenferro_ops::input_key::TensorInputKey;
14use tenferro_ops::std_tensor_op::StdTensorOp;
15use tenferro_tensor::{
16    CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
17    Tensor, TensorScalar,
18};
19
20use super::error::{Error, Result};
21use super::sym_dim::SymDim;
22use crate::checkpoint::CheckpointNode;
23use crate::metadata::{
24    concrete_tensor_meta, metadata_scopes_for_scope, metadata_scopes_with_new, push_metadata_scope,
25    register_scoped_graph_metadata, register_scoped_value_metadata, symbolic_input_meta,
26    tensor_meta,
27};
28use crate::scalar_semantics::round_real_to_i64;
29
30static NEXT_INPUT_ID: AtomicU64 = AtomicU64::new(0);
31static NEXT_TRACED_ID: AtomicU64 = AtomicU64::new(0);
32
33pub type TracedTensorId = u64;
34
35pub(crate) fn next_input_key() -> TensorInputKey {
36    TensorInputKey::User {
37        id: NEXT_INPUT_ID.fetch_add(1, Ordering::Relaxed),
38    }
39}
40
41pub(crate) fn next_traced_id() -> TracedTensorId {
42    NEXT_TRACED_ID.fetch_add(1, Ordering::Relaxed)
43}
44
45#[derive(Clone)]
46pub struct TracedTensor {
47    pub id: TracedTensorId,
48    pub rank: usize,
49    pub dtype: DType,
50    pub(crate) graph: Arc<Graph<StdTensorOp>>,
51    pub val: LocalValueId,
52    pub(crate) data: Option<Arc<Tensor>>,
53    pub(crate) shape_hint: Option<Vec<SymDim>>,
54    pub(crate) inputs_map: Arc<HashMap<TensorInputKey, Arc<Tensor>>>,
55    pub(crate) extra_roots: Vec<Arc<Graph<StdTensorOp>>>,
56    pub(crate) checkpoint_chain: Option<Arc<CheckpointNode>>,
57    pub(crate) metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
58}
59
60impl fmt::Debug for TracedTensor {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("TracedTensor")
63            .field("id", &self.id)
64            .field("rank", &self.rank)
65            .field("dtype", &self.dtype)
66            .field("val", &self.val)
67            .field("shape_hint", &self.shape_hint)
68            .field("has_data", &self.data.is_some())
69            .finish_non_exhaustive()
70    }
71}
72
73pub(crate) fn try_concrete_shape(tensor: &TracedTensor) -> Option<Vec<usize>> {
74    tensor
75        .shape_hint
76        .as_ref()?
77        .iter()
78        .map(SymDim::constant_value)
79        .collect()
80}
81
82pub(crate) fn concrete_shape(tensor: &TracedTensor) -> Result<Vec<usize>> {
83    tensor
84        .shape_hint
85        .as_ref()
86        .ok_or_else(|| Error::InvalidGraphBuild {
87            op: "TracedTensor::concrete_shape",
88            message: format!("missing shape hint for traced tensor {}", tensor.id),
89        })?
90        .iter()
91        .map(|dim| {
92            dim.constant_value()
93                .ok_or_else(|| Error::InvalidGraphBuild {
94                    op: "TracedTensor::concrete_shape",
95                    message: format!("symbolic dimension in shape hint for tensor {}", tensor.id),
96                })
97        })
98        .collect()
99}
100
101/// Broadcast a traced tensor to `target_shape`.
102///
103/// Expanding singleton axes are first reshaped away so the existing
104/// `BroadcastInDim` transpose rule reduces them correctly during VJP.
105pub(crate) fn broadcast_to(tensor: &TracedTensor, target_shape: &[usize]) -> Result<TracedTensor> {
106    let tensor_shape = concrete_shape(tensor)?;
107    if tensor_shape == target_shape {
108        return Ok(tensor.clone());
109    }
110
111    let plan = broadcast_input_plan(&tensor_shape, target_shape).map_err(|err| {
112        Error::InvalidGraphBuild {
113            op: "broadcast_to",
114            message: err.to_string(),
115        }
116    })?;
117
118    let source = if plan.source_shape == tensor_shape {
119        tensor.clone()
120    } else {
121        tensor.reshape(&plan.source_shape)
122    };
123    source.broadcast_in_dim(target_shape, &plan.dims)
124}
125
126/// Broadcast two tensors to a common shape.
127pub(crate) fn broadcast_binary(
128    a: &TracedTensor,
129    b: &TracedTensor,
130) -> Result<(TracedTensor, TracedTensor)> {
131    if a.shape_hint == b.shape_hint && a.rank == b.rank {
132        return Ok((a.clone(), b.clone()));
133    }
134    if (try_concrete_shape(a).is_none() || try_concrete_shape(b).is_none()) && a.rank == b.rank {
135        return Ok((a.clone(), b.clone()));
136    }
137    let a_shape = concrete_shape(a)?;
138    let b_shape = concrete_shape(b)?;
139    let target = broadcast_shape(&a_shape, &b_shape).map_err(|err| Error::InvalidGraphBuild {
140        op: "broadcast_binary",
141        message: err.to_string(),
142    })?;
143    Ok((broadcast_to(a, &target)?, broadcast_to(b, &target)?))
144}
145
146pub(crate) fn broadcast_ternary(
147    a: &TracedTensor,
148    b: &TracedTensor,
149    c: &TracedTensor,
150) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
151    let a_shape = concrete_shape(a)?;
152    let b_shape = concrete_shape(b)?;
153    let c_shape = concrete_shape(c)?;
154    let target = broadcast_shapes([a_shape.as_slice(), b_shape.as_slice(), c_shape.as_slice()])
155        .map_err(|err| Error::InvalidGraphBuild {
156            op: "broadcast_ternary",
157            message: err.to_string(),
158        })?;
159    Ok((
160        broadcast_to(a, &target)?,
161        broadcast_to(b, &target)?,
162        broadcast_to(c, &target)?,
163    ))
164}
165
166fn scale_with_constant(input: &TracedTensor, op: StdTensorOp) -> TracedTensor {
167    let scalar = apply_nullary(op, 0, input.dtype, Some(vec![]));
168    apply_binary(
169        StdTensorOp::Mul,
170        input,
171        &scalar,
172        input.rank,
173        input.shape_hint.clone(),
174    )
175}
176
177fn inferred_output_dtype_or_fallback(
178    op: &StdTensorOp,
179    inputs: &[DType],
180    fallback: DType,
181    context: &'static str,
182) -> DType {
183    match crate::shape_infer::infer_output_dtype(op, inputs) {
184        Ok(dtype) => dtype,
185        Err(err) => {
186            debug_assert!(
187                false,
188                "{context}: built-in traced dtype inference failed for {op:?}: {err}"
189            );
190            fallback
191        }
192    }
193}
194
195fn traced_input_shape_exprs(input_idx: usize, tensor: &TracedTensor) -> Vec<DimExpr> {
196    match tensor.shape_hint.as_ref() {
197        Some(shape) => shape
198            .iter()
199            .enumerate()
200            .map(|(axis, dim)| {
201                dim.constant_value()
202                    .map_or(DimExpr::InputDim { input_idx, axis }, DimExpr::Const)
203            })
204            .collect(),
205        None => (0..tensor.rank)
206            .map(|axis| DimExpr::InputDim { input_idx, axis })
207            .collect(),
208    }
209}
210
211fn traced_input_sym_shape(tensor: &TracedTensor) -> Vec<SymDim> {
212    tensor.shape_hint.clone().unwrap_or_else(|| {
213        (0..tensor.rank)
214            .map(|axis| SymDim::tensor_axis(tensor.id, axis))
215            .collect()
216    })
217}
218
219pub(crate) fn infer_traced_single_output_shape(
220    op_name: &'static str,
221    op: &StdTensorOp,
222    inputs: &[&TracedTensor],
223) -> Result<(usize, Option<Vec<SymDim>>)> {
224    let input_shape_exprs: Vec<Vec<DimExpr>> = inputs
225        .iter()
226        .enumerate()
227        .map(|(input_idx, tensor)| traced_input_shape_exprs(input_idx, tensor))
228        .collect();
229    let input_shape_refs: Vec<&[DimExpr]> = input_shape_exprs.iter().map(Vec::as_slice).collect();
230    let output_shapes =
231        crate::shape_infer::infer_output_shapes(op, &input_shape_refs).map_err(|err| {
232            Error::InvalidGraphBuild {
233                op: op_name,
234                message: err.to_string(),
235            }
236        })?;
237    let output_shape = output_shapes
238        .first()
239        .ok_or_else(|| Error::InvalidGraphBuild {
240            op: op_name,
241            message: "shape inference returned no outputs".into(),
242        })?;
243    if output_shapes.len() != 1 {
244        return Err(Error::InvalidGraphBuild {
245            op: op_name,
246            message: format!(
247                "expected single-output shape inference, got {} outputs",
248                output_shapes.len()
249            ),
250        });
251    }
252
253    let input_sym_shapes: Vec<Vec<SymDim>> = inputs
254        .iter()
255        .map(|tensor| traced_input_sym_shape(tensor))
256        .collect();
257    let input_sym_refs: Vec<&[SymDim]> = input_sym_shapes.iter().map(Vec::as_slice).collect();
258    let out_shape_hint = output_shape
259        .iter()
260        .map(|dim| SymDim::from_dim_expr(dim, &input_sym_refs))
261        .collect();
262    Ok((output_shape.len(), Some(out_shape_hint)))
263}
264
265fn register_metadata_or_internal(
266    result: std::result::Result<GlobalMetadataScope, impl std::fmt::Display>,
267) -> Result<GlobalMetadataScope> {
268    result.map_err(|err| Error::Internal(format!("metadata registration failed: {err}")))
269}
270
271fn reduction_output_meta(
272    tensor: &TracedTensor,
273    axes: &[usize],
274    op: &'static str,
275) -> Result<(usize, Option<Vec<SymDim>>)> {
276    let mut seen = vec![false; tensor.rank];
277    for &axis in axes {
278        if axis >= tensor.rank {
279            return Err(Error::InvalidGraphBuild {
280                op,
281                message: format!("axis {axis} out of bounds for rank {}", tensor.rank),
282            });
283        }
284        if seen[axis] {
285            return Err(Error::InvalidGraphBuild {
286                op,
287                message: format!("duplicate reduction axis {axis}"),
288            });
289        }
290        seen[axis] = true;
291    }
292
293    let out_shape_hint = tensor.shape_hint.as_ref().map(|shape| {
294        (0..shape.len())
295            .filter(|d| !axes.contains(d))
296            .map(|d| shape[d].clone())
297            .collect()
298    });
299    Ok((tensor.rank - axes.len(), out_shape_hint))
300}
301
302fn validate_traced_axis(tensor: &TracedTensor, axis: usize, op: &'static str) -> Result<()> {
303    if axis >= tensor.rank {
304        return Err(Error::InvalidGraphBuild {
305            op,
306            message: format!("axis {axis} out of bounds for rank {}", tensor.rank),
307        });
308    }
309    Ok(())
310}
311
312fn validate_traced_axes(rank: usize, axes: &[usize], op: &'static str) -> Result<()> {
313    let mut seen = vec![false; rank];
314    for &axis in axes {
315        if axis >= rank {
316            return Err(Error::InvalidGraphBuild {
317                op,
318                message: format!("axis {axis} out of bounds for rank {rank}"),
319            });
320        }
321        if seen[axis] {
322            return Err(Error::InvalidGraphBuild {
323                op,
324                message: format!("duplicate axis {axis}"),
325            });
326        }
327        seen[axis] = true;
328    }
329    Ok(())
330}
331
332fn validate_traced_insert_axis(rank: usize, axis: usize, op: &'static str) -> Result<()> {
333    if axis > rank {
334        return Err(Error::InvalidGraphBuild {
335            op,
336            message: format!("axis {axis} out of bounds for rank {rank} insertion"),
337        });
338    }
339    Ok(())
340}
341
342fn validate_traced_perm(rank: usize, perm: &[usize], op: &'static str) -> Result<()> {
343    if perm.len() != rank {
344        return Err(Error::InvalidGraphBuild {
345            op,
346            message: format!(
347                "permutation length {} does not match rank {rank}",
348                perm.len()
349            ),
350        });
351    }
352    let mut seen = vec![false; rank];
353    for &axis in perm {
354        if axis >= rank {
355            return Err(Error::InvalidGraphBuild {
356                op,
357                message: format!("permutation axis {axis} out of bounds for rank {rank}"),
358            });
359        }
360        if seen[axis] {
361            return Err(Error::InvalidGraphBuild {
362                op,
363                message: format!("duplicate permutation axis {axis}"),
364            });
365        }
366        seen[axis] = true;
367    }
368    Ok(())
369}
370
371fn validate_broadcast_in_dim_args(
372    input: &TracedTensor,
373    output_shape: &[SymDim],
374    dims: &[usize],
375    op: &'static str,
376) -> Result<()> {
377    if dims.len() != input.rank {
378        return Err(Error::InvalidGraphBuild {
379            op,
380            message: format!(
381                "dims length {} must match input rank {}",
382                dims.len(),
383                input.rank
384            ),
385        });
386    }
387
388    let mut seen = vec![false; output_shape.len()];
389    for &dim in dims {
390        if dim >= output_shape.len() {
391            return Err(Error::InvalidGraphBuild {
392                op,
393                message: format!(
394                    "broadcast dim {dim} out of bounds for output rank {}",
395                    output_shape.len()
396                ),
397            });
398        }
399        if seen[dim] {
400            return Err(Error::InvalidGraphBuild {
401                op,
402                message: format!("duplicate broadcast dim {dim}"),
403            });
404        }
405        seen[dim] = true;
406    }
407
408    if let Some(input_shape) = input.shape_hint.as_ref() {
409        for (input_axis, &output_axis) in dims.iter().enumerate() {
410            let input_dim = &input_shape[input_axis];
411            let output_dim = &output_shape[output_axis];
412            if input_dim != output_dim && input_dim.constant_value() != Some(1) {
413                return Err(Error::InvalidGraphBuild {
414                    op,
415                    message: format!(
416                        "input axis {input_axis} with dim {input_dim:?} cannot broadcast to \
417                         output axis {output_axis} with dim {output_dim:?}"
418                    ),
419                });
420            }
421        }
422    }
423
424    Ok(())
425}
426
427impl std::ops::Add for &TracedTensor {
428    type Output = Result<TracedTensor>;
429
430    fn add(self, rhs: &TracedTensor) -> Result<TracedTensor> {
431        TracedTensor::add(self, rhs)
432    }
433}
434
435impl std::ops::Sub for &TracedTensor {
436    type Output = Result<TracedTensor>;
437
438    fn sub(self, rhs: &TracedTensor) -> Result<TracedTensor> {
439        TracedTensor::sub(self, rhs)
440    }
441}
442
443impl std::ops::Mul for &TracedTensor {
444    type Output = Result<TracedTensor>;
445
446    fn mul(self, rhs: &TracedTensor) -> Result<TracedTensor> {
447        TracedTensor::mul(self, rhs)
448    }
449}
450
451impl std::ops::Mul<f64> for &TracedTensor {
452    type Output = TracedTensor;
453
454    fn mul(self, rhs: f64) -> TracedTensor {
455        self.scale_real(rhs)
456    }
457}
458
459impl std::ops::Mul<&TracedTensor> for f64 {
460    type Output = TracedTensor;
461
462    fn mul(self, rhs: &TracedTensor) -> TracedTensor {
463        rhs.scale_real(self)
464    }
465}
466
467impl std::ops::Neg for &TracedTensor {
468    type Output = TracedTensor;
469
470    fn neg(self) -> TracedTensor {
471        TracedTensor::neg(self)
472    }
473}
474
475impl std::ops::Div for &TracedTensor {
476    type Output = Result<TracedTensor>;
477
478    fn div(self, rhs: &TracedTensor) -> Result<TracedTensor> {
479        TracedTensor::div(self, rhs)
480    }
481}
482
483impl TracedTensor {
484    /// Return the graph that owns this traced tensor's current value.
485    ///
486    /// # Examples
487    ///
488    /// ```
489    /// use tenferro_runtime::TracedTensor;
490    ///
491    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap();
492    /// let _graph = x.graph();
493    /// ```
494    pub fn graph(&self) -> &Arc<Graph<StdTensorOp>> {
495        &self.graph
496    }
497
498    /// Return the concrete tensor data attached to this traced value, if any.
499    ///
500    /// Placeholder tensors created with `input_concrete_shape` or
501    /// `input_symbolic_shape` have no attached data until execution bindings
502    /// provide it.
503    ///
504    /// # Examples
505    ///
506    /// ```
507    /// use tenferro_runtime::{DType, TracedTensor};
508    ///
509    /// let concrete = TracedTensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap();
510    /// assert!(concrete.attached_data().is_some());
511    ///
512    /// let placeholder = TracedTensor::input_symbolic_shape(DType::F64, 1).unwrap();
513    /// assert!(placeholder.attached_data().is_none());
514    /// ```
515    pub fn attached_data(&self) -> Option<&Arc<Tensor>> {
516        self.data.as_ref()
517    }
518
519    /// Build a [`TracedTensor`] leaf from a concrete [`Tensor`], keeping its
520    /// shape as a concrete `shape_hint`.
521    ///
522    /// This is the common constructor when you have concrete tensor data that
523    /// you want to use both for graph building and for evaluation. The
524    /// resulting tensor is treated as a concrete-shape leaf by downstream
525    /// passes (binary einsum decomposition, build-time reshape folding, etc.).
526    ///
527    /// # Examples
528    ///
529    /// ```
530    /// use tenferro_runtime::{Tensor, TracedTensor};
531    ///
532    /// let a = TracedTensor::from_tensor_concrete_shape(
533    ///     Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
534    /// )
535    /// .unwrap();
536    /// assert_eq!(a.rank, 2);
537    /// assert!(a.is_concrete_shape());
538    /// ```
539    pub fn from_tensor_concrete_shape(tensor: Tensor) -> Result<Self> {
540        let shape = tensor.shape().to_vec();
541        let rank = shape.len();
542        let dtype = tensor.dtype();
543        let key = next_input_key();
544        let id = next_traced_id();
545        let data = Arc::new(tensor);
546
547        let mut builder = GraphBuilder::new();
548        let val = builder.add_input(key.clone());
549        builder.set_outputs(vec![val]);
550        let graph = Arc::new(builder.build());
551        let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
552            graph.values()[val].key.clone(),
553            concrete_tensor_meta(dtype, &shape),
554        ))?;
555
556        let mut map = HashMap::new();
557        map.insert(key, Arc::clone(&data));
558
559        Ok(Self {
560            id,
561            rank,
562            dtype,
563            graph,
564            val,
565            data: Some(data),
566            shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
567            inputs_map: Arc::new(map),
568            extra_roots: Vec::new(),
569            checkpoint_chain: None,
570            metadata_scopes: metadata_scopes_for_scope(metadata_scope),
571        })
572    }
573
574    /// Build a [`TracedTensor`] leaf from a concrete [`Tensor`] but advertise
575    /// a symbolic shape during graph construction.
576    ///
577    /// The tensor data is still attached (so plain `eval` works without
578    /// bindings), but graph passes see the leaf as shape-symbolic. This is
579    /// useful for building a single traced program that should not bake in
580    /// shape-specific optimizations.
581    ///
582    /// # Examples
583    ///
584    /// ```
585    /// use tenferro_runtime::{Tensor, TracedTensor};
586    ///
587    /// let t = TracedTensor::from_tensor_symbolic_shape(
588    ///     Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
589    /// )
590    /// .unwrap();
591    /// assert_eq!(t.rank, 2);
592    /// assert!(!t.is_concrete_shape());
593    /// ```
594    pub fn from_tensor_symbolic_shape(tensor: Tensor) -> Result<Self> {
595        let rank = tensor.shape().len();
596        let dtype = tensor.dtype();
597        let key = next_input_key();
598        let id = next_traced_id();
599        let data = Arc::new(tensor);
600
601        let mut builder = GraphBuilder::new();
602        let val = builder.add_input(key.clone());
603        builder.set_outputs(vec![val]);
604        let graph = Arc::new(builder.build());
605        let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
606            graph.values()[val].key.clone(),
607            symbolic_input_meta(dtype, id, rank),
608        ))?;
609
610        let mut map = HashMap::new();
611        map.insert(key, Arc::clone(&data));
612
613        Ok(Self {
614            id,
615            rank,
616            dtype,
617            graph,
618            val,
619            data: Some(data),
620            shape_hint: None,
621            inputs_map: Arc::new(map),
622            extra_roots: Vec::new(),
623            checkpoint_chain: None,
624            metadata_scopes: metadata_scopes_for_scope(metadata_scope),
625        })
626    }
627
628    /// Build a data-less placeholder leaf with a fixed (concrete) shape.
629    ///
630    /// Must be bound via [`crate::GraphExecutor::run_with_inputs`] before evaluation.
631    /// Use this when you know the exact shape of the input but want to build
632    /// the graph once and feed different concrete tensors at execution time.
633    ///
634    /// # Examples
635    ///
636    /// ```
637    /// use tenferro_tensor::DType;
638    /// use tenferro_runtime::TracedTensor;
639    ///
640    /// let x = TracedTensor::input_concrete_shape(DType::F64, &[2, 3]).unwrap();
641    /// assert_eq!(x.rank, 2);
642    /// assert!(x.is_concrete_shape());
643    /// ```
644    pub fn input_concrete_shape(dtype: DType, shape: &[usize]) -> Result<Self> {
645        let shape = shape.to_vec();
646        let rank = shape.len();
647        let key = next_input_key();
648        let id = next_traced_id();
649
650        let mut builder = GraphBuilder::new();
651        let val = builder.add_input(key.clone());
652        builder.set_outputs(vec![val]);
653        let graph = Arc::new(builder.build());
654        let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
655            graph.values()[val].key.clone(),
656            concrete_tensor_meta(dtype, &shape),
657        ))?;
658
659        Ok(Self {
660            id,
661            rank,
662            dtype,
663            graph,
664            val,
665            data: None,
666            shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
667            inputs_map: Arc::new(HashMap::new()),
668            extra_roots: Vec::new(),
669            checkpoint_chain: None,
670            metadata_scopes: metadata_scopes_for_scope(metadata_scope),
671        })
672    }
673
674    /// Build a data-less placeholder leaf with the given rank but fully
675    /// symbolic shape (every dim is a distinct `SymDim::TensorAxis`).
676    ///
677    /// Must be bound via [`crate::GraphExecutor::run_with_inputs`] before
678    /// evaluation. Use this to build shape-agnostic graphs.
679    ///
680    /// # Examples
681    ///
682    /// ```
683    /// use tenferro_tensor::DType;
684    /// use tenferro_runtime::TracedTensor;
685    ///
686    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
687    /// assert_eq!(x.rank, 2);
688    /// assert!(!x.is_concrete_shape());
689    /// ```
690    pub fn input_symbolic_shape(dtype: DType, rank: usize) -> Result<Self> {
691        let key = next_input_key();
692        let id = next_traced_id();
693
694        let mut builder = GraphBuilder::new();
695        let val = builder.add_input(key.clone());
696        builder.set_outputs(vec![val]);
697        let graph = Arc::new(builder.build());
698        let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
699            graph.values()[val].key.clone(),
700            symbolic_input_meta(dtype, id, rank),
701        ))?;
702
703        Ok(Self {
704            id,
705            rank,
706            dtype,
707            graph,
708            val,
709            data: None,
710            shape_hint: None,
711            inputs_map: Arc::new(HashMap::new()),
712            extra_roots: Vec::new(),
713            checkpoint_chain: None,
714            metadata_scopes: metadata_scopes_for_scope(metadata_scope),
715        })
716    }
717
718    /// Build a concrete-shape [`TracedTensor`] leaf from column-major typed
719    /// `Vec<T>` data.
720    ///
721    /// The data must already be in tenferro's physical column-major order.
722    ///
723    /// # Examples
724    ///
725    /// ```
726    /// use tenferro_runtime::TracedTensor;
727    ///
728    /// let a = TracedTensor::from_vec_col_major(
729    ///     vec![2, 3],
730    ///     vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0],
731    /// )?;
732    /// assert_eq!(a.rank, 2);
733    /// # Ok::<(), tenferro_runtime::Error>(())
734    /// ```
735    pub fn from_vec_col_major<T: TensorScalar>(shape: Vec<usize>, data: Vec<T>) -> Result<Self> {
736        Self::from_tensor_concrete_shape(Tensor::from_vec_col_major(shape, data)?)
737    }
738
739    /// Returns `true` iff every dim of this tensor's `shape_hint` is a
740    /// constant `SymDim` (i.e. the shape is fully known at graph-build time).
741    ///
742    /// # Examples
743    ///
744    /// ```
745    /// use tenferro_tensor::DType;
746    /// use tenferro_runtime::TracedTensor;
747    ///
748    /// let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
749    /// let b = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
750    /// assert!(a.is_concrete_shape());
751    /// assert!(!b.is_concrete_shape());
752    /// ```
753    pub fn is_concrete_shape(&self) -> bool {
754        try_concrete_shape(self).is_some()
755    }
756
757    /// Return the fully-concrete shape of this tensor, if every dim of
758    /// its shape-hint is a constant `SymDim`. Returns `None` if any
759    /// dimension is symbolic.
760    ///
761    /// This is the counterpart to [`Self::is_concrete_shape`] for callers
762    /// that need to *use* the concrete shape (e.g. external composition
763    /// wrappers building `broadcast_in_dim` payloads from known shapes).
764    ///
765    /// # Examples
766    ///
767    /// ```
768    /// use tenferro_tensor::DType;
769    /// use tenferro_runtime::TracedTensor;
770    ///
771    /// let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
772    /// assert_eq!(a.try_concrete_shape(), Some(vec![2, 3]));
773    ///
774    /// let b = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
775    /// assert!(b.try_concrete_shape().is_none());
776    /// ```
777    pub fn try_concrete_shape(&self) -> Option<Vec<usize>> {
778        try_concrete_shape(self)
779    }
780
781    /// Return the concrete tensor shape.
782    ///
783    /// Returns an error when a shape hint is missing or any dimension is
784    /// symbolic. Composite traced ops that require concrete sizes should
785    /// propagate this error instead of panicking.
786    pub fn concrete_shape(&self) -> Result<Vec<usize>> {
787        concrete_shape(self)
788    }
789
790    /// If this `TracedTensor` is a leaf (single-node input graph),
791    /// return its input key. Computed tensors return `None`.
792    pub fn input_key(&self) -> Option<TensorInputKey> {
793        match &self.graph.values()[self.val].key {
794            ValueKey::Input(key) => Some(key.clone()),
795            _ => None,
796        }
797    }
798
799    /// Elementwise addition with NumPy-style broadcasting.
800    ///
801    /// Prefer using the `+` operator when it reads naturally.
802    ///
803    /// # Examples
804    ///
805    /// ```rust
806    /// # use tenferro_runtime::TracedTensor;
807    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
808    /// # let z = TracedTensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap();
809    /// let y = x.add(&z);
810    /// let y2 = &x + &z;
811    /// ```
812    pub fn add(&self, other: &TracedTensor) -> Result<TracedTensor> {
813        let (lhs, rhs) = broadcast_binary(self, other)?;
814        Ok(apply_binary(
815            StdTensorOp::Add,
816            &lhs,
817            &rhs,
818            lhs.rank,
819            lhs.shape_hint.clone(),
820        ))
821    }
822
823    /// Elementwise subtraction with NumPy-style broadcasting.
824    ///
825    /// Prefer using the `-` operator when it reads naturally.
826    pub fn sub(&self, other: &TracedTensor) -> Result<TracedTensor> {
827        let (lhs, rhs) = broadcast_binary(self, other)?;
828        let rhs = rhs.neg();
829        Ok(apply_binary(
830            StdTensorOp::Add,
831            &lhs,
832            &rhs,
833            lhs.rank,
834            lhs.shape_hint.clone(),
835        ))
836    }
837
838    /// Elementwise multiplication with NumPy-style broadcasting.
839    ///
840    /// Prefer using the `*` operator when it reads naturally.
841    ///
842    /// # Examples
843    ///
844    /// ```rust
845    /// # use tenferro_runtime::TracedTensor;
846    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
847    /// # let z = TracedTensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap();
848    /// let y = x.mul(&z);
849    /// let y2 = &x * &z;
850    /// ```
851    pub fn mul(&self, other: &TracedTensor) -> Result<TracedTensor> {
852        let (lhs, rhs) = broadcast_binary(self, other)?;
853        Ok(apply_binary(
854            StdTensorOp::Mul,
855            &lhs,
856            &rhs,
857            lhs.rank,
858            lhs.shape_hint.clone(),
859        ))
860    }
861
862    /// Elementwise division with NumPy-style broadcasting.
863    ///
864    /// Prefer using the `/` operator when it reads naturally.
865    ///
866    /// # Examples
867    ///
868    /// ```rust
869    /// # use tenferro_runtime::TracedTensor;
870    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
871    /// # let z = TracedTensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap();
872    /// let y = x.div(&z);
873    /// let y2 = &x / &z;
874    /// ```
875    pub fn div(&self, other: &TracedTensor) -> Result<TracedTensor> {
876        let (lhs, rhs) = broadcast_binary(self, other)?;
877        Ok(apply_binary(
878            StdTensorOp::Div,
879            &lhs,
880            &rhs,
881            lhs.rank,
882            lhs.shape_hint.clone(),
883        ))
884    }
885
886    /// Elementwise comparison with NumPy-style broadcasting.
887    pub fn compare(&self, other: &TracedTensor, dir: CompareDir) -> Result<TracedTensor> {
888        let (lhs, rhs) = broadcast_binary(self, other)?;
889        Ok(apply_binary(
890            StdTensorOp::Compare(dir),
891            &lhs,
892            &rhs,
893            lhs.rank,
894            lhs.shape_hint.clone(),
895        ))
896    }
897
898    /// Elementwise maximum with NumPy-style broadcasting.
899    pub fn maximum(&self, other: &TracedTensor) -> Result<TracedTensor> {
900        apply_broadcast_binary_op(StdTensorOp::Maximum, self, other)
901    }
902
903    /// Elementwise minimum with NumPy-style broadcasting.
904    pub fn minimum(&self, other: &TracedTensor) -> Result<TracedTensor> {
905        apply_broadcast_binary_op(StdTensorOp::Minimum, self, other)
906    }
907
908    /// Select values from `on_true` or `on_false` using `condition`.
909    pub fn where_select(
910        condition: &TracedTensor,
911        on_true: &TracedTensor,
912        on_false: &TracedTensor,
913    ) -> Result<TracedTensor> {
914        apply_broadcast_ternary_op(StdTensorOp::Select, condition, on_true, on_false)
915    }
916
917    /// Alias for [`Self::where_select`].
918    pub fn select(
919        condition: &TracedTensor,
920        on_true: &TracedTensor,
921        on_false: &TracedTensor,
922    ) -> Result<TracedTensor> {
923        Self::where_select(condition, on_true, on_false)
924    }
925
926    /// Clamp values elementwise between lower and upper bounds.
927    pub fn clamp(&self, lower: &TracedTensor, upper: &TracedTensor) -> Result<TracedTensor> {
928        apply_broadcast_ternary_op(StdTensorOp::Clamp, self, lower, upper)
929    }
930
931    fn apply_same_shape_unary(&self, op: StdTensorOp) -> TracedTensor {
932        apply_unary(op, self, self.rank, self.shape_hint.clone())
933    }
934
935    /// Elementwise negation.
936    ///
937    /// Prefer using the unary `-` operator when it reads naturally.
938    ///
939    /// # Examples
940    ///
941    /// ```rust
942    /// # use tenferro_runtime::TracedTensor;
943    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
944    /// let y = x.neg();
945    /// let y2 = -&x;
946    /// ```
947    pub fn neg(&self) -> TracedTensor {
948        self.apply_same_shape_unary(StdTensorOp::Neg)
949    }
950
951    /// Elementwise complex conjugate.
952    ///
953    /// # Examples
954    ///
955    /// ```rust
956    /// # use num_complex::Complex64;
957    /// # use tenferro_runtime::TracedTensor;
958    /// # let x = TracedTensor::from_vec_col_major(
959    /// #     vec![2],
960    /// #     vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
961    /// # )
962    /// # .unwrap();
963    /// let y = x.conj();
964    /// ```
965    pub fn conj(&self) -> TracedTensor {
966        self.apply_same_shape_unary(StdTensorOp::Conj)
967    }
968
969    /// Elementwise absolute value.
970    ///
971    /// Complex inputs return real magnitudes (`C32 -> F32`, `C64 -> F64`).
972    ///
973    /// # Examples
974    ///
975    /// ```rust
976    /// # use tenferro_runtime::TracedTensor;
977    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![-1.0_f64, 2.0]).unwrap();
978    /// let y = x.abs();
979    /// ```
980    pub fn abs(&self) -> TracedTensor {
981        self.apply_same_shape_unary(StdTensorOp::Abs)
982    }
983
984    /// Elementwise sign.
985    ///
986    /// # Examples
987    ///
988    /// ```rust
989    /// # use tenferro_runtime::TracedTensor;
990    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![-1.0_f64, 2.0]).unwrap();
991    /// let y = x.sign();
992    /// ```
993    pub fn sign(&self) -> TracedTensor {
994        self.apply_same_shape_unary(StdTensorOp::Sign)
995    }
996
997    /// Scale by a real scalar: `y = factor * x`.
998    ///
999    /// # Examples
1000    ///
1001    /// ```rust
1002    /// # use tenferro_runtime::TracedTensor;
1003    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1004    /// let y = x.scale_real(2.0);
1005    /// ```
1006    pub fn scale_real(&self, factor: f64) -> TracedTensor {
1007        let op = match self.dtype {
1008            DType::F64 => StdTensorOp::constant(factor),
1009            DType::F32 => StdTensorOp::constant(factor as f32),
1010            DType::I32 => StdTensorOp::constant(round_real_to_i64(factor) as i32),
1011            DType::I64 => StdTensorOp::constant(round_real_to_i64(factor)),
1012            DType::Bool => StdTensorOp::constant(factor != 0.0),
1013            DType::C64 => StdTensorOp::constant(Complex64::new(factor, 0.0)),
1014            DType::C32 => StdTensorOp::constant(Complex32::new(factor as f32, 0.0)),
1015        };
1016        scale_with_constant(self, op)
1017    }
1018
1019    /// Scale by a complex scalar: `y = factor * x`.
1020    ///
1021    /// Only complex tensors support complex scaling. For a real scalar factor
1022    /// that should preserve the input dtype, prefer [`scale_real`](Self::scale_real).
1023    ///
1024    /// # Examples
1025    ///
1026    /// ```rust
1027    /// use num_complex::Complex64;
1028    /// # use tenferro_runtime::TracedTensor;
1029    /// # let x = TracedTensor::from_vec_col_major(
1030    /// #     vec![2],
1031    /// #     vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
1032    /// # )
1033    /// # .unwrap();
1034    /// let y = x.scale_complex(Complex64::new(0.0, 1.0)).unwrap(); // multiply by i
1035    /// ```
1036    pub fn scale_complex(&self, factor: Complex64) -> Result<TracedTensor> {
1037        match self.dtype {
1038            DType::C64 => Ok(scale_with_constant(self, StdTensorOp::constant(factor))),
1039            DType::C32 => Ok(scale_with_constant(
1040                self,
1041                StdTensorOp::constant(Complex32::new(factor.re as f32, factor.im as f32)),
1042            )),
1043            DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => {
1044                Err(Error::InvalidGraphBuild {
1045                    op: "scale_complex",
1046                    message: format!("requires complex tensor dtype, got {:?}", self.dtype),
1047                })
1048            }
1049        }
1050    }
1051
1052    /// Elementwise exponential.
1053    ///
1054    /// # Examples
1055    ///
1056    /// ```rust
1057    /// # use tenferro_runtime::TracedTensor;
1058    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1059    /// let y = x.exp();
1060    /// ```
1061    pub fn exp(&self) -> TracedTensor {
1062        self.apply_same_shape_unary(StdTensorOp::Exp)
1063    }
1064
1065    /// Elementwise natural logarithm.
1066    ///
1067    /// # Examples
1068    ///
1069    /// ```rust
1070    /// # use tenferro_runtime::TracedTensor;
1071    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1072    /// let y = x.log();
1073    /// ```
1074    pub fn log(&self) -> TracedTensor {
1075        self.apply_same_shape_unary(StdTensorOp::Log)
1076    }
1077
1078    /// Elementwise sine.
1079    ///
1080    /// # Examples
1081    ///
1082    /// ```rust
1083    /// # use tenferro_runtime::TracedTensor;
1084    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1085    /// let y = x.sin();
1086    /// ```
1087    pub fn sin(&self) -> TracedTensor {
1088        self.apply_same_shape_unary(StdTensorOp::Sin)
1089    }
1090
1091    /// Elementwise cosine.
1092    ///
1093    /// # Examples
1094    ///
1095    /// ```rust
1096    /// # use tenferro_runtime::TracedTensor;
1097    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1098    /// let y = x.cos();
1099    /// ```
1100    pub fn cos(&self) -> TracedTensor {
1101        self.apply_same_shape_unary(StdTensorOp::Cos)
1102    }
1103
1104    /// Elementwise hyperbolic tangent.
1105    ///
1106    /// # Examples
1107    ///
1108    /// ```rust
1109    /// # use tenferro_runtime::TracedTensor;
1110    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1111    /// let y = x.tanh();
1112    /// ```
1113    pub fn tanh(&self) -> TracedTensor {
1114        self.apply_same_shape_unary(StdTensorOp::Tanh)
1115    }
1116
1117    /// Elementwise square root.
1118    ///
1119    /// # Examples
1120    ///
1121    /// ```rust
1122    /// # use tenferro_runtime::TracedTensor;
1123    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 4.0]).unwrap();
1124    /// let y = x.sqrt();
1125    /// ```
1126    pub fn sqrt(&self) -> TracedTensor {
1127        self.apply_same_shape_unary(StdTensorOp::Sqrt)
1128    }
1129
1130    /// Elementwise reciprocal square root.
1131    ///
1132    /// # Examples
1133    ///
1134    /// ```rust
1135    /// # use tenferro_runtime::TracedTensor;
1136    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 4.0]).unwrap();
1137    /// let y = x.rsqrt();
1138    /// ```
1139    pub fn rsqrt(&self) -> TracedTensor {
1140        self.apply_same_shape_unary(StdTensorOp::Rsqrt)
1141    }
1142
1143    /// Elementwise power with NumPy-style broadcasting.
1144    ///
1145    /// # Examples
1146    ///
1147    /// ```rust
1148    /// # use tenferro_runtime::TracedTensor;
1149    /// # let base = TracedTensor::from_vec_col_major(vec![2], vec![2.0_f64, 3.0]).unwrap();
1150    /// # let exp = TracedTensor::from_vec_col_major(vec![2], vec![3.0_f64, 2.0]).unwrap();
1151    /// let y = base.pow(&exp);
1152    /// ```
1153    pub fn pow(&self, other: &TracedTensor) -> Result<TracedTensor> {
1154        let (lhs, rhs) = broadcast_binary(self, other)?;
1155        Ok(apply_binary(
1156            StdTensorOp::Pow,
1157            &lhs,
1158            &rhs,
1159            lhs.rank,
1160            lhs.shape_hint.clone(),
1161        ))
1162    }
1163
1164    /// Elementwise `exp(x) - 1`.
1165    ///
1166    /// # Examples
1167    ///
1168    /// ```rust
1169    /// # use tenferro_runtime::TracedTensor;
1170    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1171    /// let y = x.expm1();
1172    /// ```
1173    pub fn expm1(&self) -> TracedTensor {
1174        self.apply_same_shape_unary(StdTensorOp::Expm1)
1175    }
1176
1177    /// Elementwise `log(1 + x)`.
1178    ///
1179    /// # Examples
1180    ///
1181    /// ```rust
1182    /// # use tenferro_runtime::TracedTensor;
1183    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1184    /// let y = x.log1p();
1185    /// ```
1186    pub fn log1p(&self) -> TracedTensor {
1187        self.apply_same_shape_unary(StdTensorOp::Log1p)
1188    }
1189
1190    /// Convert the tensor to a different dtype using checked conversion.
1191    ///
1192    /// Use [`cast`](Self::cast) when a lossy dtype projection is intended.
1193    ///
1194    /// # Examples
1195    ///
1196    /// ```rust
1197    /// use tenferro_runtime::DType;
1198    /// # use tenferro_runtime::TracedTensor;
1199    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
1200    ///
1201    /// let y = x.convert(DType::C64)?;
1202    /// # Ok::<(), tenferro_runtime::Error>(())
1203    /// ```
1204    ///
1205    /// # Errors
1206    ///
1207    /// Returns an error when the requested conversion is outside tenferro's
1208    /// checked dtype-promotion lattice. Use [`cast`](Self::cast) for explicit
1209    /// lossy dtype projection.
1210    pub fn convert(&self, to: DType) -> Result<TracedTensor> {
1211        tenferro_tensor::validate::validate_convert_dtype("TracedTensor::convert", self.dtype, to)?;
1212        Ok(self.cast(to))
1213    }
1214
1215    /// Cast the tensor to a different dtype using explicit dtype projection.
1216    ///
1217    /// `cast` may truncate, narrow precision, project complex values to their
1218    /// real component, or use boolean truthiness where the backend supports the
1219    /// requested projection.
1220    ///
1221    /// # Examples
1222    ///
1223    /// ```rust
1224    /// use tenferro_runtime::DType;
1225    /// # use tenferro_runtime::TracedTensor;
1226    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.2_f64, -2.8]).unwrap();
1227    ///
1228    /// let y = x.cast(DType::I32);
1229    /// ```
1230    pub fn cast(&self, to: DType) -> TracedTensor {
1231        if self.dtype == to {
1232            return self.clone();
1233        }
1234
1235        apply_unary_with_dtype(
1236            StdTensorOp::Convert {
1237                from: self.dtype,
1238                to,
1239            },
1240            self,
1241            self.rank,
1242            self.shape_hint.clone(),
1243            to,
1244        )
1245    }
1246
1247    /// Generalized tensor contraction.
1248    ///
1249    /// # Examples
1250    ///
1251    /// ```rust
1252    /// # use tenferro_runtime::{DotGeneralConfig, TracedTensor};
1253    /// # let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1254    /// # let b = TracedTensor::from_vec_col_major(vec![3, 4], vec![1.0_f64; 12]).unwrap();
1255    /// # let config = DotGeneralConfig {
1256    /// #     lhs_contracting_dims: vec![1],
1257    /// #     rhs_contracting_dims: vec![0],
1258    /// #     lhs_batch_dims: vec![],
1259    /// #     rhs_batch_dims: vec![],
1260    /// # };
1261    /// let y = a.dot_general(&b, config)?;
1262    /// # Ok::<(), tenferro_runtime::Error>(())
1263    /// ```
1264    ///
1265    /// # Errors
1266    ///
1267    /// Returns an error when the dimension-numbering configuration is invalid
1268    /// for the operand ranks.
1269    pub fn dot_general(
1270        &self,
1271        other: &TracedTensor,
1272        config: DotGeneralConfig,
1273    ) -> Result<TracedTensor> {
1274        config
1275            .validate_dims_with_ranks(self.rank, other.rank)
1276            .map_err(|message| Error::InvalidGraphBuild {
1277                op: "dot_general",
1278                message,
1279            })?;
1280        let lhs_free: Vec<usize> = (0..self.rank)
1281            .filter(|d| {
1282                !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d)
1283            })
1284            .collect();
1285        let rhs_free: Vec<usize> = (0..other.rank)
1286            .filter(|d| {
1287                !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d)
1288            })
1289            .collect();
1290        let out_rank = config.lhs_batch_dims.len() + lhs_free.len() + rhs_free.len();
1291        let out_shape_hint = match (&self.shape_hint, &other.shape_hint) {
1292            (Some(lhs_shape), Some(rhs_shape)) => {
1293                let mut out_shape = Vec::with_capacity(out_rank);
1294                for &d in &lhs_free {
1295                    out_shape.push(lhs_shape[d].clone());
1296                }
1297                for &d in &rhs_free {
1298                    out_shape.push(rhs_shape[d].clone());
1299                }
1300                for &d in &config.lhs_batch_dims {
1301                    out_shape.push(lhs_shape[d].clone());
1302                }
1303                Some(out_shape)
1304            }
1305            _ => None,
1306        };
1307
1308        Ok(apply_binary(
1309            StdTensorOp::DotGeneral { config },
1310            self,
1311            other,
1312            out_rank,
1313            out_shape_hint,
1314        ))
1315    }
1316
1317    /// Matrix multiplication for rank-2 tensors.
1318    pub fn matmul(&self, other: &TracedTensor) -> Result<TracedTensor> {
1319        if self.rank != 2 {
1320            return Err(Error::InvalidGraphBuild {
1321                op: "TracedTensor::matmul",
1322                message: format!("matmul requires rank-2 inputs, got lhs rank {}", self.rank),
1323            });
1324        }
1325        if other.rank != 2 {
1326            return Err(Error::InvalidGraphBuild {
1327                op: "TracedTensor::matmul",
1328                message: format!("matmul requires rank-2 inputs, got rhs rank {}", other.rank),
1329            });
1330        }
1331        if let (Some(lhs_shape), Some(rhs_shape)) = (&self.shape_hint, &other.shape_hint) {
1332            if let (Some(lhs_cols), Some(rhs_rows)) =
1333                (lhs_shape[1].constant_value(), rhs_shape[0].constant_value())
1334            {
1335                if lhs_cols != rhs_rows {
1336                    return Err(Error::InvalidGraphBuild {
1337                        op: "TracedTensor::matmul",
1338                        message: format!(
1339                            "matmul dimension mismatch: lhs columns {lhs_cols} != rhs rows {rhs_rows}"
1340                        ),
1341                    });
1342                }
1343            }
1344        }
1345        self.dot_general(
1346            other,
1347            DotGeneralConfig {
1348                lhs_contracting_dims: vec![1],
1349                rhs_contracting_dims: vec![0],
1350                lhs_batch_dims: vec![],
1351                rhs_batch_dims: vec![],
1352            },
1353        )
1354    }
1355
1356    /// Sum over the given axes.
1357    ///
1358    /// # Examples
1359    ///
1360    /// ```rust
1361    /// # use tenferro_runtime::TracedTensor;
1362    /// # let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64; 4]).unwrap();
1363    /// let y = x.reduce_sum(&[0])?;
1364    /// let y2 = x.reduce_sum(&[0])?;
1365    /// # Ok::<(), tenferro_runtime::Error>(())
1366    /// ```
1367    ///
1368    /// # Errors
1369    ///
1370    /// Returns an error when an axis is out of bounds or duplicated.
1371    pub fn reduce_sum(&self, axes: &[usize]) -> Result<TracedTensor> {
1372        let (out_rank, out_shape_hint) =
1373            reduction_output_meta(self, axes, "TracedTensor::reduce_sum")?;
1374        Ok(apply_unary(
1375            StdTensorOp::ReduceSum {
1376                axes: axes.to_vec(),
1377            },
1378            self,
1379            out_rank,
1380            out_shape_hint,
1381        ))
1382    }
1383
1384    /// Reduce by taking the maximum along the given axes.
1385    ///
1386    /// Used by tropical (max-plus) compositions: a max-plus reduction over
1387    /// an axis is `ReduceMax` on that axis.
1388    ///
1389    /// # Examples
1390    ///
1391    /// ```rust
1392    /// # use tenferro_runtime::TracedTensor;
1393    /// # let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64; 4]).unwrap();
1394    /// let y = x.reduce_max(&[0])?;
1395    /// # Ok::<(), tenferro_runtime::Error>(())
1396    /// ```
1397    ///
1398    /// # Errors
1399    ///
1400    /// Returns an error when an axis is out of bounds or duplicated.
1401    pub fn reduce_max(&self, axes: &[usize]) -> Result<TracedTensor> {
1402        let (out_rank, out_shape_hint) =
1403            reduction_output_meta(self, axes, "TracedTensor::reduce_max")?;
1404        Ok(apply_unary(
1405            StdTensorOp::ReduceMax {
1406                axes: axes.to_vec(),
1407            },
1408            self,
1409            out_rank,
1410            out_shape_hint,
1411        ))
1412    }
1413
1414    /// Reduce by taking the minimum along the given axes.
1415    ///
1416    /// Used by tropical (min-plus) compositions: a min-plus reduction over
1417    /// an axis is `ReduceMin` on that axis.
1418    ///
1419    /// # Examples
1420    ///
1421    /// ```rust
1422    /// # use tenferro_runtime::TracedTensor;
1423    /// # let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64; 4]).unwrap();
1424    /// let y = x.reduce_min(&[0])?;
1425    /// # Ok::<(), tenferro_runtime::Error>(())
1426    /// ```
1427    ///
1428    /// # Errors
1429    ///
1430    /// Returns an error when an axis is out of bounds or duplicated.
1431    pub fn reduce_min(&self, axes: &[usize]) -> Result<TracedTensor> {
1432        let (out_rank, out_shape_hint) =
1433            reduction_output_meta(self, axes, "TracedTensor::reduce_min")?;
1434        Ok(apply_unary(
1435            StdTensorOp::ReduceMin {
1436                axes: axes.to_vec(),
1437            },
1438            self,
1439            out_rank,
1440            out_shape_hint,
1441        ))
1442    }
1443
1444    /// Reduce by taking the product along the given axes.
1445    ///
1446    /// # Examples
1447    ///
1448    /// ```rust
1449    /// # use tenferro_runtime::TracedTensor;
1450    /// # let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64; 4]).unwrap();
1451    /// let y = x.reduce_prod(&[0])?;
1452    /// # Ok::<(), tenferro_runtime::Error>(())
1453    /// ```
1454    ///
1455    /// # Errors
1456    ///
1457    /// Returns an error when an axis is out of bounds or duplicated.
1458    pub fn reduce_prod(&self, axes: &[usize]) -> Result<TracedTensor> {
1459        let (out_rank, out_shape_hint) =
1460            reduction_output_meta(self, axes, "TracedTensor::reduce_prod")?;
1461        Ok(apply_unary(
1462            StdTensorOp::ReduceProd {
1463                axes: axes.to_vec(),
1464            },
1465            self,
1466            out_rank,
1467            out_shape_hint,
1468        ))
1469    }
1470
1471    /// Reshape without changing element order.
1472    ///
1473    /// # Examples
1474    ///
1475    /// ```rust
1476    /// # use tenferro_runtime::TracedTensor;
1477    /// # let x = TracedTensor::from_vec_col_major(vec![4], vec![1.0_f64; 4]).unwrap();
1478    /// let y = x.reshape(&[2, 2]);
1479    /// ```
1480    pub fn reshape(&self, shape: &[usize]) -> TracedTensor {
1481        apply_unary(
1482            StdTensorOp::Reshape {
1483                to_shape: DimExpr::from_concrete(shape),
1484            },
1485            self,
1486            shape.len(),
1487            Some(shape.iter().copied().map(SymDim::from).collect()),
1488        )
1489    }
1490
1491    /// Return a symbolic expression for the size of one axis, suitable as
1492    /// an `InputDim`-style reference when composing with
1493    /// [`TracedTensor::reshape_sym`].
1494    ///
1495    /// Semantics: if this tensor's `shape_hint` has a symbolic
1496    /// (non-constant) entry for `axis`, that entry is returned
1497    /// verbatim. Otherwise — including when `shape_hint[axis]` is a
1498    /// concrete `SymDim::Concrete(n)` — a
1499    /// `SymDim::tensor_axis(self.id, axis)` reference is returned so the
1500    /// resulting graph remains shape-polymorphic if the same graph is
1501    /// later evaluated against a differently-shaped binding.
1502    ///
1503    /// For a canonical "what is the size of this axis?" query that
1504    /// reports the concrete size when it is known, prefer
1505    /// [`Self::axis_sym_dim`].
1506    ///
1507    /// # Examples
1508    ///
1509    /// ```rust
1510    /// # use tenferro_runtime::TracedTensor;
1511    /// # let x = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1512    /// let rows = x.sym_size(0)?;
1513    /// let cols = x.sym_size(1)?;
1514    /// let y = x.reshape_sym(&[rows * cols]).unwrap();
1515    /// # Ok::<(), tenferro_runtime::Error>(())
1516    /// ```
1517    ///
1518    /// # Errors
1519    ///
1520    /// Returns an error when `axis` is out of bounds.
1521    pub fn sym_size(&self, axis: usize) -> Result<SymDim> {
1522        validate_traced_axis(self, axis, "TracedTensor::sym_size")?;
1523        Ok(self
1524            .shape_hint
1525            .as_ref()
1526            .and_then(|shape| shape.get(axis))
1527            .filter(|dim| dim.constant_value().is_none())
1528            .cloned()
1529            .unwrap_or_else(|| SymDim::tensor_axis(self.id, axis)))
1530    }
1531
1532    /// Return the canonical `SymDim` for `axis` — the concrete
1533    /// `SymDim::Concrete(n)` when the size is known, otherwise a symbolic
1534    /// expression identifying this tensor's axis.
1535    ///
1536    /// Unlike [`Self::sym_size`], this method does **not** rewrite
1537    /// concrete axes into `TensorAxis` references. It is the accessor
1538    /// external composition wrappers should use when building mixed
1539    /// concrete/symbolic target shapes for operations like
1540    /// [`Self::broadcast_in_dim_sym`].
1541    ///
1542    /// # Examples
1543    ///
1544    /// ```
1545    /// use tenferro_tensor::DType;
1546    /// use tenferro_runtime::TracedTensor;
1547    ///
1548    /// let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1549    /// // Concrete axis: reports the constant size.
1550    /// assert_eq!(a.axis_sym_dim(0).unwrap().constant_value(), Some(2));
1551    ///
1552    /// let b = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
1553    /// // Fully symbolic leaf: reports a TensorAxis reference.
1554    /// assert!(b.axis_sym_dim(0).unwrap().constant_value().is_none());
1555    /// ```
1556    ///
1557    /// # Errors
1558    ///
1559    /// Returns an error when `axis` is out of bounds.
1560    pub fn axis_sym_dim(&self, axis: usize) -> Result<SymDim> {
1561        validate_traced_axis(self, axis, "TracedTensor::axis_sym_dim")?;
1562        match self.shape_hint.as_ref().and_then(|shape| shape.get(axis)) {
1563            Some(dim) => Ok(dim.clone()),
1564            None => Ok(SymDim::tensor_axis(self.id, axis)),
1565        }
1566    }
1567
1568    /// Return the full symbolic shape of this tensor when a `shape_hint`
1569    /// is present.
1570    ///
1571    /// Returns `None` for fully-symbolic placeholders produced via
1572    /// [`Self::input_symbolic_shape`] (where `shape_hint` is intentionally
1573    /// absent). For those, build the shape axis-by-axis via
1574    /// [`Self::axis_sym_dim`].
1575    ///
1576    /// # Examples
1577    ///
1578    /// ```
1579    /// use tenferro_tensor::DType;
1580    /// use tenferro_runtime::TracedTensor;
1581    ///
1582    /// let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1583    /// assert!(a.sym_shape().is_some());
1584    /// assert_eq!(a.sym_shape().unwrap().len(), 2);
1585    ///
1586    /// let b = TracedTensor::input_symbolic_shape(DType::F64, 2).unwrap();
1587    /// assert!(b.sym_shape().is_none());
1588    /// ```
1589    pub fn sym_shape(&self) -> Option<&[SymDim]> {
1590        self.shape_hint.as_deref()
1591    }
1592
1593    /// Reshape using symbolic dimensions derived from traced tensor axes.
1594    ///
1595    /// # Examples
1596    ///
1597    /// ```rust
1598    /// # use tenferro_runtime::TracedTensor;
1599    /// # let x = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1600    /// let rows = x.sym_size(0)?;
1601    /// let cols = x.sym_size(1)?;
1602    /// let y = x.reshape_sym(&[rows * cols]).unwrap();
1603    /// # Ok::<(), tenferro_runtime::Error>(())
1604    /// ```
1605    pub fn reshape_sym(&self, shape: &[SymDim]) -> Result<TracedTensor> {
1606        let tensor_map = [(self.id, 0usize)];
1607        let to_shape = shape
1608            .iter()
1609            .map(|dim| dim.to_dim_expr(&tensor_map).map_err(Error::Internal))
1610            .collect::<Result<Vec<_>>>()?;
1611        let out_shape_hint = Some(shape.to_vec());
1612        Ok(apply_unary(
1613            StdTensorOp::Reshape { to_shape },
1614            self,
1615            shape.len(),
1616            out_shape_hint,
1617        ))
1618    }
1619
1620    /// Broadcast into a larger shape with explicit dimension placement.
1621    ///
1622    /// # Examples
1623    ///
1624    /// ```rust
1625    /// # use tenferro_runtime::TracedTensor;
1626    /// # let x = TracedTensor::from_vec_col_major(vec![3], vec![1.0_f64; 3]).unwrap();
1627    /// let y = x.broadcast_in_dim(&[2, 3], &[1])?;
1628    /// # Ok::<(), tenferro_runtime::Error>(())
1629    /// ```
1630    ///
1631    /// # Errors
1632    ///
1633    /// Returns an error when `dims` is not a duplicate-free mapping from every
1634    /// input axis into the output rank, or when a known input dimension cannot
1635    /// broadcast to the corresponding output dimension.
1636    pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> Result<TracedTensor> {
1637        let out_shape_hint: Vec<SymDim> = shape.iter().copied().map(SymDim::from).collect();
1638        validate_broadcast_in_dim_args(
1639            self,
1640            &out_shape_hint,
1641            dims,
1642            "TracedTensor::broadcast_in_dim",
1643        )?;
1644        Ok(apply_unary(
1645            StdTensorOp::BroadcastInDim {
1646                shape: DimExpr::from_concrete(shape),
1647                dims: dims.to_vec(),
1648            },
1649            self,
1650            shape.len(),
1651            Some(out_shape_hint),
1652        ))
1653    }
1654
1655    /// Broadcast into a symbolic target shape with explicit dimension
1656    /// placement.
1657    ///
1658    /// Unlike [`Self::broadcast_in_dim`], each axis of `shape` is a
1659    /// [`SymDim`], so the target shape can mix concrete sizes (via
1660    /// `SymDim::from(n)`) with symbolic references to this tensor's axes
1661    /// (via [`Self::axis_sym_dim`]) or to axes of other traced tensors.
1662    ///
1663    /// When `shape` contains a `SymDim` that references a traced tensor
1664    /// other than `self`, the referenced tensor(s) must be supplied in
1665    /// `shape_refs`. They are wired into the built op as auxiliary
1666    /// shape-reference inputs — the op does not read their data, only
1667    /// their runtime shape. `shape_refs` must be listed in the same order
1668    /// in which their tensor IDs first appear when walking `shape` after
1669    /// any references to `self`. Usually the simplest correct thing is to
1670    /// pass each unique non-self reference tensor once.
1671    ///
1672    /// # Examples
1673    ///
1674    /// ```
1675    /// use tenferro_runtime::TracedTensor;
1676    ///
1677    /// let a = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1678    /// let b = TracedTensor::from_vec_col_major(vec![3, 4], vec![1.0_f64; 12]).unwrap();
1679    /// let m = a.axis_sym_dim(0)?;
1680    /// let k = a.axis_sym_dim(1)?;
1681    /// let n = b.axis_sym_dim(1)?;
1682    /// // Broadcast `a[m, k]` to `[m, k, n]`, placing `a`'s axes at 0, 1
1683    /// // and taking `n` from `b` as an auxiliary shape reference.
1684    /// let a_b = a.broadcast_in_dim_sym(&[m, k, n], &[0, 1], &[&b])?;
1685    /// assert_eq!(a_b.rank, 3);
1686    /// # Ok::<(), tenferro_runtime::Error>(())
1687    /// ```
1688    pub fn broadcast_in_dim_sym(
1689        &self,
1690        shape: &[SymDim],
1691        dims: &[usize],
1692        shape_refs: &[&TracedTensor],
1693    ) -> Result<TracedTensor> {
1694        validate_broadcast_in_dim_args(self, shape, dims, "TracedTensor::broadcast_in_dim_sym")?;
1695
1696        // Build a dedup'd list of shape-reference tensors (first occurrence
1697        // wins) and index them starting at 1 — the primary input `self`
1698        // is at 0.
1699        let mut dedup_refs: Vec<&TracedTensor> = Vec::with_capacity(shape_refs.len());
1700        let mut tensor_map: Vec<(u64, usize)> = vec![(self.id, 0)];
1701        for &t in shape_refs {
1702            if !tensor_map.iter().any(|(id, _)| *id == t.id) {
1703                let idx = tensor_map.len();
1704                tensor_map.push((t.id, idx));
1705                dedup_refs.push(t);
1706            }
1707        }
1708
1709        let to_shape: Vec<DimExpr> = shape
1710            .iter()
1711            .map(|dim| {
1712                dim.to_dim_expr(&tensor_map)
1713                    .map_err(|err| Error::InvalidGraphBuild {
1714                        op: "broadcast_in_dim_sym",
1715                        message: format!(
1716                            "unresolved symbolic dimension: {err}; \
1717                             pass every referenced tensor via `shape_refs`"
1718                        ),
1719                    })
1720            })
1721            .collect::<Result<Vec<_>>>()?;
1722
1723        // Trim auxiliary shape-reference inputs down to those actually
1724        // used by the generated `DimExpr`s. If the target shape resolved
1725        // to all constants (the concrete-shape case) the op is a plain
1726        // unary broadcast with no extra parents. Otherwise the op needs
1727        // a contiguous prefix of shape-ref inputs covering every
1728        // referenced `input_idx`.
1729        let max_used_idx = DimExpr::max_input_idx_all(&to_shape).unwrap_or(0);
1730        let used_refs: Vec<&TracedTensor> = dedup_refs.into_iter().take(max_used_idx).collect();
1731
1732        let out_shape_hint = Some(shape.to_vec());
1733        Ok(apply_unary_with_shape_refs(
1734            StdTensorOp::BroadcastInDim {
1735                shape: to_shape,
1736                dims: dims.to_vec(),
1737            },
1738            self,
1739            &used_refs,
1740            shape.len(),
1741            out_shape_hint,
1742        ))
1743    }
1744
1745    /// Slice with explicit start, limit, and stride per axis.
1746    pub fn slice(&self, config: SliceConfig) -> Result<TracedTensor> {
1747        let op = StdTensorOp::Slice(config);
1748        let (out_rank, out_shape_hint) =
1749            infer_traced_single_output_shape("TracedTensor::slice", &op, &[self])?;
1750        Ok(apply_unary(op, self, out_rank, out_shape_hint))
1751    }
1752
1753    /// Pad with zeros using StableHLO-style edge and interior padding.
1754    pub fn pad(&self, config: PadConfig) -> Result<TracedTensor> {
1755        let op = StdTensorOp::Pad(config);
1756        let (out_rank, out_shape_hint) =
1757            infer_traced_single_output_shape("TracedTensor::pad", &op, &[self])?;
1758        Ok(apply_unary(op, self, out_rank, out_shape_hint))
1759    }
1760
1761    /// Reverse the order of elements along the requested axes.
1762    pub fn reverse(&self, axes: &[usize]) -> Result<TracedTensor> {
1763        validate_traced_axes(self.rank, axes, "TracedTensor::reverse")?;
1764        Ok(apply_unary(
1765            StdTensorOp::Reverse {
1766                axes: axes.to_vec(),
1767            },
1768            self,
1769            self.rank,
1770            self.shape_hint.clone(),
1771        ))
1772    }
1773
1774    /// Gather slices from `self` using integer start indices.
1775    pub fn gather(&self, indices: &TracedTensor, config: GatherConfig) -> Result<TracedTensor> {
1776        let op = StdTensorOp::Gather(config);
1777        let (out_rank, out_shape_hint) =
1778            infer_traced_single_output_shape("TracedTensor::gather", &op, &[self, indices])?;
1779        Ok(apply_binary_preserve_input_dtypes(
1780            op,
1781            self,
1782            indices,
1783            out_rank,
1784            out_shape_hint,
1785            self.dtype,
1786        ))
1787    }
1788
1789    /// Scatter updates into `self` using StableHLO scatter semantics.
1790    pub fn scatter(
1791        &self,
1792        indices: &TracedTensor,
1793        updates: &TracedTensor,
1794        config: ScatterConfig,
1795    ) -> Result<TracedTensor> {
1796        let op = StdTensorOp::Scatter(config);
1797        let (out_rank, out_shape_hint) = infer_traced_single_output_shape(
1798            "TracedTensor::scatter",
1799            &op,
1800            &[self, indices, updates],
1801        )?;
1802        let out_dtype = crate::shape_infer::promote_dtype(self.dtype, updates.dtype);
1803        let operand = if self.dtype != out_dtype {
1804            self.cast(out_dtype)
1805        } else {
1806            self.clone()
1807        };
1808        let updates = if updates.dtype != out_dtype {
1809            updates.cast(out_dtype)
1810        } else {
1811            updates.clone()
1812        };
1813        Ok(apply_ternary_with_output_dtype(
1814            op,
1815            &operand,
1816            indices,
1817            &updates,
1818            out_rank,
1819            out_shape_hint,
1820            out_dtype,
1821        ))
1822    }
1823
1824    /// Slice using runtime start indices.
1825    pub fn dynamic_slice(&self, starts: &TracedTensor, sizes: &[usize]) -> Result<TracedTensor> {
1826        let op = StdTensorOp::DynamicSlice {
1827            slice_sizes: sizes.to_vec(),
1828        };
1829        let (out_rank, out_shape_hint) =
1830            infer_traced_single_output_shape("TracedTensor::dynamic_slice", &op, &[self, starts])?;
1831        Ok(apply_binary_preserve_input_dtypes(
1832            op,
1833            self,
1834            starts,
1835            out_rank,
1836            out_shape_hint,
1837            self.dtype,
1838        ))
1839    }
1840
1841    /// Keep the lower triangle and zero the rest.
1842    pub fn tril(&self, k: i64) -> TracedTensor {
1843        apply_unary(
1844            StdTensorOp::Tril { k },
1845            self,
1846            self.rank,
1847            self.shape_hint.clone(),
1848        )
1849    }
1850
1851    /// Keep the upper triangle and zero the rest.
1852    pub fn triu(&self, k: i64) -> TracedTensor {
1853        apply_unary(
1854            StdTensorOp::Triu { k },
1855            self,
1856            self.rank,
1857            self.shape_hint.clone(),
1858        )
1859    }
1860
1861    /// Permute tensor axes.
1862    ///
1863    /// # Examples
1864    ///
1865    /// ```rust
1866    /// # use tenferro_runtime::TracedTensor;
1867    /// # let x = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
1868    /// let y = x.transpose(&[1, 0])?;
1869    /// # Ok::<(), tenferro_runtime::Error>(())
1870    /// ```
1871    ///
1872    /// # Errors
1873    ///
1874    /// Returns an error when `perm` is not a valid permutation of the tensor
1875    /// axes.
1876    pub fn transpose(&self, perm: &[usize]) -> Result<TracedTensor> {
1877        validate_traced_perm(self.rank, perm, "TracedTensor::transpose")?;
1878        let out_shape_hint = self
1879            .shape_hint
1880            .as_ref()
1881            .map(|shape| perm.iter().map(|&p| shape[p].clone()).collect());
1882        Ok(apply_unary(
1883            StdTensorOp::Transpose {
1884                perm: perm.to_vec(),
1885            },
1886            self,
1887            self.rank,
1888            out_shape_hint,
1889        ))
1890    }
1891
1892    /// Extract the diagonal along two axes.
1893    ///
1894    /// # Examples
1895    ///
1896    /// ```rust
1897    /// # use tenferro_runtime::TracedTensor;
1898    /// # let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64; 4]).unwrap();
1899    /// let y = x.extract_diag(0, 1)?;
1900    /// # Ok::<(), tenferro_runtime::Error>(())
1901    /// ```
1902    ///
1903    /// # Errors
1904    ///
1905    /// Returns an error when either axis is out of bounds or the two axes are
1906    /// equal.
1907    pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> Result<TracedTensor> {
1908        validate_traced_axis(self, axis_a, "TracedTensor::extract_diag")?;
1909        validate_traced_axis(self, axis_b, "TracedTensor::extract_diag")?;
1910        if axis_a == axis_b {
1911            return Err(Error::InvalidGraphBuild {
1912                op: "TracedTensor::extract_diag",
1913                message: "diagonal axes must be distinct".into(),
1914            });
1915        }
1916        let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1917            shape
1918                .iter()
1919                .enumerate()
1920                .filter_map(|(axis, dim)| (axis != axis_b).then_some(dim.clone()))
1921                .collect()
1922        });
1923        Ok(apply_unary(
1924            StdTensorOp::ExtractDiag { axis_a, axis_b },
1925            self,
1926            self.rank - 1,
1927            out_shape_hint,
1928        ))
1929    }
1930
1931    /// Embed a vector or lower-rank tensor along a diagonal.
1932    ///
1933    /// # Examples
1934    ///
1935    /// ```rust
1936    /// # use tenferro_runtime::TracedTensor;
1937    /// # let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64; 2]).unwrap();
1938    /// let y = x.embed_diag(0, 1)?;
1939    /// # Ok::<(), tenferro_runtime::Error>(())
1940    /// ```
1941    ///
1942    /// # Errors
1943    ///
1944    /// Returns an error when `axis_a` is out of bounds or `axis_b` is not a
1945    /// valid insertion axis.
1946    pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> Result<TracedTensor> {
1947        validate_traced_axis(self, axis_a, "TracedTensor::embed_diag")?;
1948        validate_traced_insert_axis(self.rank, axis_b, "TracedTensor::embed_diag")?;
1949        let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1950            let mut out_shape = shape.clone();
1951            out_shape.insert(axis_b, shape[axis_a].clone());
1952            out_shape
1953        });
1954        Ok(apply_unary(
1955            StdTensorOp::EmbedDiag { axis_a, axis_b },
1956            self,
1957            self.rank + 1,
1958            out_shape_hint,
1959        ))
1960    }
1961
1962    /// Return the runtime size of one axis as a scalar `f64` tensor.
1963    ///
1964    /// The result is metadata-derived and therefore has no gradient.
1965    ///
1966    /// # Examples
1967    ///
1968    /// ```
1969    /// use tenferro_cpu::CpuBackend;
1970    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
1971    ///
1972    /// let x = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1973    /// let cols = x.shape_of(1)?;
1974    /// let mut compiler = GraphCompiler::new();
1975    /// let program = compiler.compile(&cols).unwrap();
1976    /// let out = GraphExecutor::new(CpuBackend::new()).run(&program).unwrap();
1977    /// assert_eq!(out.shape(), &[] as &[usize]);
1978    /// # Ok::<(), tenferro_runtime::Error>(())
1979    /// ```
1980    ///
1981    /// # Errors
1982    ///
1983    /// Returns an error when `axis` is out of bounds.
1984    pub fn shape_of(&self, axis: usize) -> Result<TracedTensor> {
1985        validate_traced_axis(self, axis, "TracedTensor::shape_of")?;
1986        Ok(apply_unary_with_dtype(
1987            StdTensorOp::ShapeOf { axis },
1988            self,
1989            0,
1990            Some(vec![]),
1991            DType::F64,
1992        ))
1993    }
1994
1995    /// Truncate this tensor along `axis` to the first `size` elements.
1996    ///
1997    /// `size` is read at runtime from a scalar traced tensor. Values are
1998    /// rounded to the nearest integer, clamped to `[0, self.shape[axis]]`,
1999    /// and the output keeps the same element dtype as the input.
2000    ///
2001    /// # Examples
2002    ///
2003    /// ```
2004    /// use tenferro_cpu::CpuBackend;
2005    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
2006    ///
2007    /// let x = TracedTensor::from_vec_col_major(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap();
2008    /// let size = TracedTensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap();
2009    /// let y = x.dynamic_truncate(&size, 0)?;
2010    /// let mut compiler = GraphCompiler::new();
2011    /// let program = compiler.compile(&y).unwrap();
2012    /// let out = GraphExecutor::new(CpuBackend::new()).run(&program).unwrap();
2013    /// assert_eq!(out.shape(), &[2]);
2014    /// # Ok::<(), tenferro_runtime::Error>(())
2015    /// ```
2016    ///
2017    /// # Errors
2018    ///
2019    /// Returns an error when `axis` is out of bounds or `size` is not scalar.
2020    pub fn dynamic_truncate(&self, size: &TracedTensor, axis: usize) -> Result<TracedTensor> {
2021        validate_traced_axis(self, axis, "TracedTensor::dynamic_truncate")?;
2022        if size.rank != 0 {
2023            return Err(Error::InvalidGraphBuild {
2024                op: "TracedTensor::dynamic_truncate",
2025                message: format!("size must be a scalar tensor, got rank {}", size.rank),
2026            });
2027        }
2028        Ok(apply_binary(
2029            StdTensorOp::DynamicTruncate { axis },
2030            self,
2031            size,
2032            self.rank,
2033            None,
2034        ))
2035    }
2036
2037    /// Pad this tensor with zeros along `axis` to match `reference.shape[axis]`.
2038    ///
2039    /// If `reference` is smaller along that axis, this is a no-op.
2040    ///
2041    /// # Examples
2042    ///
2043    /// ```
2044    /// use tenferro_cpu::CpuBackend;
2045    /// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
2046    ///
2047    /// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
2048    /// let reference = TracedTensor::from_vec_col_major(vec![4], vec![0.0_f64, 0.0, 0.0, 0.0]).unwrap();
2049    /// let y = x.pad_to_match(&reference, 0)?;
2050    /// let mut compiler = GraphCompiler::new();
2051    /// let program = compiler.compile(&y).unwrap();
2052    /// let out = GraphExecutor::new(CpuBackend::new()).run(&program).unwrap();
2053    /// assert_eq!(out.shape(), &[4]);
2054    /// # Ok::<(), tenferro_runtime::Error>(())
2055    /// ```
2056    ///
2057    /// # Errors
2058    ///
2059    /// Returns an error when `axis` is out of bounds for either tensor.
2060    pub fn pad_to_match(&self, reference: &TracedTensor, axis: usize) -> Result<TracedTensor> {
2061        validate_traced_axis(self, axis, "TracedTensor::pad_to_match")?;
2062        validate_traced_axis(reference, axis, "TracedTensor::pad_to_match")?;
2063        Ok(apply_binary(
2064            StdTensorOp::PadToMatch { axis },
2065            self,
2066            reference,
2067            self.rank,
2068            reference.shape_hint.clone(),
2069        ))
2070    }
2071}
2072
2073pub(crate) fn apply_unary(
2074    op: StdTensorOp,
2075    input: &TracedTensor,
2076    out_rank: usize,
2077    out_shape_hint: Option<Vec<SymDim>>,
2078) -> TracedTensor {
2079    let out_dtype =
2080        inferred_output_dtype_or_fallback(&op, &[input.dtype], input.dtype, "apply_unary");
2081    apply_unary_with_dtype(op, input, out_rank, out_shape_hint, out_dtype)
2082}
2083
2084pub(crate) fn apply_unary_with_dtype(
2085    op: StdTensorOp,
2086    input: &TracedTensor,
2087    out_rank: usize,
2088    out_shape_hint: Option<Vec<SymDim>>,
2089    out_dtype: DType,
2090) -> TracedTensor {
2091    let mut builder = GraphBuilder::new();
2092    builder.add_parent(input.graph.clone());
2093    let input_ref = ValueRef::External(input.graph.values()[input.val].key.clone());
2094    let outputs = builder.add_operation(op, vec![input_ref], OperationRole::Primary);
2095    builder.set_outputs(outputs.clone());
2096    let graph = Arc::new(builder.build());
2097    let metadata_scope =
2098        register_single_output_metadata(graph.as_ref(), outputs[0], out_dtype, &out_shape_hint);
2099
2100    TracedTensor {
2101        id: next_traced_id(),
2102        rank: out_rank,
2103        dtype: out_dtype,
2104        graph,
2105        val: outputs[0],
2106        data: None,
2107        shape_hint: out_shape_hint,
2108        inputs_map: input.inputs_map.clone(),
2109        extra_roots: input.extra_roots.clone(),
2110        checkpoint_chain: input.checkpoint_chain.clone(),
2111        metadata_scopes: metadata_scopes_with_new(
2112            metadata_scope,
2113            [input.metadata_scopes.as_slice()],
2114        ),
2115    }
2116}
2117
2118/// Apply a unary-primary op that additionally references one or more
2119/// tensors for shape resolution only.
2120///
2121/// The primary `input` becomes op input 0; each tensor in `shape_refs`
2122/// becomes op input 1, 2, … in order. Used by
2123/// [`TracedTensor::broadcast_in_dim_sym`] when the target shape
2124/// references axes of tensors other than the primary input; the op
2125/// reads only their runtime shape, not their data.
2126pub(crate) fn apply_unary_with_shape_refs(
2127    op: StdTensorOp,
2128    input: &TracedTensor,
2129    shape_refs: &[&TracedTensor],
2130    out_rank: usize,
2131    out_shape_hint: Option<Vec<SymDim>>,
2132) -> TracedTensor {
2133    let mut builder = GraphBuilder::new();
2134    builder.add_parent(input.graph.clone());
2135    for t in shape_refs {
2136        builder.add_parent(t.graph.clone());
2137    }
2138    let mut op_inputs: Vec<ValueRef<StdTensorOp>> = Vec::with_capacity(1 + shape_refs.len());
2139    op_inputs.push(ValueRef::External(
2140        input.graph.values()[input.val].key.clone(),
2141    ));
2142    for t in shape_refs {
2143        op_inputs.push(ValueRef::External(t.graph.values()[t.val].key.clone()));
2144    }
2145    let outputs = builder.add_operation(op, op_inputs, OperationRole::Primary);
2146    builder.set_outputs(outputs.clone());
2147    let graph = Arc::new(builder.build());
2148    let metadata_scope =
2149        register_single_output_metadata(graph.as_ref(), outputs[0], input.dtype, &out_shape_hint);
2150
2151    let mut merged = (*input.inputs_map).clone();
2152    for t in shape_refs {
2153        merged.extend(t.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
2154    }
2155
2156    let mut extra_roots = input.extra_roots.clone();
2157    for t in shape_refs {
2158        extra_roots.extend(t.extra_roots.iter().cloned());
2159    }
2160
2161    let mut checkpoint_chain = input.checkpoint_chain.clone();
2162    for t in shape_refs {
2163        checkpoint_chain =
2164            CheckpointNode::merge_chains(checkpoint_chain, t.checkpoint_chain.clone());
2165    }
2166
2167    TracedTensor {
2168        id: next_traced_id(),
2169        rank: out_rank,
2170        dtype: input.dtype,
2171        graph,
2172        val: outputs[0],
2173        data: None,
2174        shape_hint: out_shape_hint,
2175        inputs_map: Arc::new(merged),
2176        extra_roots,
2177        checkpoint_chain,
2178        metadata_scopes: {
2179            let mut scopes =
2180                metadata_scopes_with_new(metadata_scope, [input.metadata_scopes.as_slice()]);
2181            for t in shape_refs {
2182                for scope in &t.metadata_scopes {
2183                    push_metadata_scope(&mut scopes, Arc::clone(scope));
2184                }
2185            }
2186            scopes
2187        },
2188    }
2189}
2190
2191pub(crate) fn apply_nullary(
2192    op: StdTensorOp,
2193    rank: usize,
2194    dtype: DType,
2195    shape_hint: Option<Vec<SymDim>>,
2196) -> TracedTensor {
2197    let mut builder = GraphBuilder::new();
2198    let outputs = builder.add_operation(op, vec![], OperationRole::Primary);
2199    builder.set_outputs(outputs.clone());
2200    let graph = Arc::new(builder.build());
2201    let metadata_scope =
2202        register_single_output_metadata(graph.as_ref(), outputs[0], dtype, &shape_hint);
2203
2204    TracedTensor {
2205        id: next_traced_id(),
2206        rank,
2207        dtype,
2208        graph,
2209        val: outputs[0],
2210        data: None,
2211        shape_hint,
2212        inputs_map: Arc::new(HashMap::new()),
2213        extra_roots: Vec::new(),
2214        checkpoint_chain: None,
2215        metadata_scopes: metadata_scopes_for_scope(metadata_scope),
2216    }
2217}
2218
2219pub(crate) fn apply_binary(
2220    op: StdTensorOp,
2221    lhs: &TracedTensor,
2222    rhs: &TracedTensor,
2223    out_rank: usize,
2224    out_shape_hint: Option<Vec<SymDim>>,
2225) -> TracedTensor {
2226    let input_dtype = crate::shape_infer::promote_dtype_for_binary_op(&op, lhs.dtype, rhs.dtype);
2227    let out_dtype = inferred_output_dtype_or_fallback(
2228        &op,
2229        &[lhs.dtype, rhs.dtype],
2230        input_dtype,
2231        "apply_binary",
2232    );
2233
2234    // Insert Convert ops when an input dtype differs from the primitive input dtype.
2235    let lhs = if lhs.dtype != input_dtype {
2236        lhs.cast(input_dtype)
2237    } else {
2238        lhs.clone()
2239    };
2240    let rhs = if rhs.dtype != input_dtype {
2241        rhs.cast(input_dtype)
2242    } else {
2243        rhs.clone()
2244    };
2245
2246    apply_binary_with_output_dtype(op, &lhs, &rhs, out_rank, out_shape_hint, out_dtype)
2247}
2248
2249pub(crate) fn apply_binary_preserve_input_dtypes(
2250    op: StdTensorOp,
2251    lhs: &TracedTensor,
2252    rhs: &TracedTensor,
2253    out_rank: usize,
2254    out_shape_hint: Option<Vec<SymDim>>,
2255    out_dtype: DType,
2256) -> TracedTensor {
2257    apply_binary_with_output_dtype(op, lhs, rhs, out_rank, out_shape_hint, out_dtype)
2258}
2259
2260pub(crate) fn apply_broadcast_binary_op(
2261    op: StdTensorOp,
2262    lhs: &TracedTensor,
2263    rhs: &TracedTensor,
2264) -> Result<TracedTensor> {
2265    let (lhs, rhs) = broadcast_binary(lhs, rhs)?;
2266    Ok(apply_binary(
2267        op,
2268        &lhs,
2269        &rhs,
2270        lhs.rank,
2271        lhs.shape_hint.clone(),
2272    ))
2273}
2274
2275pub(crate) fn apply_broadcast_ternary_op(
2276    op: StdTensorOp,
2277    first: &TracedTensor,
2278    second: &TracedTensor,
2279    third: &TracedTensor,
2280) -> Result<TracedTensor> {
2281    let (first, second, third) = broadcast_ternary(first, second, third)?;
2282    Ok(apply_ternary(
2283        op,
2284        &first,
2285        &second,
2286        &third,
2287        first.rank,
2288        first.shape_hint.clone(),
2289    ))
2290}
2291
2292pub(crate) fn apply_ternary(
2293    op: StdTensorOp,
2294    first: &TracedTensor,
2295    second: &TracedTensor,
2296    third: &TracedTensor,
2297    out_rank: usize,
2298    out_shape_hint: Option<Vec<SymDim>>,
2299) -> TracedTensor {
2300    let fallback_dtype =
2301        crate::shape_infer::promote_dtypes([first.dtype, second.dtype, third.dtype]);
2302    let out_dtype = inferred_output_dtype_or_fallback(
2303        &op,
2304        &[first.dtype, second.dtype, third.dtype],
2305        fallback_dtype,
2306        "apply_ternary",
2307    );
2308    let (first, second, third) = match op {
2309        StdTensorOp::Select => {
2310            let value_dtype = crate::shape_infer::promote_dtype(second.dtype, third.dtype);
2311            let second = if second.dtype != value_dtype {
2312                second.cast(value_dtype)
2313            } else {
2314                second.clone()
2315            };
2316            let third = if third.dtype != value_dtype {
2317                third.cast(value_dtype)
2318            } else {
2319                third.clone()
2320            };
2321            (first.clone(), second, third)
2322        }
2323        _ => {
2324            let input_dtype =
2325                crate::shape_infer::promote_dtypes([first.dtype, second.dtype, third.dtype]);
2326            let first = if first.dtype != input_dtype {
2327                first.cast(input_dtype)
2328            } else {
2329                first.clone()
2330            };
2331            let second = if second.dtype != input_dtype {
2332                second.cast(input_dtype)
2333            } else {
2334                second.clone()
2335            };
2336            let third = if third.dtype != input_dtype {
2337                third.cast(input_dtype)
2338            } else {
2339                third.clone()
2340            };
2341            (first, second, third)
2342        }
2343    };
2344    apply_ternary_with_output_dtype(
2345        op,
2346        &first,
2347        &second,
2348        &third,
2349        out_rank,
2350        out_shape_hint,
2351        out_dtype,
2352    )
2353}
2354
2355fn apply_binary_with_output_dtype(
2356    op: StdTensorOp,
2357    lhs: &TracedTensor,
2358    rhs: &TracedTensor,
2359    out_rank: usize,
2360    out_shape_hint: Option<Vec<SymDim>>,
2361    out_dtype: DType,
2362) -> TracedTensor {
2363    let lhs_ref = ValueRef::External(lhs.graph.values()[lhs.val].key.clone());
2364    let rhs_ref = ValueRef::External(rhs.graph.values()[rhs.val].key.clone());
2365
2366    let mut builder = GraphBuilder::new();
2367    builder.add_parent(lhs.graph.clone());
2368    builder.add_parent(rhs.graph.clone());
2369    let outputs = builder.add_operation(op, vec![lhs_ref, rhs_ref], OperationRole::Primary);
2370    builder.set_outputs(outputs.clone());
2371    let graph = Arc::new(builder.build());
2372    let metadata_scope =
2373        register_single_output_metadata(graph.as_ref(), outputs[0], out_dtype, &out_shape_hint);
2374
2375    let mut merged = (*lhs.inputs_map).clone();
2376    merged.extend(rhs.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
2377    let mut extra_roots = lhs.extra_roots.clone();
2378    extra_roots.extend(rhs.extra_roots.iter().cloned());
2379
2380    TracedTensor {
2381        id: next_traced_id(),
2382        rank: out_rank,
2383        dtype: out_dtype,
2384        graph,
2385        val: outputs[0],
2386        data: None,
2387        shape_hint: out_shape_hint,
2388        inputs_map: Arc::new(merged),
2389        extra_roots,
2390        checkpoint_chain: CheckpointNode::merge_chains(
2391            lhs.checkpoint_chain.clone(),
2392            rhs.checkpoint_chain.clone(),
2393        ),
2394        metadata_scopes: metadata_scopes_with_new(
2395            metadata_scope,
2396            [
2397                lhs.metadata_scopes.as_slice(),
2398                rhs.metadata_scopes.as_slice(),
2399            ],
2400        ),
2401    }
2402}
2403
2404fn apply_ternary_with_output_dtype(
2405    op: StdTensorOp,
2406    first: &TracedTensor,
2407    second: &TracedTensor,
2408    third: &TracedTensor,
2409    out_rank: usize,
2410    out_shape_hint: Option<Vec<SymDim>>,
2411    out_dtype: DType,
2412) -> TracedTensor {
2413    let first_ref = ValueRef::External(first.graph.values()[first.val].key.clone());
2414    let second_ref = ValueRef::External(second.graph.values()[second.val].key.clone());
2415    let third_ref = ValueRef::External(third.graph.values()[third.val].key.clone());
2416
2417    let mut builder = GraphBuilder::new();
2418    builder.add_parent(first.graph.clone());
2419    builder.add_parent(second.graph.clone());
2420    builder.add_parent(third.graph.clone());
2421    let outputs = builder.add_operation(
2422        op,
2423        vec![first_ref, second_ref, third_ref],
2424        OperationRole::Primary,
2425    );
2426    builder.set_outputs(outputs.clone());
2427    let graph = Arc::new(builder.build());
2428    let metadata_scope =
2429        register_single_output_metadata(graph.as_ref(), outputs[0], out_dtype, &out_shape_hint);
2430
2431    let mut merged = (*first.inputs_map).clone();
2432    merged.extend(
2433        second
2434            .inputs_map
2435            .iter()
2436            .map(|(k, v)| (k.clone(), v.clone())),
2437    );
2438    merged.extend(third.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
2439
2440    let mut extra_roots = first.extra_roots.clone();
2441    extra_roots.extend(second.extra_roots.iter().cloned());
2442    extra_roots.extend(third.extra_roots.iter().cloned());
2443
2444    let checkpoint_chain = CheckpointNode::merge_chains(
2445        CheckpointNode::merge_chains(
2446            first.checkpoint_chain.clone(),
2447            second.checkpoint_chain.clone(),
2448        ),
2449        third.checkpoint_chain.clone(),
2450    );
2451
2452    TracedTensor {
2453        id: next_traced_id(),
2454        rank: out_rank,
2455        dtype: out_dtype,
2456        graph,
2457        val: outputs[0],
2458        data: None,
2459        shape_hint: out_shape_hint,
2460        inputs_map: Arc::new(merged),
2461        extra_roots,
2462        checkpoint_chain,
2463        metadata_scopes: metadata_scopes_with_new(
2464            metadata_scope,
2465            [
2466                first.metadata_scopes.as_slice(),
2467                second.metadata_scopes.as_slice(),
2468                third.metadata_scopes.as_slice(),
2469            ],
2470        ),
2471    }
2472}
2473
2474fn register_single_output_metadata(
2475    graph: &Graph<StdTensorOp>,
2476    output: LocalValueId,
2477    dtype: DType,
2478    shape_hint: &Option<Vec<SymDim>>,
2479) -> GlobalMetadataScope {
2480    if let Some(shape) = shape_hint {
2481        // Fresh graph output keys are generated in this builder, so metadata
2482        // registration failure would indicate a global metadata invariant bug.
2483        register_scoped_value_metadata(
2484            graph.values()[output].key.clone(),
2485            tensor_meta(dtype, shape.clone()),
2486        )
2487        .expect("fresh traced graph output metadata registration failed")
2488    } else {
2489        // Fresh graph output keys are generated in this builder, so metadata
2490        // registration failure would indicate a global metadata invariant bug.
2491        register_scoped_graph_metadata(graph, std::iter::empty())
2492            .expect("fresh traced graph metadata registration failed")
2493    }
2494}
2495
2496impl TracedTensor {
2497    pub(crate) fn resolve_roots(&self) -> Vec<Arc<Graph<StdTensorOp>>> {
2498        let mut roots = Vec::with_capacity(1 + self.extra_roots.len());
2499        roots.push(self.graph.clone());
2500        roots.extend(self.extra_roots.iter().cloned());
2501        roots
2502    }
2503}
2504
2505#[cfg(test)]
2506mod tests;