Skip to main content

tenferro/
traced.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4
5use computegraph::compile::compile;
6use computegraph::fragment::{Fragment, FragmentBuilder};
7use computegraph::materialize::materialize_merge;
8use computegraph::resolve::resolve;
9use computegraph::types::{GlobalValKey, OpMode, ValRef};
10use computegraph::LocalValId;
11use num_complex::{Complex32, Complex64};
12use tenferro_ops::dim_expr::DimExpr;
13use tenferro_ops::input_key::TensorInputKey;
14use tenferro_ops::std_tensor_op::StdTensorOp;
15use tenferro_ops::ShapeGuardContext;
16use tenferro_tensor::{DType, DotGeneralConfig, Tensor, TensorBackend, TensorScalar, TypedTensor};
17use tidu::{differentiate, transpose};
18
19use super::compiler::compile_std_to_exec;
20use super::engine::Engine;
21use super::error::{Error, Result};
22use super::sym_dim::SymDim;
23use crate::checkpoint::CheckpointNode;
24
25static NEXT_INPUT_ID: AtomicU64 = AtomicU64::new(0);
26static NEXT_DIFF_PASS_ID: AtomicU64 = AtomicU64::new(0);
27static NEXT_TRACED_ID: AtomicU64 = AtomicU64::new(0);
28
29pub type TracedTensorId = u64;
30
31pub(crate) fn next_input_key() -> TensorInputKey {
32    TensorInputKey::User {
33        id: NEXT_INPUT_ID.fetch_add(1, Ordering::Relaxed),
34    }
35}
36
37fn next_pass_id() -> u64 {
38    NEXT_DIFF_PASS_ID.fetch_add(1, Ordering::Relaxed)
39}
40
41pub(crate) fn next_traced_id() -> TracedTensorId {
42    NEXT_TRACED_ID.fetch_add(1, Ordering::Relaxed)
43}
44
45#[derive(Clone)]
46pub struct TracedTensor {
47    pub id: TracedTensorId,
48    pub rank: usize,
49    pub dtype: DType,
50    pub fragment: Arc<Fragment<StdTensorOp>>,
51    pub val: LocalValId,
52    pub data: Option<Arc<Tensor>>,
53    pub(crate) shape_hint: Option<Vec<SymDim>>,
54    pub(crate) inputs_map: Arc<HashMap<TensorInputKey, Arc<Tensor>>>,
55    pub(crate) extra_roots: Vec<Arc<Fragment<StdTensorOp>>>,
56    pub(crate) checkpoint_chain: Option<Arc<CheckpointNode>>,
57}
58
59/// Compute a broadcast output shape following NumPy rules.
60///
61/// Returns `None` when the two shapes are incompatible.
62fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Vec<usize>> {
63    let rank = a.len().max(b.len());
64    let mut result = Vec::with_capacity(rank);
65    for index in 0..rank {
66        let a_dim = if index < rank - a.len() {
67            1
68        } else {
69            a[index - (rank - a.len())]
70        };
71        let b_dim = if index < rank - b.len() {
72            1
73        } else {
74            b[index - (rank - b.len())]
75        };
76        if a_dim == b_dim {
77            result.push(a_dim);
78        } else if a_dim == 1 {
79            result.push(b_dim);
80        } else if b_dim == 1 {
81            result.push(a_dim);
82        } else {
83            return None;
84        }
85    }
86    Some(result)
87}
88
89pub(crate) fn try_concrete_shape(tensor: &TracedTensor) -> Option<Vec<usize>> {
90    tensor
91        .shape_hint
92        .as_ref()?
93        .iter()
94        .map(SymDim::constant_value)
95        .collect()
96}
97
98pub(crate) fn concrete_shape(tensor: &TracedTensor) -> Vec<usize> {
99    tensor
100        .shape_hint
101        .as_ref()
102        .unwrap_or_else(|| panic!("missing shape hint for traced tensor {}", tensor.id))
103        .iter()
104        .map(|dim| {
105            dim.constant_value().unwrap_or_else(|| {
106                panic!("symbolic dimension in shape hint for tensor {}", tensor.id)
107            })
108        })
109        .collect()
110}
111
112fn error_shape_hint(tensor: &TracedTensor) -> Vec<usize> {
113    try_concrete_shape(tensor).unwrap_or_else(|| vec![0; tensor.rank])
114}
115
116/// Broadcast a traced tensor to `target_shape`.
117///
118/// Expanding singleton axes are first reshaped away so the existing
119/// `BroadcastInDim` transpose rule reduces them correctly during VJP.
120fn broadcast_to(tensor: &TracedTensor, target_shape: &[usize]) -> TracedTensor {
121    let tensor_shape = concrete_shape(tensor);
122    if tensor_shape == target_shape {
123        return tensor.clone();
124    }
125
126    assert!(
127        tensor.rank <= target_shape.len(),
128        "cannot broadcast higher-rank shape {:?} to {:?}",
129        tensor_shape,
130        target_shape
131    );
132
133    let rank_diff = target_shape.len() - tensor.rank;
134    let mut source_shape = Vec::with_capacity(tensor.rank);
135    let mut dims = Vec::with_capacity(tensor.rank);
136    for (src_axis, &src_dim) in tensor_shape.iter().enumerate() {
137        let dst_axis = src_axis + rank_diff;
138        let dst_dim = target_shape[dst_axis];
139        assert!(
140            src_dim == dst_dim || src_dim == 1,
141            "cannot broadcast shape {:?} to {:?}",
142            tensor_shape,
143            target_shape
144        );
145        if src_dim == 1 && dst_dim != 1 {
146            continue;
147        }
148        source_shape.push(src_dim);
149        dims.push(dst_axis);
150    }
151
152    let source = if source_shape == tensor_shape {
153        tensor.clone()
154    } else {
155        tensor.reshape(&source_shape)
156    };
157    source.broadcast_in_dim(target_shape, &dims)
158}
159
160/// Broadcast two tensors to a common shape.
161fn broadcast_binary(a: &TracedTensor, b: &TracedTensor) -> (TracedTensor, TracedTensor) {
162    if a.shape_hint == b.shape_hint && a.rank == b.rank {
163        return (a.clone(), b.clone());
164    }
165    let a_shape = concrete_shape(a);
166    let b_shape = concrete_shape(b);
167    let target = broadcast_shape(&a_shape, &b_shape).unwrap_or_else(|| {
168        panic!(
169            "incompatible shapes for broadcast: {:?} and {:?}",
170            a_shape, b_shape
171        )
172    });
173    (broadcast_to(a, &target), broadcast_to(b, &target))
174}
175
176fn scale_with_constant(input: &TracedTensor, op: StdTensorOp) -> TracedTensor {
177    let scalar = apply_nullary(op, 0, input.dtype, Some(vec![]));
178    let input_shape = concrete_shape(input);
179    let factor = broadcast_to(&scalar, &input_shape);
180    apply_binary(
181        StdTensorOp::Mul,
182        input,
183        &factor,
184        input.rank,
185        input.shape_hint.clone(),
186    )
187}
188
189impl std::ops::Add for &TracedTensor {
190    type Output = TracedTensor;
191
192    fn add(self, rhs: &TracedTensor) -> TracedTensor {
193        TracedTensor::add(self, rhs)
194    }
195}
196
197impl std::ops::Mul for &TracedTensor {
198    type Output = TracedTensor;
199
200    fn mul(self, rhs: &TracedTensor) -> TracedTensor {
201        TracedTensor::mul(self, rhs)
202    }
203}
204
205impl std::ops::Mul<f64> for &TracedTensor {
206    type Output = TracedTensor;
207
208    fn mul(self, rhs: f64) -> TracedTensor {
209        self.scale_real(rhs)
210    }
211}
212
213impl std::ops::Mul<&TracedTensor> for f64 {
214    type Output = TracedTensor;
215
216    fn mul(self, rhs: &TracedTensor) -> TracedTensor {
217        rhs.scale_real(self)
218    }
219}
220
221impl std::ops::Neg for &TracedTensor {
222    type Output = TracedTensor;
223
224    fn neg(self) -> TracedTensor {
225        TracedTensor::neg(self)
226    }
227}
228
229impl std::ops::Div for &TracedTensor {
230    type Output = TracedTensor;
231
232    fn div(self, rhs: &TracedTensor) -> TracedTensor {
233        TracedTensor::div(self, rhs)
234    }
235}
236
237impl TracedTensor {
238    /// Build a [`TracedTensor`] leaf from a concrete [`Tensor`], keeping its
239    /// shape as a concrete `shape_hint`.
240    ///
241    /// This is the common constructor when you have concrete tensor data that
242    /// you want to use both for graph building and for evaluation. The
243    /// resulting tensor is treated as a concrete-shape leaf by downstream
244    /// passes (binary einsum decomposition, build-time reshape folding, etc.).
245    ///
246    /// # Examples
247    ///
248    /// ```
249    /// use tenferro::{Tensor, TracedTensor};
250    ///
251    /// let a = TracedTensor::from_tensor_concrete_shape(
252    ///     Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]),
253    /// );
254    /// assert_eq!(a.rank, 2);
255    /// assert!(a.is_concrete_shape());
256    /// ```
257    pub fn from_tensor_concrete_shape(tensor: Tensor) -> Self {
258        let shape = tensor.shape().to_vec();
259        let rank = shape.len();
260        let dtype = tensor.dtype();
261        let key = next_input_key();
262        let data = Arc::new(tensor);
263
264        let mut builder = FragmentBuilder::new();
265        let val = builder.add_input(key.clone());
266        builder.set_outputs(vec![val]);
267        let fragment = Arc::new(builder.build());
268
269        let mut map = HashMap::new();
270        map.insert(key, Arc::clone(&data));
271
272        Self {
273            id: next_traced_id(),
274            rank,
275            dtype,
276            fragment,
277            val,
278            data: Some(data),
279            shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
280            inputs_map: Arc::new(map),
281            extra_roots: Vec::new(),
282            checkpoint_chain: None,
283        }
284    }
285
286    /// Build a [`TracedTensor`] leaf from a concrete [`Tensor`] but advertise
287    /// a symbolic shape during graph construction.
288    ///
289    /// The tensor data is still attached (so plain `eval` works without
290    /// bindings), but graph passes see the leaf as shape-symbolic. This is
291    /// useful for building a single traced program that should not bake in
292    /// shape-specific optimizations — e.g. mixing a known-shape tensor into
293    /// an einsum with other `input_symbolic_shape` placeholders forces the
294    /// einsum to be kept as a single `NaryEinsum` op rather than
295    /// decomposing at build time.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// use tenferro::{Tensor, TracedTensor};
301    ///
302    /// let t = TracedTensor::from_tensor_symbolic_shape(
303    ///     Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]),
304    /// );
305    /// assert_eq!(t.rank, 2);
306    /// assert!(!t.is_concrete_shape());
307    /// ```
308    pub fn from_tensor_symbolic_shape(tensor: Tensor) -> Self {
309        let rank = tensor.shape().len();
310        let dtype = tensor.dtype();
311        let key = next_input_key();
312        let data = Arc::new(tensor);
313
314        let mut builder = FragmentBuilder::new();
315        let val = builder.add_input(key.clone());
316        builder.set_outputs(vec![val]);
317        let fragment = Arc::new(builder.build());
318
319        let mut map = HashMap::new();
320        map.insert(key, Arc::clone(&data));
321
322        Self {
323            id: next_traced_id(),
324            rank,
325            dtype,
326            fragment,
327            val,
328            data: Some(data),
329            shape_hint: None,
330            inputs_map: Arc::new(map),
331            extra_roots: Vec::new(),
332            checkpoint_chain: None,
333        }
334    }
335
336    /// Build a data-less placeholder leaf with a fixed (concrete) shape.
337    ///
338    /// Must be bound via [`TracedTensor::eval_with_inputs`] before evaluation.
339    /// Use this when you know the exact shape of the input but want to build
340    /// the graph once and feed different concrete tensors at eval time.
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use tenferro_tensor::DType;
346    /// use tenferro::TracedTensor;
347    ///
348    /// let x = TracedTensor::input_concrete_shape(DType::F64, &[2, 3]);
349    /// assert_eq!(x.rank, 2);
350    /// assert!(x.is_concrete_shape());
351    /// ```
352    pub fn input_concrete_shape(dtype: DType, shape: &[usize]) -> Self {
353        let shape = shape.to_vec();
354        let rank = shape.len();
355        let key = next_input_key();
356
357        let mut builder = FragmentBuilder::new();
358        let val = builder.add_input(key.clone());
359        builder.set_outputs(vec![val]);
360        let fragment = Arc::new(builder.build());
361
362        Self {
363            id: next_traced_id(),
364            rank,
365            dtype,
366            fragment,
367            val,
368            data: None,
369            shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
370            inputs_map: Arc::new(HashMap::new()),
371            extra_roots: Vec::new(),
372            checkpoint_chain: None,
373        }
374    }
375
376    /// Build a data-less placeholder leaf with the given rank but fully
377    /// symbolic shape (every dim is a distinct `SymDim::TensorAxis`).
378    ///
379    /// Must be bound via [`TracedTensor::eval_with_inputs`] before
380    /// evaluation. Use this to build shape-agnostic graphs — in particular,
381    /// einsum calls containing at least one `input_symbolic_shape` input are
382    /// kept as a single `NaryEinsum` op so the contraction path can be
383    /// optimized at eval time against the actual bound shapes.
384    ///
385    /// # Examples
386    ///
387    /// ```
388    /// use tenferro_tensor::DType;
389    /// use tenferro::TracedTensor;
390    ///
391    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 2);
392    /// assert_eq!(x.rank, 2);
393    /// assert!(!x.is_concrete_shape());
394    /// ```
395    pub fn input_symbolic_shape(dtype: DType, rank: usize) -> Self {
396        let key = next_input_key();
397
398        let mut builder = FragmentBuilder::new();
399        let val = builder.add_input(key.clone());
400        builder.set_outputs(vec![val]);
401        let fragment = Arc::new(builder.build());
402
403        Self {
404            id: next_traced_id(),
405            rank,
406            dtype,
407            fragment,
408            val,
409            data: None,
410            shape_hint: None,
411            inputs_map: Arc::new(HashMap::new()),
412            extra_roots: Vec::new(),
413            checkpoint_chain: None,
414        }
415    }
416
417    /// Build a concrete-shape [`TracedTensor`] leaf from typed `Vec<T>`
418    /// data. Equivalent to
419    /// [`TracedTensor::from_tensor_concrete_shape`]`(Tensor::from_vec(shape, data))`.
420    ///
421    /// # Examples
422    ///
423    /// ```
424    /// use tenferro::TracedTensor;
425    ///
426    /// let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
427    /// assert_eq!(a.rank, 2);
428    /// ```
429    pub fn from_vec<T: TensorScalar>(shape: Vec<usize>, data: Vec<T>) -> Self {
430        Self::from_tensor_concrete_shape(T::into_tensor(shape, data))
431    }
432
433    /// Returns `true` iff every dim of this tensor's `shape_hint` is a
434    /// constant `SymDim` (i.e. the shape is fully known at graph-build time).
435    ///
436    /// # Examples
437    ///
438    /// ```
439    /// use tenferro_tensor::DType;
440    /// use tenferro::TracedTensor;
441    ///
442    /// let a = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64; 6]);
443    /// let b = TracedTensor::input_symbolic_shape(DType::F64, 2);
444    /// assert!(a.is_concrete_shape());
445    /// assert!(!b.is_concrete_shape());
446    /// ```
447    pub fn is_concrete_shape(&self) -> bool {
448        try_concrete_shape(self).is_some()
449    }
450
451    /// If this `TracedTensor` is a leaf (single-node input fragment),
452    /// return its input key. Computed tensors return `None`.
453    pub fn input_key(&self) -> Option<TensorInputKey> {
454        match &self.fragment.vals()[self.val].key {
455            GlobalValKey::Input(key) => Some(key.clone()),
456            _ => None,
457        }
458    }
459
460    pub fn eval<B: TensorBackend>(&mut self, engine: &mut Engine<B>) -> Result<&Tensor> {
461        self.eval_with_inputs(engine, &[])
462    }
463
464    /// Evaluate this traced tensor, binding external tensors to any
465    /// placeholder leaves present in the graph.
466    ///
467    /// Each `(placeholder, tensor)` pair maps a placeholder
468    /// [`TracedTensor`] (built via
469    /// [`TracedTensor::input_concrete_shape`] or
470    /// [`TracedTensor::input_symbolic_shape`]) to the concrete
471    /// [`Tensor`] that should take its place during execution. Leaves that
472    /// carry their own data (from [`TracedTensor::from_vec`],
473    /// [`TracedTensor::from_tensor_concrete_shape`], or
474    /// [`TracedTensor::from_tensor_symbolic_shape`]) must not appear in
475    /// `bindings`.
476    ///
477    /// # Errors
478    ///
479    /// - [`Error::UnexpectedBinding`] if a binding's left side is not a
480    ///   data-less placeholder leaf.
481    /// - [`Error::DuplicateBinding`] if the same placeholder key appears
482    ///   more than once in `bindings`.
483    /// - [`Error::PlaceholderDtypeMismatch`] if a binding tensor's dtype
484    ///   differs from the placeholder's declared dtype.
485    /// - [`Error::PlaceholderShapeMismatch`] if a binding tensor's shape
486    ///   differs from an `input_concrete_shape` placeholder's fixed shape.
487    /// - [`Error::PlaceholderRankMismatch`] if a binding tensor's rank
488    ///   differs from an `input_symbolic_shape` placeholder's rank.
489    /// - [`Error::UnboundPlaceholder`] if the compiled graph contains a
490    ///   placeholder that has no entry in `bindings`.
491    ///
492    /// # Examples
493    ///
494    /// ```
495    /// use tenferro::{CpuBackend, Engine, Tensor, TracedTensor};
496    /// use tenferro_tensor::DType;
497    ///
498    /// let mut engine = Engine::new(CpuBackend::new());
499    /// let x = TracedTensor::input_symbolic_shape(DType::F64, 1);
500    /// let mut y = &x + &x;
501    /// let concrete = Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
502    /// let out = y.eval_with_inputs(&mut engine, &[(&x, &concrete)]).unwrap();
503    /// assert_eq!(out.shape(), &[3]);
504    /// ```
505    pub fn eval_with_inputs<B: TensorBackend>(
506        &mut self,
507        engine: &mut Engine<B>,
508        bindings: &[(&TracedTensor, &Tensor)],
509    ) -> Result<&Tensor> {
510        // Build the binding map keyed on `TensorInputKey`, validating each
511        // binding as we go. We validate *before* the already-evaluated
512        // shortcut so that user mistakes (e.g. binding a data-carrying leaf)
513        // surface as errors rather than getting silently ignored when the
514        // output tensor happens to be cached.
515        let mut binding_map: HashMap<TensorInputKey, &Tensor> = HashMap::new();
516        for (index, (placeholder, tensor)) in bindings.iter().enumerate() {
517            if placeholder.data.is_some() {
518                return Err(Error::UnexpectedBinding {
519                    binding_index: index,
520                });
521            }
522            let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
523                binding_index: index,
524            })?;
525
526            if placeholder.dtype != tensor.dtype() {
527                return Err(Error::PlaceholderDtypeMismatch {
528                    expected: placeholder.dtype,
529                    actual: tensor.dtype(),
530                });
531            }
532
533            match try_concrete_shape(placeholder) {
534                Some(expected_shape) => {
535                    if expected_shape.as_slice() != tensor.shape() {
536                        return Err(Error::PlaceholderShapeMismatch {
537                            expected: expected_shape,
538                            actual: tensor.shape().to_vec(),
539                        });
540                    }
541                }
542                None => {
543                    if placeholder.rank != tensor.shape().len() {
544                        return Err(Error::PlaceholderRankMismatch {
545                            expected: placeholder.rank,
546                            actual: tensor.shape().len(),
547                        });
548                    }
549                }
550            }
551
552            if binding_map.insert(key.clone(), *tensor).is_some() {
553                return Err(Error::DuplicateBinding {
554                    input_key: format!("{:?}", key),
555                });
556            }
557        }
558
559        // Already-evaluated shortcut — safe to take now that bindings passed
560        // validation.
561        if self.data.is_some() {
562            return Ok(self.data.as_ref().unwrap().as_ref());
563        }
564
565        let output_key = self.fragment.vals()[self.val].key.clone();
566
567        let view = resolve(self.resolve_roots());
568        let graph = materialize_merge(&view, &[output_key]);
569        let compiled = compile(&graph);
570
571        let mut input_tensors = Vec::with_capacity(graph.inputs.len());
572        let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
573        let mut input_shapes = Vec::with_capacity(graph.inputs.len());
574        for key in &graph.inputs {
575            match key {
576                GlobalValKey::Input(k) => {
577                    if let Some(tensor) = self.inputs_map.get(k) {
578                        input_tensors.push(tensor.as_ref().clone());
579                        input_dtypes.push(tensor.dtype());
580                        input_shapes.push(DimExpr::from_concrete(tensor.shape()));
581                    } else if let Some(bound) = binding_map.remove(k) {
582                        input_tensors.push((*bound).clone());
583                        input_dtypes.push(bound.dtype());
584                        input_shapes.push(DimExpr::from_concrete(bound.shape()));
585                    } else {
586                        return Err(Error::UnboundPlaceholder {
587                            input_key: format!("{:?}", k),
588                        });
589                    }
590                }
591                _ => {
592                    return Err(Error::Internal(
593                        "expected Input key in graph inputs".to_string(),
594                    ));
595                }
596            }
597        }
598        let exec = compile_std_to_exec(&compiled, &input_dtypes, &input_shapes);
599
600        let cached_exec = engine.get_or_compile(exec);
601        let mut results = engine.eval_exec_ir(&cached_exec, input_tensors)?;
602        if results.len() != 1 {
603            return Err(Error::Internal(format!(
604                "expected 1 output, got {}",
605                results.len()
606            )));
607        }
608
609        self.data = Some(Arc::new(results.remove(0)));
610        Ok(self.data.as_ref().unwrap().as_ref())
611    }
612
613    pub fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor> {
614        if self.rank != 0 {
615            return Err(Error::NonScalarGrad {
616                shape: error_shape_hint(self),
617            });
618        }
619
620        let ones = ones_tensor(self.dtype, vec![]);
621        let seed = TracedTensor::from_tensor_concrete_shape(ones);
622        Ok(self.vjp(wrt, &seed))
623    }
624
625    /// Like [`grad`](Self::grad) but returns `None` when the scalar output does
626    /// not depend on `wrt`.
627    ///
628    /// # Examples
629    ///
630    /// ```rust,ignore
631    /// let maybe_dx = loss.try_grad(&x)?;
632    /// ```
633    pub fn try_grad(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>> {
634        if self.rank != 0 {
635            return Err(Error::NonScalarGrad {
636                shape: error_shape_hint(self),
637            });
638        }
639
640        let ones = ones_tensor(self.dtype, vec![]);
641        let seed = TracedTensor::from_tensor_concrete_shape(ones);
642        Ok(self.try_vjp(wrt, &seed))
643    }
644
645    /// Evaluate this tensor and replace its graph with a concrete leaf.
646    ///
647    /// This keeps downstream forward evaluation rooted at the concrete value
648    /// while retaining the original fragment chain for later reverse-mode AD.
649    ///
650    /// # Examples
651    ///
652    /// ```
653    /// use tenferro::{CpuBackend, Engine, TracedTensor};
654    ///
655    /// let mut engine = Engine::new(CpuBackend::new());
656    /// let x = TracedTensor::from_vec(vec![], vec![3.0_f64]);
657    /// let mut y = &x * &x;
658    /// y.checkpoint(&mut engine).unwrap();
659    /// assert_eq!(y.eval(&mut engine).unwrap().shape(), &[] as &[usize]);
660    /// ```
661    pub fn checkpoint<B: TensorBackend>(&mut self, engine: &mut Engine<B>) -> Result<()> {
662        self.eval(engine)?;
663        let data = self
664            .data
665            .clone()
666            .ok_or_else(|| Error::Internal("checkpoint eval did not populate data".to_string()))?;
667        let concrete_shape_hint = Some(data.shape().iter().copied().map(SymDim::from).collect());
668
669        let old_fragment = self.fragment.clone();
670        let old_output_key = old_fragment.vals()[self.val].key.clone();
671        let old_inputs = (*self.inputs_map).clone();
672
673        let new_key = next_input_key();
674        let mut builder = FragmentBuilder::new();
675        let leaf_val = builder.add_input(new_key.clone());
676        builder.set_outputs(vec![leaf_val]);
677        let new_fragment = Arc::new(builder.build());
678
679        let node = CheckpointNode {
680            fragment: old_fragment,
681            alias_key: new_key.clone(),
682            alias_target: old_output_key,
683            old_inputs,
684            prev: self.checkpoint_chain.take(),
685        };
686
687        self.fragment = new_fragment;
688        self.val = leaf_val;
689        self.extra_roots.clear();
690        self.shape_hint = concrete_shape_hint;
691        self.checkpoint_chain = Some(Arc::new(node));
692
693        let mut merged = HashMap::new();
694        if let Some(chain) = &self.checkpoint_chain {
695            merged.extend(chain.collect_inputs());
696        }
697        merged.insert(new_key, data);
698        self.inputs_map = Arc::new(merged);
699
700        Ok(())
701    }
702
703    pub fn jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> TracedTensor {
704        self.try_jvp(wrt, tangent)
705            .unwrap_or_else(|| panic!("jvp output is inactive for {:?}", leaf_input_key(wrt)))
706    }
707
708    /// Like [`jvp`](Self::jvp) but returns `None` when the output does not
709    /// depend on `wrt` (i.e. the tangent is structurally zero).
710    pub fn try_jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> Option<TracedTensor> {
711        let wrt_input_key = leaf_input_key(wrt);
712        let output_key = self.fragment.vals()[self.val].key.clone();
713        let aliases = self
714            .checkpoint_chain
715            .as_ref()
716            .map(|chain| chain.collect_aliases())
717            .unwrap_or_default();
718        let checkpoint_fragments = self
719            .checkpoint_chain
720            .as_ref()
721            .map(|chain| chain.collect_fragments())
722            .unwrap_or_default();
723        let mut roots = self.resolve_roots();
724        roots.extend(checkpoint_fragments.iter().cloned());
725        let view = resolve(roots);
726        let mut ad_ctx = ShapeGuardContext::default();
727        let linear = differentiate(
728            &view,
729            std::slice::from_ref(&output_key),
730            std::slice::from_ref(&wrt_input_key),
731            next_pass_id(),
732            &mut ad_ctx,
733            &aliases,
734        );
735        let tangent_output = linear.tangent_outputs[0]?;
736        let tangent_input_key = linear_input_key(&linear.fragment, linear.tangent_inputs[0].1);
737
738        let mut inputs_map = (*self.inputs_map).clone();
739        if let Some(chain) = &self.checkpoint_chain {
740            inputs_map.extend(chain.collect_inputs());
741        }
742        inputs_map.insert(
743            tangent_input_key,
744            tangent
745                .data
746                .clone()
747                .unwrap_or_else(|| panic!("jvp tangent must have concrete tensor data")),
748        );
749
750        let mut extra_roots = vec![self.fragment.clone()];
751        extra_roots.extend(checkpoint_fragments);
752        extra_roots.extend(self.extra_roots.iter().cloned());
753
754        Some(TracedTensor {
755            id: next_traced_id(),
756            rank: self.rank,
757            dtype: self.dtype,
758            fragment: Arc::new(linear.fragment),
759            val: tangent_output,
760            data: None,
761            shape_hint: self.shape_hint.clone(),
762            inputs_map: Arc::new(inputs_map),
763            extra_roots,
764            checkpoint_chain: self.checkpoint_chain.clone(),
765        })
766    }
767
768    pub fn vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> TracedTensor {
769        self.try_vjp(wrt, cotangent)
770            .unwrap_or_else(|| panic!("vjp output is inactive for {:?}", leaf_input_key(wrt)))
771    }
772
773    fn try_vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> Option<TracedTensor> {
774        let wrt_input_key = leaf_input_key(wrt);
775        let output_key = self.fragment.vals()[self.val].key.clone();
776        let aliases = self
777            .checkpoint_chain
778            .as_ref()
779            .map(|chain| chain.collect_aliases())
780            .unwrap_or_default();
781        let checkpoint_fragments = self
782            .checkpoint_chain
783            .as_ref()
784            .map(|chain| chain.collect_fragments())
785            .unwrap_or_default();
786        let mut roots = self.resolve_roots();
787        roots.extend(checkpoint_fragments.iter().cloned());
788        let view = resolve(roots);
789        let mut ad_ctx = ShapeGuardContext::default();
790        let linear = differentiate(
791            &view,
792            std::slice::from_ref(&output_key),
793            std::slice::from_ref(&wrt_input_key),
794            next_pass_id(),
795            &mut ad_ctx,
796            &aliases,
797        );
798        let linear_tangent_input_ids: Vec<LocalValId> = linear
799            .tangent_inputs
800            .iter()
801            .map(|(_, local_id)| *local_id)
802            .collect();
803        let transposed = transpose(&linear, &mut ad_ctx);
804        let linear_fragment = Arc::new(linear.fragment);
805        let cotangent_output = transposed.tangent_outputs[0]?;
806        let cotangent_input_key =
807            linear_input_key(&transposed.fragment, transposed.tangent_inputs[0].1);
808
809        let mut inputs_map = (*self.inputs_map).clone();
810        if let Some(chain) = &self.checkpoint_chain {
811            inputs_map.extend(chain.collect_inputs());
812        }
813        inputs_map.insert(
814            cotangent_input_key.clone(),
815            cotangent
816                .data
817                .clone()
818                .unwrap_or_else(|| panic!("vjp cotangent must have concrete tensor data")),
819        );
820        let zero_tangent = Arc::new(zeros_tensor(
821            wrt.dtype,
822            try_concrete_shape(wrt).unwrap_or_else(|| vec![0; wrt.rank]),
823        ));
824        for (_, local_id) in &transposed.tangent_inputs {
825            let tangent_input_key = linear_input_key(&transposed.fragment, *local_id);
826            if tangent_input_key != cotangent_input_key {
827                inputs_map.insert(tangent_input_key, Arc::clone(&zero_tangent));
828            }
829        }
830        for local_id in linear_tangent_input_ids {
831            let tangent_input_key = linear_input_key(&linear_fragment, local_id);
832            inputs_map.insert(tangent_input_key, Arc::clone(&zero_tangent));
833        }
834
835        let mut extra_roots = vec![self.fragment.clone(), linear_fragment];
836        extra_roots.extend(checkpoint_fragments);
837        extra_roots.extend(self.extra_roots.iter().cloned());
838
839        Some(TracedTensor {
840            id: next_traced_id(),
841            rank: wrt.rank,
842            dtype: wrt.dtype,
843            fragment: Arc::new(transposed.fragment),
844            val: cotangent_output,
845            data: None,
846            shape_hint: wrt.shape_hint.clone(),
847            inputs_map: Arc::new(inputs_map),
848            extra_roots,
849            checkpoint_chain: self.checkpoint_chain.clone(),
850        })
851    }
852
853    /// Elementwise addition with NumPy-style broadcasting.
854    ///
855    /// Prefer using the `+` operator when it reads naturally.
856    ///
857    /// # Examples
858    ///
859    /// ```rust,ignore
860    /// let y = x.add(&z);
861    /// let y2 = &x + &z;
862    /// ```
863    pub fn add(&self, other: &TracedTensor) -> TracedTensor {
864        let (lhs, rhs) = broadcast_binary(self, other);
865        apply_binary(
866            StdTensorOp::Add,
867            &lhs,
868            &rhs,
869            lhs.rank,
870            lhs.shape_hint.clone(),
871        )
872    }
873
874    /// Elementwise multiplication with NumPy-style broadcasting.
875    ///
876    /// Prefer using the `*` operator when it reads naturally.
877    ///
878    /// # Examples
879    ///
880    /// ```rust,ignore
881    /// let y = x.mul(&z);
882    /// let y2 = &x * &z;
883    /// ```
884    pub fn mul(&self, other: &TracedTensor) -> TracedTensor {
885        let (lhs, rhs) = broadcast_binary(self, other);
886        apply_binary(
887            StdTensorOp::Mul,
888            &lhs,
889            &rhs,
890            lhs.rank,
891            lhs.shape_hint.clone(),
892        )
893    }
894
895    /// Elementwise division with NumPy-style broadcasting.
896    ///
897    /// Prefer using the `/` operator when it reads naturally.
898    ///
899    /// # Examples
900    ///
901    /// ```rust,ignore
902    /// let y = x.div(&z);
903    /// let y2 = &x / &z;
904    /// ```
905    pub fn div(&self, other: &TracedTensor) -> TracedTensor {
906        let (lhs, rhs) = broadcast_binary(self, other);
907        apply_binary(
908            StdTensorOp::Div,
909            &lhs,
910            &rhs,
911            lhs.rank,
912            lhs.shape_hint.clone(),
913        )
914    }
915
916    /// Elementwise negation.
917    ///
918    /// Prefer using the unary `-` operator when it reads naturally.
919    ///
920    /// # Examples
921    ///
922    /// ```rust,ignore
923    /// let y = x.neg();
924    /// let y2 = -&x;
925    /// ```
926    pub fn neg(&self) -> TracedTensor {
927        apply_unary(StdTensorOp::Neg, self, self.rank, self.shape_hint.clone())
928    }
929
930    /// Elementwise complex conjugate.
931    ///
932    /// # Examples
933    ///
934    /// ```rust,ignore
935    /// let y = x.conj();
936    /// ```
937    pub fn conj(&self) -> TracedTensor {
938        apply_unary(StdTensorOp::Conj, self, self.rank, self.shape_hint.clone())
939    }
940
941    /// Elementwise absolute value.
942    ///
943    /// # Examples
944    ///
945    /// ```rust,ignore
946    /// let y = x.abs();
947    /// ```
948    pub fn abs(&self) -> TracedTensor {
949        apply_unary(StdTensorOp::Abs, self, self.rank, self.shape_hint.clone())
950    }
951
952    /// Elementwise sign.
953    ///
954    /// # Examples
955    ///
956    /// ```rust,ignore
957    /// let y = x.sign();
958    /// ```
959    pub fn sign(&self) -> TracedTensor {
960        apply_unary(StdTensorOp::Sign, self, self.rank, self.shape_hint.clone())
961    }
962
963    /// Scale by a real scalar: `y = factor * x`.
964    ///
965    /// # Examples
966    ///
967    /// ```rust,ignore
968    /// let y = x.scale_real(2.0);
969    /// ```
970    pub fn scale_real(&self, factor: f64) -> TracedTensor {
971        let op = match self.dtype {
972            DType::F64 => StdTensorOp::constant_f64(factor),
973            DType::F32 => StdTensorOp::constant_f32(factor as f32),
974            DType::C64 => StdTensorOp::constant_c64(Complex64::new(factor, 0.0)),
975            DType::C32 => StdTensorOp::constant_c32(Complex32::new(factor as f32, 0.0)),
976        };
977        scale_with_constant(self, op)
978    }
979
980    /// Scale by a complex scalar: `y = factor * x`.
981    ///
982    /// This currently supports complex tensors only. For real scaling, prefer
983    /// [`scale_real`](Self::scale_real).
984    ///
985    /// # Examples
986    ///
987    /// ```rust,ignore
988    /// use num_complex::Complex64;
989    /// let y = x.scale_complex(Complex64::new(0.0, 1.0)); // multiply by i
990    /// ```
991    pub fn scale_complex(&self, factor: Complex64) -> TracedTensor {
992        match self.dtype {
993            DType::C64 => scale_with_constant(self, StdTensorOp::constant_c64(factor)),
994            DType::C32 => scale_with_constant(
995                self,
996                StdTensorOp::constant_c32(Complex32::new(factor.re as f32, factor.im as f32)),
997            ),
998            DType::F32 | DType::F64 => {
999                panic!(
1000                    "scale_complex only supports complex tensors; use scale_real for real tensors"
1001                )
1002            }
1003        }
1004    }
1005
1006    /// Elementwise exponential.
1007    ///
1008    /// # Examples
1009    ///
1010    /// ```rust,ignore
1011    /// let y = x.exp();
1012    /// ```
1013    pub fn exp(&self) -> TracedTensor {
1014        apply_unary(StdTensorOp::Exp, self, self.rank, self.shape_hint.clone())
1015    }
1016
1017    /// Elementwise natural logarithm.
1018    ///
1019    /// # Examples
1020    ///
1021    /// ```rust,ignore
1022    /// let y = x.log();
1023    /// ```
1024    pub fn log(&self) -> TracedTensor {
1025        apply_unary(StdTensorOp::Log, self, self.rank, self.shape_hint.clone())
1026    }
1027
1028    /// Elementwise sine.
1029    ///
1030    /// # Examples
1031    ///
1032    /// ```rust,ignore
1033    /// let y = x.sin();
1034    /// ```
1035    pub fn sin(&self) -> TracedTensor {
1036        apply_unary(StdTensorOp::Sin, self, self.rank, self.shape_hint.clone())
1037    }
1038
1039    /// Elementwise cosine.
1040    ///
1041    /// # Examples
1042    ///
1043    /// ```rust,ignore
1044    /// let y = x.cos();
1045    /// ```
1046    pub fn cos(&self) -> TracedTensor {
1047        apply_unary(StdTensorOp::Cos, self, self.rank, self.shape_hint.clone())
1048    }
1049
1050    /// Elementwise hyperbolic tangent.
1051    ///
1052    /// # Examples
1053    ///
1054    /// ```rust,ignore
1055    /// let y = x.tanh();
1056    /// ```
1057    pub fn tanh(&self) -> TracedTensor {
1058        apply_unary(StdTensorOp::Tanh, self, self.rank, self.shape_hint.clone())
1059    }
1060
1061    /// Elementwise square root.
1062    ///
1063    /// # Examples
1064    ///
1065    /// ```rust,ignore
1066    /// let y = x.sqrt();
1067    /// ```
1068    pub fn sqrt(&self) -> TracedTensor {
1069        apply_unary(StdTensorOp::Sqrt, self, self.rank, self.shape_hint.clone())
1070    }
1071
1072    /// Elementwise reciprocal square root.
1073    ///
1074    /// # Examples
1075    ///
1076    /// ```rust,ignore
1077    /// let y = x.rsqrt();
1078    /// ```
1079    pub fn rsqrt(&self) -> TracedTensor {
1080        apply_unary(StdTensorOp::Rsqrt, self, self.rank, self.shape_hint.clone())
1081    }
1082
1083    /// Elementwise power with NumPy-style broadcasting.
1084    ///
1085    /// # Examples
1086    ///
1087    /// ```rust,ignore
1088    /// let y = base.pow(&exp);
1089    /// ```
1090    pub fn pow(&self, other: &TracedTensor) -> TracedTensor {
1091        let (lhs, rhs) = broadcast_binary(self, other);
1092        apply_binary(
1093            StdTensorOp::Pow,
1094            &lhs,
1095            &rhs,
1096            lhs.rank,
1097            lhs.shape_hint.clone(),
1098        )
1099    }
1100
1101    /// Elementwise `exp(x) - 1`.
1102    ///
1103    /// # Examples
1104    ///
1105    /// ```rust,ignore
1106    /// let y = x.expm1();
1107    /// ```
1108    pub fn expm1(&self) -> TracedTensor {
1109        apply_unary(StdTensorOp::Expm1, self, self.rank, self.shape_hint.clone())
1110    }
1111
1112    /// Elementwise `log(1 + x)`.
1113    ///
1114    /// # Examples
1115    ///
1116    /// ```rust,ignore
1117    /// let y = x.log1p();
1118    /// ```
1119    pub fn log1p(&self) -> TracedTensor {
1120        apply_unary(StdTensorOp::Log1p, self, self.rank, self.shape_hint.clone())
1121    }
1122
1123    /// Convert the tensor to a different dtype.
1124    ///
1125    /// For real-to-complex conversions this embeds the real values as
1126    /// `x + 0i`. For complex-to-real conversions this extracts the real part.
1127    ///
1128    /// # Examples
1129    ///
1130    /// ```rust,ignore
1131    /// use tenferro::DType;
1132    ///
1133    /// let y = x.convert(DType::C64);
1134    /// ```
1135    pub fn convert(&self, to: DType) -> TracedTensor {
1136        if self.dtype == to {
1137            return self.clone();
1138        }
1139
1140        apply_unary_with_dtype(
1141            StdTensorOp::Convert {
1142                from: self.dtype,
1143                to,
1144            },
1145            self,
1146            self.rank,
1147            self.shape_hint.clone(),
1148            to,
1149        )
1150    }
1151
1152    /// Generalized tensor contraction.
1153    ///
1154    /// # Examples
1155    ///
1156    /// ```rust,ignore
1157    /// let y = a.dot_general(&b, config);
1158    /// ```
1159    pub fn dot_general(&self, other: &TracedTensor, config: DotGeneralConfig) -> TracedTensor {
1160        config
1161            .validate_ranks(self.rank, other.rank)
1162            .expect("DotGeneral config rank validation failed");
1163        config
1164            .validate_dims()
1165            .expect("DotGeneral config dimension validation failed");
1166        let lhs_free: Vec<usize> = (0..config.lhs_rank)
1167            .filter(|d| {
1168                !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d)
1169            })
1170            .collect();
1171        let rhs_free: Vec<usize> = (0..config.rhs_rank)
1172            .filter(|d| {
1173                !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d)
1174            })
1175            .collect();
1176        let out_rank = config.lhs_batch_dims.len() + lhs_free.len() + rhs_free.len();
1177        let out_shape_hint = match (&self.shape_hint, &other.shape_hint) {
1178            (Some(lhs_shape), Some(rhs_shape)) => {
1179                let mut out_shape = Vec::with_capacity(out_rank);
1180                for &d in &lhs_free {
1181                    out_shape.push(lhs_shape[d].clone());
1182                }
1183                for &d in &rhs_free {
1184                    out_shape.push(rhs_shape[d].clone());
1185                }
1186                for &d in &config.lhs_batch_dims {
1187                    out_shape.push(lhs_shape[d].clone());
1188                }
1189                Some(out_shape)
1190            }
1191            _ => None,
1192        };
1193
1194        apply_binary(
1195            StdTensorOp::DotGeneral(config),
1196            self,
1197            other,
1198            out_rank,
1199            out_shape_hint,
1200        )
1201    }
1202
1203    /// Sum over the given axes.
1204    ///
1205    /// # Examples
1206    ///
1207    /// ```rust,ignore
1208    /// let y = x.reduce_sum(&[0]);
1209    /// let y2 = x.sum(&[0]);
1210    /// ```
1211    pub fn reduce_sum(&self, axes: &[usize]) -> TracedTensor {
1212        let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1213            (0..shape.len())
1214                .filter(|d| !axes.contains(d))
1215                .map(|d| shape[d].clone())
1216                .collect()
1217        });
1218        apply_unary(
1219            StdTensorOp::ReduceSum {
1220                axes: axes.to_vec(),
1221                input_shape: DimExpr::input_shape(0, self.rank),
1222            },
1223            self,
1224            self.rank - axes.len(),
1225            out_shape_hint,
1226        )
1227    }
1228
1229    /// Reshape without changing element order.
1230    ///
1231    /// # Examples
1232    ///
1233    /// ```rust,ignore
1234    /// let y = x.reshape(&[2, 2]);
1235    /// ```
1236    pub fn reshape(&self, shape: &[usize]) -> TracedTensor {
1237        apply_unary(
1238            StdTensorOp::Reshape {
1239                from_shape: DimExpr::input_shape(0, self.rank),
1240                to_shape: DimExpr::from_concrete(shape),
1241            },
1242            self,
1243            shape.len(),
1244            Some(shape.iter().copied().map(SymDim::from).collect()),
1245        )
1246    }
1247
1248    /// Return a symbolic expression for the size of one axis.
1249    ///
1250    /// # Examples
1251    ///
1252    /// ```rust,ignore
1253    /// let rows = x.sym_size(0);
1254    /// let cols = x.sym_size(1);
1255    /// let y = x.reshape_sym(&[rows * cols])?;
1256    /// ```
1257    pub fn sym_size(&self, axis: usize) -> SymDim {
1258        assert!(
1259            axis < self.rank,
1260            "axis {axis} out of bounds for rank {}",
1261            self.rank
1262        );
1263        self.shape_hint
1264            .as_ref()
1265            .and_then(|shape| shape.get(axis))
1266            .filter(|dim| dim.constant_value().is_none())
1267            .cloned()
1268            .unwrap_or_else(|| SymDim::tensor_axis(self.id, axis))
1269    }
1270
1271    /// Reshape using symbolic dimensions derived from traced tensor axes.
1272    ///
1273    /// # Examples
1274    ///
1275    /// ```rust,ignore
1276    /// let rows = x.sym_size(0);
1277    /// let cols = x.sym_size(1);
1278    /// let y = x.reshape_sym(&[rows * cols])?;
1279    /// ```
1280    pub fn reshape_sym(&self, shape: &[SymDim]) -> Result<TracedTensor> {
1281        let tensor_map = [(self.id, 0usize)];
1282        let to_shape = shape
1283            .iter()
1284            .map(|dim| dim.to_dim_expr(&tensor_map).map_err(Error::Internal))
1285            .collect::<Result<Vec<_>>>()?;
1286        let out_shape_hint = Some(shape.to_vec());
1287        Ok(apply_unary(
1288            StdTensorOp::Reshape {
1289                from_shape: DimExpr::input_shape(0, self.rank),
1290                to_shape,
1291            },
1292            self,
1293            shape.len(),
1294            out_shape_hint,
1295        ))
1296    }
1297
1298    /// Broadcast into a larger shape with explicit dimension placement.
1299    ///
1300    /// # Examples
1301    ///
1302    /// ```rust,ignore
1303    /// let y = x.broadcast_in_dim(&[2, 3], &[1]);
1304    /// let y2 = x.broadcast(&[2, 3], &[1]);
1305    /// ```
1306    pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> TracedTensor {
1307        apply_unary(
1308            StdTensorOp::BroadcastInDim {
1309                shape: DimExpr::from_concrete(shape),
1310                dims: dims.to_vec(),
1311            },
1312            self,
1313            shape.len(),
1314            Some(shape.iter().copied().map(SymDim::from).collect()),
1315        )
1316    }
1317
1318    /// Permute tensor axes.
1319    ///
1320    /// # Examples
1321    ///
1322    /// ```rust,ignore
1323    /// let y = x.transpose(&[1, 0]);
1324    /// ```
1325    pub fn transpose(&self, perm: &[usize]) -> TracedTensor {
1326        let out_shape_hint = self
1327            .shape_hint
1328            .as_ref()
1329            .map(|shape| perm.iter().map(|&p| shape[p].clone()).collect());
1330        apply_unary(
1331            StdTensorOp::Transpose {
1332                perm: perm.to_vec(),
1333            },
1334            self,
1335            self.rank,
1336            out_shape_hint,
1337        )
1338    }
1339
1340    /// Extract the diagonal along two axes.
1341    ///
1342    /// # Examples
1343    ///
1344    /// ```rust,ignore
1345    /// let y = x.extract_diag(0, 1);
1346    /// ```
1347    pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> TracedTensor {
1348        assert!(
1349            axis_a < self.rank && axis_b < self.rank && axis_a != axis_b,
1350            "extract_diag: invalid axes"
1351        );
1352        let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1353            shape
1354                .iter()
1355                .enumerate()
1356                .filter_map(|(axis, dim)| (axis != axis_b).then_some(dim.clone()))
1357                .collect()
1358        });
1359        apply_unary(
1360            StdTensorOp::ExtractDiag { axis_a, axis_b },
1361            self,
1362            self.rank - 1,
1363            out_shape_hint,
1364        )
1365    }
1366
1367    /// Embed a vector or lower-rank tensor along a diagonal.
1368    ///
1369    /// # Examples
1370    ///
1371    /// ```rust,ignore
1372    /// let y = x.embed_diag(0, 1);
1373    /// ```
1374    pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> TracedTensor {
1375        assert!(
1376            axis_a < self.rank && axis_b <= self.rank,
1377            "embed_diag: invalid axes"
1378        );
1379        let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1380            let mut out_shape = shape.clone();
1381            out_shape.insert(axis_b, shape[axis_a].clone());
1382            out_shape
1383        });
1384        apply_unary(
1385            StdTensorOp::EmbedDiag { axis_a, axis_b },
1386            self,
1387            self.rank + 1,
1388            out_shape_hint,
1389        )
1390    }
1391
1392    /// Alias for [`Self::reduce_sum`].
1393    ///
1394    /// # Examples
1395    ///
1396    /// ```rust,ignore
1397    /// let y = x.sum(&[0]);
1398    /// ```
1399    pub fn sum(&self, axes: &[usize]) -> TracedTensor {
1400        self.reduce_sum(axes)
1401    }
1402
1403    /// Alias for [`Self::broadcast_in_dim`].
1404    ///
1405    /// # Examples
1406    ///
1407    /// ```rust,ignore
1408    /// let y = x.broadcast(&[2, 3], &[1]);
1409    /// ```
1410    pub fn broadcast(&self, shape: &[usize], dims: &[usize]) -> TracedTensor {
1411        self.broadcast_in_dim(shape, dims)
1412    }
1413
1414    /// Return the runtime size of one axis as a scalar `f64` tensor.
1415    ///
1416    /// The result is metadata-derived and therefore has no gradient.
1417    ///
1418    /// # Examples
1419    ///
1420    /// ```
1421    /// use tenferro::{CpuBackend, Engine, TracedTensor};
1422    ///
1423    /// let mut engine = Engine::new(CpuBackend::new());
1424    /// let x = TracedTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
1425    /// let mut cols = x.shape_of(1);
1426    /// assert_eq!(cols.eval(&mut engine).unwrap().shape(), &[] as &[usize]);
1427    /// ```
1428    pub fn shape_of(&self, axis: usize) -> TracedTensor {
1429        assert!(
1430            axis < self.rank,
1431            "axis {axis} out of bounds for rank {}",
1432            self.rank
1433        );
1434        apply_unary_with_dtype(
1435            StdTensorOp::ShapeOf { axis },
1436            self,
1437            0,
1438            Some(vec![]),
1439            DType::F64,
1440        )
1441    }
1442
1443    /// Truncate this tensor along `axis` to the first `size` elements.
1444    ///
1445    /// `size` is read at runtime from a scalar traced tensor. Values are
1446    /// rounded to the nearest integer, clamped to `[0, self.shape[axis]]`,
1447    /// and the output keeps the same element dtype as the input.
1448    ///
1449    /// # Examples
1450    ///
1451    /// ```
1452    /// use tenferro::{CpuBackend, Engine, TracedTensor};
1453    ///
1454    /// let mut engine = Engine::new(CpuBackend::new());
1455    /// let x = TracedTensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]);
1456    /// let size = TracedTensor::from_vec(vec![], vec![2.0_f64]);
1457    /// let mut y = x.dynamic_truncate(&size, 0);
1458    /// assert_eq!(y.eval(&mut engine).unwrap().shape(), &[2]);
1459    /// ```
1460    pub fn dynamic_truncate(&self, size: &TracedTensor, axis: usize) -> TracedTensor {
1461        assert!(
1462            axis < self.rank,
1463            "axis {axis} out of bounds for rank {}",
1464            self.rank
1465        );
1466        assert!(
1467            size.rank == 0,
1468            "dynamic_truncate size must be a scalar tensor, got rank {}",
1469            size.rank
1470        );
1471        apply_binary(
1472            StdTensorOp::DynamicTruncate { axis },
1473            self,
1474            size,
1475            self.rank,
1476            None,
1477        )
1478    }
1479
1480    /// Pad this tensor with zeros along `axis` to match `reference.shape[axis]`.
1481    ///
1482    /// If `reference` is smaller along that axis, this is a no-op.
1483    ///
1484    /// # Examples
1485    ///
1486    /// ```
1487    /// use tenferro::{CpuBackend, Engine, TracedTensor};
1488    ///
1489    /// let mut engine = Engine::new(CpuBackend::new());
1490    /// let x = TracedTensor::from_vec(vec![2], vec![1.0_f64, 2.0]);
1491    /// let reference = TracedTensor::from_vec(vec![4], vec![0.0_f64, 0.0, 0.0, 0.0]);
1492    /// let mut y = x.pad_to_match(&reference, 0);
1493    /// assert_eq!(y.eval(&mut engine).unwrap().shape(), &[4]);
1494    /// ```
1495    pub fn pad_to_match(&self, reference: &TracedTensor, axis: usize) -> TracedTensor {
1496        assert!(
1497            axis < self.rank,
1498            "axis {axis} out of bounds for rank {}",
1499            self.rank
1500        );
1501        assert!(
1502            axis < reference.rank,
1503            "reference axis {axis} out of bounds for rank {}",
1504            reference.rank
1505        );
1506        apply_binary(
1507            StdTensorOp::PadToMatch { axis },
1508            self,
1509            reference,
1510            self.rank,
1511            reference.shape_hint.clone(),
1512        )
1513    }
1514}
1515
1516pub(crate) fn apply_unary(
1517    op: StdTensorOp,
1518    input: &TracedTensor,
1519    out_rank: usize,
1520    out_shape_hint: Option<Vec<SymDim>>,
1521) -> TracedTensor {
1522    apply_unary_with_dtype(op, input, out_rank, out_shape_hint, input.dtype)
1523}
1524
1525pub(crate) fn apply_unary_with_dtype(
1526    op: StdTensorOp,
1527    input: &TracedTensor,
1528    out_rank: usize,
1529    out_shape_hint: Option<Vec<SymDim>>,
1530    out_dtype: DType,
1531) -> TracedTensor {
1532    let mut builder = FragmentBuilder::new();
1533    builder.add_parent(input.fragment.clone());
1534    let input_ref = ValRef::External(input.fragment.vals()[input.val].key.clone());
1535    let outputs = builder.add_op(op, vec![input_ref], OpMode::Primal);
1536    builder.set_outputs(outputs.clone());
1537    let fragment = Arc::new(builder.build());
1538
1539    TracedTensor {
1540        id: next_traced_id(),
1541        rank: out_rank,
1542        dtype: out_dtype,
1543        fragment,
1544        val: outputs[0],
1545        data: None,
1546        shape_hint: out_shape_hint,
1547        inputs_map: input.inputs_map.clone(),
1548        extra_roots: input.extra_roots.clone(),
1549        checkpoint_chain: input.checkpoint_chain.clone(),
1550    }
1551}
1552
1553pub(crate) fn apply_nullary(
1554    op: StdTensorOp,
1555    rank: usize,
1556    dtype: DType,
1557    shape_hint: Option<Vec<SymDim>>,
1558) -> TracedTensor {
1559    let mut builder = FragmentBuilder::new();
1560    let outputs = builder.add_op(op, vec![], OpMode::Primal);
1561    builder.set_outputs(outputs.clone());
1562    let fragment = Arc::new(builder.build());
1563
1564    TracedTensor {
1565        id: next_traced_id(),
1566        rank,
1567        dtype,
1568        fragment,
1569        val: outputs[0],
1570        data: None,
1571        shape_hint,
1572        inputs_map: Arc::new(HashMap::new()),
1573        extra_roots: Vec::new(),
1574        checkpoint_chain: None,
1575    }
1576}
1577
1578pub(crate) fn apply_binary(
1579    op: StdTensorOp,
1580    lhs: &TracedTensor,
1581    rhs: &TracedTensor,
1582    out_rank: usize,
1583    out_shape_hint: Option<Vec<SymDim>>,
1584) -> TracedTensor {
1585    let mut builder = FragmentBuilder::new();
1586    builder.add_parent(lhs.fragment.clone());
1587    builder.add_parent(rhs.fragment.clone());
1588    let lhs_ref = ValRef::External(lhs.fragment.vals()[lhs.val].key.clone());
1589    let rhs_ref = ValRef::External(rhs.fragment.vals()[rhs.val].key.clone());
1590    let outputs = builder.add_op(op, vec![lhs_ref, rhs_ref], OpMode::Primal);
1591    builder.set_outputs(outputs.clone());
1592    let fragment = Arc::new(builder.build());
1593
1594    let mut merged = (*lhs.inputs_map).clone();
1595    merged.extend(rhs.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
1596    let mut extra_roots = lhs.extra_roots.clone();
1597    extra_roots.extend(rhs.extra_roots.iter().cloned());
1598
1599    TracedTensor {
1600        id: next_traced_id(),
1601        rank: out_rank,
1602        dtype: lhs.dtype,
1603        fragment,
1604        val: outputs[0],
1605        data: None,
1606        shape_hint: out_shape_hint,
1607        inputs_map: Arc::new(merged),
1608        extra_roots,
1609        checkpoint_chain: CheckpointNode::merge_chains(
1610            lhs.checkpoint_chain.clone(),
1611            rhs.checkpoint_chain.clone(),
1612        ),
1613    }
1614}
1615
1616pub(crate) fn apply_multi_output(
1617    op: StdTensorOp,
1618    input: &TracedTensor,
1619    output_shapes: Vec<Vec<SymDim>>,
1620) -> Vec<TracedTensor> {
1621    let mut builder = FragmentBuilder::new();
1622    builder.add_parent(input.fragment.clone());
1623    let input_ref = ValRef::External(input.fragment.vals()[input.val].key.clone());
1624    let outputs = builder.add_op(op, vec![input_ref], OpMode::Primal);
1625    builder.set_outputs(outputs.clone());
1626    let fragment = Arc::new(builder.build());
1627    assert_eq!(
1628        outputs.len(),
1629        output_shapes.len(),
1630        "apply_multi_output: output count must match output_shapes"
1631    );
1632
1633    outputs
1634        .iter()
1635        .zip(output_shapes)
1636        .map(|(&val, shape)| TracedTensor {
1637            id: next_traced_id(),
1638            rank: shape.len(),
1639            dtype: input.dtype,
1640            fragment: fragment.clone(),
1641            val,
1642            data: None,
1643            shape_hint: Some(shape),
1644            inputs_map: input.inputs_map.clone(),
1645            extra_roots: input.extra_roots.clone(),
1646            checkpoint_chain: input.checkpoint_chain.clone(),
1647        })
1648        .collect()
1649}
1650
1651impl TracedTensor {
1652    fn resolve_roots(&self) -> Vec<Arc<Fragment<StdTensorOp>>> {
1653        let mut roots = Vec::with_capacity(1 + self.extra_roots.len());
1654        roots.push(self.fragment.clone());
1655        roots.extend(self.extra_roots.iter().cloned());
1656        roots
1657    }
1658}
1659
1660fn leaf_input_key(tt: &TracedTensor) -> TensorInputKey {
1661    match &tt.fragment.vals()[tt.val].key {
1662        GlobalValKey::Input(key) => key.clone(),
1663        other => panic!("expected traced leaf input, got {:?}", other),
1664    }
1665}
1666
1667fn linear_input_key(fragment: &Fragment<StdTensorOp>, local_id: LocalValId) -> TensorInputKey {
1668    match &fragment.vals()[local_id].key {
1669        GlobalValKey::Input(key) => key.clone(),
1670        other => panic!("expected linear fragment input, got {:?}", other),
1671    }
1672}
1673
1674fn ones_tensor(dtype: DType, shape: Vec<usize>) -> Tensor {
1675    match dtype {
1676        DType::F32 => Tensor::F32(TypedTensor::ones(shape)),
1677        DType::F64 => Tensor::F64(TypedTensor::ones(shape)),
1678        DType::C32 => Tensor::C32(TypedTensor::ones(shape)),
1679        DType::C64 => Tensor::C64(TypedTensor::ones(shape)),
1680    }
1681}
1682
1683fn zeros_tensor(dtype: DType, shape: Vec<usize>) -> Tensor {
1684    match dtype {
1685        DType::F32 => Tensor::F32(TypedTensor::zeros(shape)),
1686        DType::F64 => Tensor::F64(TypedTensor::zeros(shape)),
1687        DType::C32 => Tensor::C32(TypedTensor::zeros(shape)),
1688        DType::C64 => Tensor::C64(TypedTensor::zeros(shape)),
1689    }
1690}
1691
1692pub fn eval_all<B: TensorBackend>(
1693    engine: &mut Engine<B>,
1694    outputs: &mut [&mut TracedTensor],
1695) -> Result<Vec<Tensor>> {
1696    let mut all_fragments = Vec::new();
1697    let mut output_keys = Vec::new();
1698    let mut all_inputs: HashMap<TensorInputKey, Arc<Tensor>> = HashMap::new();
1699
1700    for tt in outputs.iter() {
1701        all_fragments.extend(tt.resolve_roots());
1702        output_keys.push(tt.fragment.vals()[tt.val].key.clone());
1703        all_inputs.extend(tt.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
1704    }
1705
1706    let view = resolve(all_fragments);
1707    let graph = materialize_merge(&view, &output_keys);
1708    let compiled = compile(&graph);
1709
1710    let mut input_tensors = Vec::with_capacity(graph.inputs.len());
1711    let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
1712    let mut input_shapes = Vec::with_capacity(graph.inputs.len());
1713    for key in &graph.inputs {
1714        match key {
1715            GlobalValKey::Input(k) => {
1716                let tensor = all_inputs.get(k).ok_or_else(|| {
1717                    Error::MissingInput(format!("missing input data for key {:?}", k))
1718                })?;
1719                input_tensors.push(tensor.as_ref().clone());
1720                input_dtypes.push(tensor.dtype());
1721                input_shapes.push(DimExpr::from_concrete(tensor.shape()));
1722            }
1723            _ => {
1724                return Err(Error::Internal(
1725                    "expected Input key in graph inputs".to_string(),
1726                ));
1727            }
1728        }
1729    }
1730    let exec = compile_std_to_exec(&compiled, &input_dtypes, &input_shapes);
1731
1732    let cached_exec = engine.get_or_compile(exec);
1733    let results: Vec<Tensor> = engine.eval_exec_ir(&cached_exec, input_tensors)?;
1734
1735    for (tt, result) in outputs.iter_mut().zip(results.iter()) {
1736        tt.data = Some(Arc::new(result.clone()));
1737    }
1738
1739    Ok(results)
1740}