1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4
5use computegraph::compile::compile;
6use computegraph::fragment::{Fragment, FragmentBuilder};
7use computegraph::materialize::materialize_merge;
8use computegraph::resolve::resolve;
9use computegraph::types::{GlobalValKey, OpMode, ValRef};
10use computegraph::LocalValId;
11use num_complex::{Complex32, Complex64};
12use tenferro_ops::dim_expr::DimExpr;
13use tenferro_ops::input_key::TensorInputKey;
14use tenferro_ops::std_tensor_op::StdTensorOp;
15use tenferro_ops::ShapeGuardContext;
16use tenferro_tensor::{DType, DotGeneralConfig, Tensor, TensorBackend, TensorScalar, TypedTensor};
17use tidu::{differentiate, transpose};
18
19use super::compiler::compile_std_to_exec;
20use super::engine::Engine;
21use super::error::{Error, Result};
22use super::sym_dim::SymDim;
23use crate::checkpoint::CheckpointNode;
24
25static NEXT_INPUT_ID: AtomicU64 = AtomicU64::new(0);
26static NEXT_DIFF_PASS_ID: AtomicU64 = AtomicU64::new(0);
27static NEXT_TRACED_ID: AtomicU64 = AtomicU64::new(0);
28
29pub type TracedTensorId = u64;
30
31pub(crate) fn next_input_key() -> TensorInputKey {
32 TensorInputKey::User {
33 id: NEXT_INPUT_ID.fetch_add(1, Ordering::Relaxed),
34 }
35}
36
37fn next_pass_id() -> u64 {
38 NEXT_DIFF_PASS_ID.fetch_add(1, Ordering::Relaxed)
39}
40
41pub(crate) fn next_traced_id() -> TracedTensorId {
42 NEXT_TRACED_ID.fetch_add(1, Ordering::Relaxed)
43}
44
45#[derive(Clone)]
46pub struct TracedTensor {
47 pub id: TracedTensorId,
48 pub rank: usize,
49 pub dtype: DType,
50 pub fragment: Arc<Fragment<StdTensorOp>>,
51 pub val: LocalValId,
52 pub data: Option<Arc<Tensor>>,
53 pub(crate) shape_hint: Option<Vec<SymDim>>,
54 pub(crate) inputs_map: Arc<HashMap<TensorInputKey, Arc<Tensor>>>,
55 pub(crate) extra_roots: Vec<Arc<Fragment<StdTensorOp>>>,
56 pub(crate) checkpoint_chain: Option<Arc<CheckpointNode>>,
57}
58
59fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Vec<usize>> {
63 let rank = a.len().max(b.len());
64 let mut result = Vec::with_capacity(rank);
65 for index in 0..rank {
66 let a_dim = if index < rank - a.len() {
67 1
68 } else {
69 a[index - (rank - a.len())]
70 };
71 let b_dim = if index < rank - b.len() {
72 1
73 } else {
74 b[index - (rank - b.len())]
75 };
76 if a_dim == b_dim {
77 result.push(a_dim);
78 } else if a_dim == 1 {
79 result.push(b_dim);
80 } else if b_dim == 1 {
81 result.push(a_dim);
82 } else {
83 return None;
84 }
85 }
86 Some(result)
87}
88
89pub(crate) fn try_concrete_shape(tensor: &TracedTensor) -> Option<Vec<usize>> {
90 tensor
91 .shape_hint
92 .as_ref()?
93 .iter()
94 .map(SymDim::constant_value)
95 .collect()
96}
97
98pub(crate) fn concrete_shape(tensor: &TracedTensor) -> Vec<usize> {
99 tensor
100 .shape_hint
101 .as_ref()
102 .unwrap_or_else(|| panic!("missing shape hint for traced tensor {}", tensor.id))
103 .iter()
104 .map(|dim| {
105 dim.constant_value().unwrap_or_else(|| {
106 panic!("symbolic dimension in shape hint for tensor {}", tensor.id)
107 })
108 })
109 .collect()
110}
111
112fn error_shape_hint(tensor: &TracedTensor) -> Vec<usize> {
113 try_concrete_shape(tensor).unwrap_or_else(|| vec![0; tensor.rank])
114}
115
116fn broadcast_to(tensor: &TracedTensor, target_shape: &[usize]) -> TracedTensor {
121 let tensor_shape = concrete_shape(tensor);
122 if tensor_shape == target_shape {
123 return tensor.clone();
124 }
125
126 assert!(
127 tensor.rank <= target_shape.len(),
128 "cannot broadcast higher-rank shape {:?} to {:?}",
129 tensor_shape,
130 target_shape
131 );
132
133 let rank_diff = target_shape.len() - tensor.rank;
134 let mut source_shape = Vec::with_capacity(tensor.rank);
135 let mut dims = Vec::with_capacity(tensor.rank);
136 for (src_axis, &src_dim) in tensor_shape.iter().enumerate() {
137 let dst_axis = src_axis + rank_diff;
138 let dst_dim = target_shape[dst_axis];
139 assert!(
140 src_dim == dst_dim || src_dim == 1,
141 "cannot broadcast shape {:?} to {:?}",
142 tensor_shape,
143 target_shape
144 );
145 if src_dim == 1 && dst_dim != 1 {
146 continue;
147 }
148 source_shape.push(src_dim);
149 dims.push(dst_axis);
150 }
151
152 let source = if source_shape == tensor_shape {
153 tensor.clone()
154 } else {
155 tensor.reshape(&source_shape)
156 };
157 source.broadcast_in_dim(target_shape, &dims)
158}
159
160fn broadcast_binary(a: &TracedTensor, b: &TracedTensor) -> (TracedTensor, TracedTensor) {
162 if a.shape_hint == b.shape_hint && a.rank == b.rank {
163 return (a.clone(), b.clone());
164 }
165 let a_shape = concrete_shape(a);
166 let b_shape = concrete_shape(b);
167 let target = broadcast_shape(&a_shape, &b_shape).unwrap_or_else(|| {
168 panic!(
169 "incompatible shapes for broadcast: {:?} and {:?}",
170 a_shape, b_shape
171 )
172 });
173 (broadcast_to(a, &target), broadcast_to(b, &target))
174}
175
176fn scale_with_constant(input: &TracedTensor, op: StdTensorOp) -> TracedTensor {
177 let scalar = apply_nullary(op, 0, input.dtype, Some(vec![]));
178 let input_shape = concrete_shape(input);
179 let factor = broadcast_to(&scalar, &input_shape);
180 apply_binary(
181 StdTensorOp::Mul,
182 input,
183 &factor,
184 input.rank,
185 input.shape_hint.clone(),
186 )
187}
188
189impl std::ops::Add for &TracedTensor {
190 type Output = TracedTensor;
191
192 fn add(self, rhs: &TracedTensor) -> TracedTensor {
193 TracedTensor::add(self, rhs)
194 }
195}
196
197impl std::ops::Mul for &TracedTensor {
198 type Output = TracedTensor;
199
200 fn mul(self, rhs: &TracedTensor) -> TracedTensor {
201 TracedTensor::mul(self, rhs)
202 }
203}
204
205impl std::ops::Mul<f64> for &TracedTensor {
206 type Output = TracedTensor;
207
208 fn mul(self, rhs: f64) -> TracedTensor {
209 self.scale_real(rhs)
210 }
211}
212
213impl std::ops::Mul<&TracedTensor> for f64 {
214 type Output = TracedTensor;
215
216 fn mul(self, rhs: &TracedTensor) -> TracedTensor {
217 rhs.scale_real(self)
218 }
219}
220
221impl std::ops::Neg for &TracedTensor {
222 type Output = TracedTensor;
223
224 fn neg(self) -> TracedTensor {
225 TracedTensor::neg(self)
226 }
227}
228
229impl std::ops::Div for &TracedTensor {
230 type Output = TracedTensor;
231
232 fn div(self, rhs: &TracedTensor) -> TracedTensor {
233 TracedTensor::div(self, rhs)
234 }
235}
236
237impl TracedTensor {
238 pub fn from_tensor_concrete_shape(tensor: Tensor) -> Self {
258 let shape = tensor.shape().to_vec();
259 let rank = shape.len();
260 let dtype = tensor.dtype();
261 let key = next_input_key();
262 let data = Arc::new(tensor);
263
264 let mut builder = FragmentBuilder::new();
265 let val = builder.add_input(key.clone());
266 builder.set_outputs(vec![val]);
267 let fragment = Arc::new(builder.build());
268
269 let mut map = HashMap::new();
270 map.insert(key, Arc::clone(&data));
271
272 Self {
273 id: next_traced_id(),
274 rank,
275 dtype,
276 fragment,
277 val,
278 data: Some(data),
279 shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
280 inputs_map: Arc::new(map),
281 extra_roots: Vec::new(),
282 checkpoint_chain: None,
283 }
284 }
285
286 pub fn from_tensor_symbolic_shape(tensor: Tensor) -> Self {
309 let rank = tensor.shape().len();
310 let dtype = tensor.dtype();
311 let key = next_input_key();
312 let data = Arc::new(tensor);
313
314 let mut builder = FragmentBuilder::new();
315 let val = builder.add_input(key.clone());
316 builder.set_outputs(vec![val]);
317 let fragment = Arc::new(builder.build());
318
319 let mut map = HashMap::new();
320 map.insert(key, Arc::clone(&data));
321
322 Self {
323 id: next_traced_id(),
324 rank,
325 dtype,
326 fragment,
327 val,
328 data: Some(data),
329 shape_hint: None,
330 inputs_map: Arc::new(map),
331 extra_roots: Vec::new(),
332 checkpoint_chain: None,
333 }
334 }
335
336 pub fn input_concrete_shape(dtype: DType, shape: &[usize]) -> Self {
353 let shape = shape.to_vec();
354 let rank = shape.len();
355 let key = next_input_key();
356
357 let mut builder = FragmentBuilder::new();
358 let val = builder.add_input(key.clone());
359 builder.set_outputs(vec![val]);
360 let fragment = Arc::new(builder.build());
361
362 Self {
363 id: next_traced_id(),
364 rank,
365 dtype,
366 fragment,
367 val,
368 data: None,
369 shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
370 inputs_map: Arc::new(HashMap::new()),
371 extra_roots: Vec::new(),
372 checkpoint_chain: None,
373 }
374 }
375
376 pub fn input_symbolic_shape(dtype: DType, rank: usize) -> Self {
396 let key = next_input_key();
397
398 let mut builder = FragmentBuilder::new();
399 let val = builder.add_input(key.clone());
400 builder.set_outputs(vec![val]);
401 let fragment = Arc::new(builder.build());
402
403 Self {
404 id: next_traced_id(),
405 rank,
406 dtype,
407 fragment,
408 val,
409 data: None,
410 shape_hint: None,
411 inputs_map: Arc::new(HashMap::new()),
412 extra_roots: Vec::new(),
413 checkpoint_chain: None,
414 }
415 }
416
417 pub fn from_vec<T: TensorScalar>(shape: Vec<usize>, data: Vec<T>) -> Self {
430 Self::from_tensor_concrete_shape(T::into_tensor(shape, data))
431 }
432
433 pub fn is_concrete_shape(&self) -> bool {
448 try_concrete_shape(self).is_some()
449 }
450
451 pub fn input_key(&self) -> Option<TensorInputKey> {
454 match &self.fragment.vals()[self.val].key {
455 GlobalValKey::Input(key) => Some(key.clone()),
456 _ => None,
457 }
458 }
459
460 pub fn eval<B: TensorBackend>(&mut self, engine: &mut Engine<B>) -> Result<&Tensor> {
461 self.eval_with_inputs(engine, &[])
462 }
463
464 pub fn eval_with_inputs<B: TensorBackend>(
506 &mut self,
507 engine: &mut Engine<B>,
508 bindings: &[(&TracedTensor, &Tensor)],
509 ) -> Result<&Tensor> {
510 let mut binding_map: HashMap<TensorInputKey, &Tensor> = HashMap::new();
516 for (index, (placeholder, tensor)) in bindings.iter().enumerate() {
517 if placeholder.data.is_some() {
518 return Err(Error::UnexpectedBinding {
519 binding_index: index,
520 });
521 }
522 let key = placeholder.input_key().ok_or(Error::UnexpectedBinding {
523 binding_index: index,
524 })?;
525
526 if placeholder.dtype != tensor.dtype() {
527 return Err(Error::PlaceholderDtypeMismatch {
528 expected: placeholder.dtype,
529 actual: tensor.dtype(),
530 });
531 }
532
533 match try_concrete_shape(placeholder) {
534 Some(expected_shape) => {
535 if expected_shape.as_slice() != tensor.shape() {
536 return Err(Error::PlaceholderShapeMismatch {
537 expected: expected_shape,
538 actual: tensor.shape().to_vec(),
539 });
540 }
541 }
542 None => {
543 if placeholder.rank != tensor.shape().len() {
544 return Err(Error::PlaceholderRankMismatch {
545 expected: placeholder.rank,
546 actual: tensor.shape().len(),
547 });
548 }
549 }
550 }
551
552 if binding_map.insert(key.clone(), *tensor).is_some() {
553 return Err(Error::DuplicateBinding {
554 input_key: format!("{:?}", key),
555 });
556 }
557 }
558
559 if self.data.is_some() {
562 return Ok(self.data.as_ref().unwrap().as_ref());
563 }
564
565 let output_key = self.fragment.vals()[self.val].key.clone();
566
567 let view = resolve(self.resolve_roots());
568 let graph = materialize_merge(&view, &[output_key]);
569 let compiled = compile(&graph);
570
571 let mut input_tensors = Vec::with_capacity(graph.inputs.len());
572 let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
573 let mut input_shapes = Vec::with_capacity(graph.inputs.len());
574 for key in &graph.inputs {
575 match key {
576 GlobalValKey::Input(k) => {
577 if let Some(tensor) = self.inputs_map.get(k) {
578 input_tensors.push(tensor.as_ref().clone());
579 input_dtypes.push(tensor.dtype());
580 input_shapes.push(DimExpr::from_concrete(tensor.shape()));
581 } else if let Some(bound) = binding_map.remove(k) {
582 input_tensors.push((*bound).clone());
583 input_dtypes.push(bound.dtype());
584 input_shapes.push(DimExpr::from_concrete(bound.shape()));
585 } else {
586 return Err(Error::UnboundPlaceholder {
587 input_key: format!("{:?}", k),
588 });
589 }
590 }
591 _ => {
592 return Err(Error::Internal(
593 "expected Input key in graph inputs".to_string(),
594 ));
595 }
596 }
597 }
598 let exec = compile_std_to_exec(&compiled, &input_dtypes, &input_shapes);
599
600 let cached_exec = engine.get_or_compile(exec);
601 let mut results = engine.eval_exec_ir(&cached_exec, input_tensors)?;
602 if results.len() != 1 {
603 return Err(Error::Internal(format!(
604 "expected 1 output, got {}",
605 results.len()
606 )));
607 }
608
609 self.data = Some(Arc::new(results.remove(0)));
610 Ok(self.data.as_ref().unwrap().as_ref())
611 }
612
613 pub fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor> {
614 if self.rank != 0 {
615 return Err(Error::NonScalarGrad {
616 shape: error_shape_hint(self),
617 });
618 }
619
620 let ones = ones_tensor(self.dtype, vec![]);
621 let seed = TracedTensor::from_tensor_concrete_shape(ones);
622 Ok(self.vjp(wrt, &seed))
623 }
624
625 pub fn try_grad(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>> {
634 if self.rank != 0 {
635 return Err(Error::NonScalarGrad {
636 shape: error_shape_hint(self),
637 });
638 }
639
640 let ones = ones_tensor(self.dtype, vec![]);
641 let seed = TracedTensor::from_tensor_concrete_shape(ones);
642 Ok(self.try_vjp(wrt, &seed))
643 }
644
645 pub fn checkpoint<B: TensorBackend>(&mut self, engine: &mut Engine<B>) -> Result<()> {
662 self.eval(engine)?;
663 let data = self
664 .data
665 .clone()
666 .ok_or_else(|| Error::Internal("checkpoint eval did not populate data".to_string()))?;
667 let concrete_shape_hint = Some(data.shape().iter().copied().map(SymDim::from).collect());
668
669 let old_fragment = self.fragment.clone();
670 let old_output_key = old_fragment.vals()[self.val].key.clone();
671 let old_inputs = (*self.inputs_map).clone();
672
673 let new_key = next_input_key();
674 let mut builder = FragmentBuilder::new();
675 let leaf_val = builder.add_input(new_key.clone());
676 builder.set_outputs(vec![leaf_val]);
677 let new_fragment = Arc::new(builder.build());
678
679 let node = CheckpointNode {
680 fragment: old_fragment,
681 alias_key: new_key.clone(),
682 alias_target: old_output_key,
683 old_inputs,
684 prev: self.checkpoint_chain.take(),
685 };
686
687 self.fragment = new_fragment;
688 self.val = leaf_val;
689 self.extra_roots.clear();
690 self.shape_hint = concrete_shape_hint;
691 self.checkpoint_chain = Some(Arc::new(node));
692
693 let mut merged = HashMap::new();
694 if let Some(chain) = &self.checkpoint_chain {
695 merged.extend(chain.collect_inputs());
696 }
697 merged.insert(new_key, data);
698 self.inputs_map = Arc::new(merged);
699
700 Ok(())
701 }
702
703 pub fn jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> TracedTensor {
704 self.try_jvp(wrt, tangent)
705 .unwrap_or_else(|| panic!("jvp output is inactive for {:?}", leaf_input_key(wrt)))
706 }
707
708 pub fn try_jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> Option<TracedTensor> {
711 let wrt_input_key = leaf_input_key(wrt);
712 let output_key = self.fragment.vals()[self.val].key.clone();
713 let aliases = self
714 .checkpoint_chain
715 .as_ref()
716 .map(|chain| chain.collect_aliases())
717 .unwrap_or_default();
718 let checkpoint_fragments = self
719 .checkpoint_chain
720 .as_ref()
721 .map(|chain| chain.collect_fragments())
722 .unwrap_or_default();
723 let mut roots = self.resolve_roots();
724 roots.extend(checkpoint_fragments.iter().cloned());
725 let view = resolve(roots);
726 let mut ad_ctx = ShapeGuardContext::default();
727 let linear = differentiate(
728 &view,
729 std::slice::from_ref(&output_key),
730 std::slice::from_ref(&wrt_input_key),
731 next_pass_id(),
732 &mut ad_ctx,
733 &aliases,
734 );
735 let tangent_output = linear.tangent_outputs[0]?;
736 let tangent_input_key = linear_input_key(&linear.fragment, linear.tangent_inputs[0].1);
737
738 let mut inputs_map = (*self.inputs_map).clone();
739 if let Some(chain) = &self.checkpoint_chain {
740 inputs_map.extend(chain.collect_inputs());
741 }
742 inputs_map.insert(
743 tangent_input_key,
744 tangent
745 .data
746 .clone()
747 .unwrap_or_else(|| panic!("jvp tangent must have concrete tensor data")),
748 );
749
750 let mut extra_roots = vec![self.fragment.clone()];
751 extra_roots.extend(checkpoint_fragments);
752 extra_roots.extend(self.extra_roots.iter().cloned());
753
754 Some(TracedTensor {
755 id: next_traced_id(),
756 rank: self.rank,
757 dtype: self.dtype,
758 fragment: Arc::new(linear.fragment),
759 val: tangent_output,
760 data: None,
761 shape_hint: self.shape_hint.clone(),
762 inputs_map: Arc::new(inputs_map),
763 extra_roots,
764 checkpoint_chain: self.checkpoint_chain.clone(),
765 })
766 }
767
768 pub fn vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> TracedTensor {
769 self.try_vjp(wrt, cotangent)
770 .unwrap_or_else(|| panic!("vjp output is inactive for {:?}", leaf_input_key(wrt)))
771 }
772
773 fn try_vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> Option<TracedTensor> {
774 let wrt_input_key = leaf_input_key(wrt);
775 let output_key = self.fragment.vals()[self.val].key.clone();
776 let aliases = self
777 .checkpoint_chain
778 .as_ref()
779 .map(|chain| chain.collect_aliases())
780 .unwrap_or_default();
781 let checkpoint_fragments = self
782 .checkpoint_chain
783 .as_ref()
784 .map(|chain| chain.collect_fragments())
785 .unwrap_or_default();
786 let mut roots = self.resolve_roots();
787 roots.extend(checkpoint_fragments.iter().cloned());
788 let view = resolve(roots);
789 let mut ad_ctx = ShapeGuardContext::default();
790 let linear = differentiate(
791 &view,
792 std::slice::from_ref(&output_key),
793 std::slice::from_ref(&wrt_input_key),
794 next_pass_id(),
795 &mut ad_ctx,
796 &aliases,
797 );
798 let linear_tangent_input_ids: Vec<LocalValId> = linear
799 .tangent_inputs
800 .iter()
801 .map(|(_, local_id)| *local_id)
802 .collect();
803 let transposed = transpose(&linear, &mut ad_ctx);
804 let linear_fragment = Arc::new(linear.fragment);
805 let cotangent_output = transposed.tangent_outputs[0]?;
806 let cotangent_input_key =
807 linear_input_key(&transposed.fragment, transposed.tangent_inputs[0].1);
808
809 let mut inputs_map = (*self.inputs_map).clone();
810 if let Some(chain) = &self.checkpoint_chain {
811 inputs_map.extend(chain.collect_inputs());
812 }
813 inputs_map.insert(
814 cotangent_input_key.clone(),
815 cotangent
816 .data
817 .clone()
818 .unwrap_or_else(|| panic!("vjp cotangent must have concrete tensor data")),
819 );
820 let zero_tangent = Arc::new(zeros_tensor(
821 wrt.dtype,
822 try_concrete_shape(wrt).unwrap_or_else(|| vec![0; wrt.rank]),
823 ));
824 for (_, local_id) in &transposed.tangent_inputs {
825 let tangent_input_key = linear_input_key(&transposed.fragment, *local_id);
826 if tangent_input_key != cotangent_input_key {
827 inputs_map.insert(tangent_input_key, Arc::clone(&zero_tangent));
828 }
829 }
830 for local_id in linear_tangent_input_ids {
831 let tangent_input_key = linear_input_key(&linear_fragment, local_id);
832 inputs_map.insert(tangent_input_key, Arc::clone(&zero_tangent));
833 }
834
835 let mut extra_roots = vec![self.fragment.clone(), linear_fragment];
836 extra_roots.extend(checkpoint_fragments);
837 extra_roots.extend(self.extra_roots.iter().cloned());
838
839 Some(TracedTensor {
840 id: next_traced_id(),
841 rank: wrt.rank,
842 dtype: wrt.dtype,
843 fragment: Arc::new(transposed.fragment),
844 val: cotangent_output,
845 data: None,
846 shape_hint: wrt.shape_hint.clone(),
847 inputs_map: Arc::new(inputs_map),
848 extra_roots,
849 checkpoint_chain: self.checkpoint_chain.clone(),
850 })
851 }
852
853 pub fn add(&self, other: &TracedTensor) -> TracedTensor {
864 let (lhs, rhs) = broadcast_binary(self, other);
865 apply_binary(
866 StdTensorOp::Add,
867 &lhs,
868 &rhs,
869 lhs.rank,
870 lhs.shape_hint.clone(),
871 )
872 }
873
874 pub fn mul(&self, other: &TracedTensor) -> TracedTensor {
885 let (lhs, rhs) = broadcast_binary(self, other);
886 apply_binary(
887 StdTensorOp::Mul,
888 &lhs,
889 &rhs,
890 lhs.rank,
891 lhs.shape_hint.clone(),
892 )
893 }
894
895 pub fn div(&self, other: &TracedTensor) -> TracedTensor {
906 let (lhs, rhs) = broadcast_binary(self, other);
907 apply_binary(
908 StdTensorOp::Div,
909 &lhs,
910 &rhs,
911 lhs.rank,
912 lhs.shape_hint.clone(),
913 )
914 }
915
916 pub fn neg(&self) -> TracedTensor {
927 apply_unary(StdTensorOp::Neg, self, self.rank, self.shape_hint.clone())
928 }
929
930 pub fn conj(&self) -> TracedTensor {
938 apply_unary(StdTensorOp::Conj, self, self.rank, self.shape_hint.clone())
939 }
940
941 pub fn abs(&self) -> TracedTensor {
949 apply_unary(StdTensorOp::Abs, self, self.rank, self.shape_hint.clone())
950 }
951
952 pub fn sign(&self) -> TracedTensor {
960 apply_unary(StdTensorOp::Sign, self, self.rank, self.shape_hint.clone())
961 }
962
963 pub fn scale_real(&self, factor: f64) -> TracedTensor {
971 let op = match self.dtype {
972 DType::F64 => StdTensorOp::constant_f64(factor),
973 DType::F32 => StdTensorOp::constant_f32(factor as f32),
974 DType::C64 => StdTensorOp::constant_c64(Complex64::new(factor, 0.0)),
975 DType::C32 => StdTensorOp::constant_c32(Complex32::new(factor as f32, 0.0)),
976 };
977 scale_with_constant(self, op)
978 }
979
980 pub fn scale_complex(&self, factor: Complex64) -> TracedTensor {
992 match self.dtype {
993 DType::C64 => scale_with_constant(self, StdTensorOp::constant_c64(factor)),
994 DType::C32 => scale_with_constant(
995 self,
996 StdTensorOp::constant_c32(Complex32::new(factor.re as f32, factor.im as f32)),
997 ),
998 DType::F32 | DType::F64 => {
999 panic!(
1000 "scale_complex only supports complex tensors; use scale_real for real tensors"
1001 )
1002 }
1003 }
1004 }
1005
1006 pub fn exp(&self) -> TracedTensor {
1014 apply_unary(StdTensorOp::Exp, self, self.rank, self.shape_hint.clone())
1015 }
1016
1017 pub fn log(&self) -> TracedTensor {
1025 apply_unary(StdTensorOp::Log, self, self.rank, self.shape_hint.clone())
1026 }
1027
1028 pub fn sin(&self) -> TracedTensor {
1036 apply_unary(StdTensorOp::Sin, self, self.rank, self.shape_hint.clone())
1037 }
1038
1039 pub fn cos(&self) -> TracedTensor {
1047 apply_unary(StdTensorOp::Cos, self, self.rank, self.shape_hint.clone())
1048 }
1049
1050 pub fn tanh(&self) -> TracedTensor {
1058 apply_unary(StdTensorOp::Tanh, self, self.rank, self.shape_hint.clone())
1059 }
1060
1061 pub fn sqrt(&self) -> TracedTensor {
1069 apply_unary(StdTensorOp::Sqrt, self, self.rank, self.shape_hint.clone())
1070 }
1071
1072 pub fn rsqrt(&self) -> TracedTensor {
1080 apply_unary(StdTensorOp::Rsqrt, self, self.rank, self.shape_hint.clone())
1081 }
1082
1083 pub fn pow(&self, other: &TracedTensor) -> TracedTensor {
1091 let (lhs, rhs) = broadcast_binary(self, other);
1092 apply_binary(
1093 StdTensorOp::Pow,
1094 &lhs,
1095 &rhs,
1096 lhs.rank,
1097 lhs.shape_hint.clone(),
1098 )
1099 }
1100
1101 pub fn expm1(&self) -> TracedTensor {
1109 apply_unary(StdTensorOp::Expm1, self, self.rank, self.shape_hint.clone())
1110 }
1111
1112 pub fn log1p(&self) -> TracedTensor {
1120 apply_unary(StdTensorOp::Log1p, self, self.rank, self.shape_hint.clone())
1121 }
1122
1123 pub fn convert(&self, to: DType) -> TracedTensor {
1136 if self.dtype == to {
1137 return self.clone();
1138 }
1139
1140 apply_unary_with_dtype(
1141 StdTensorOp::Convert {
1142 from: self.dtype,
1143 to,
1144 },
1145 self,
1146 self.rank,
1147 self.shape_hint.clone(),
1148 to,
1149 )
1150 }
1151
1152 pub fn dot_general(&self, other: &TracedTensor, config: DotGeneralConfig) -> TracedTensor {
1160 config
1161 .validate_ranks(self.rank, other.rank)
1162 .expect("DotGeneral config rank validation failed");
1163 config
1164 .validate_dims()
1165 .expect("DotGeneral config dimension validation failed");
1166 let lhs_free: Vec<usize> = (0..config.lhs_rank)
1167 .filter(|d| {
1168 !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d)
1169 })
1170 .collect();
1171 let rhs_free: Vec<usize> = (0..config.rhs_rank)
1172 .filter(|d| {
1173 !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d)
1174 })
1175 .collect();
1176 let out_rank = config.lhs_batch_dims.len() + lhs_free.len() + rhs_free.len();
1177 let out_shape_hint = match (&self.shape_hint, &other.shape_hint) {
1178 (Some(lhs_shape), Some(rhs_shape)) => {
1179 let mut out_shape = Vec::with_capacity(out_rank);
1180 for &d in &lhs_free {
1181 out_shape.push(lhs_shape[d].clone());
1182 }
1183 for &d in &rhs_free {
1184 out_shape.push(rhs_shape[d].clone());
1185 }
1186 for &d in &config.lhs_batch_dims {
1187 out_shape.push(lhs_shape[d].clone());
1188 }
1189 Some(out_shape)
1190 }
1191 _ => None,
1192 };
1193
1194 apply_binary(
1195 StdTensorOp::DotGeneral(config),
1196 self,
1197 other,
1198 out_rank,
1199 out_shape_hint,
1200 )
1201 }
1202
1203 pub fn reduce_sum(&self, axes: &[usize]) -> TracedTensor {
1212 let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1213 (0..shape.len())
1214 .filter(|d| !axes.contains(d))
1215 .map(|d| shape[d].clone())
1216 .collect()
1217 });
1218 apply_unary(
1219 StdTensorOp::ReduceSum {
1220 axes: axes.to_vec(),
1221 input_shape: DimExpr::input_shape(0, self.rank),
1222 },
1223 self,
1224 self.rank - axes.len(),
1225 out_shape_hint,
1226 )
1227 }
1228
1229 pub fn reshape(&self, shape: &[usize]) -> TracedTensor {
1237 apply_unary(
1238 StdTensorOp::Reshape {
1239 from_shape: DimExpr::input_shape(0, self.rank),
1240 to_shape: DimExpr::from_concrete(shape),
1241 },
1242 self,
1243 shape.len(),
1244 Some(shape.iter().copied().map(SymDim::from).collect()),
1245 )
1246 }
1247
1248 pub fn sym_size(&self, axis: usize) -> SymDim {
1258 assert!(
1259 axis < self.rank,
1260 "axis {axis} out of bounds for rank {}",
1261 self.rank
1262 );
1263 self.shape_hint
1264 .as_ref()
1265 .and_then(|shape| shape.get(axis))
1266 .filter(|dim| dim.constant_value().is_none())
1267 .cloned()
1268 .unwrap_or_else(|| SymDim::tensor_axis(self.id, axis))
1269 }
1270
1271 pub fn reshape_sym(&self, shape: &[SymDim]) -> Result<TracedTensor> {
1281 let tensor_map = [(self.id, 0usize)];
1282 let to_shape = shape
1283 .iter()
1284 .map(|dim| dim.to_dim_expr(&tensor_map).map_err(Error::Internal))
1285 .collect::<Result<Vec<_>>>()?;
1286 let out_shape_hint = Some(shape.to_vec());
1287 Ok(apply_unary(
1288 StdTensorOp::Reshape {
1289 from_shape: DimExpr::input_shape(0, self.rank),
1290 to_shape,
1291 },
1292 self,
1293 shape.len(),
1294 out_shape_hint,
1295 ))
1296 }
1297
1298 pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> TracedTensor {
1307 apply_unary(
1308 StdTensorOp::BroadcastInDim {
1309 shape: DimExpr::from_concrete(shape),
1310 dims: dims.to_vec(),
1311 },
1312 self,
1313 shape.len(),
1314 Some(shape.iter().copied().map(SymDim::from).collect()),
1315 )
1316 }
1317
1318 pub fn transpose(&self, perm: &[usize]) -> TracedTensor {
1326 let out_shape_hint = self
1327 .shape_hint
1328 .as_ref()
1329 .map(|shape| perm.iter().map(|&p| shape[p].clone()).collect());
1330 apply_unary(
1331 StdTensorOp::Transpose {
1332 perm: perm.to_vec(),
1333 },
1334 self,
1335 self.rank,
1336 out_shape_hint,
1337 )
1338 }
1339
1340 pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> TracedTensor {
1348 assert!(
1349 axis_a < self.rank && axis_b < self.rank && axis_a != axis_b,
1350 "extract_diag: invalid axes"
1351 );
1352 let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1353 shape
1354 .iter()
1355 .enumerate()
1356 .filter_map(|(axis, dim)| (axis != axis_b).then_some(dim.clone()))
1357 .collect()
1358 });
1359 apply_unary(
1360 StdTensorOp::ExtractDiag { axis_a, axis_b },
1361 self,
1362 self.rank - 1,
1363 out_shape_hint,
1364 )
1365 }
1366
1367 pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> TracedTensor {
1375 assert!(
1376 axis_a < self.rank && axis_b <= self.rank,
1377 "embed_diag: invalid axes"
1378 );
1379 let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1380 let mut out_shape = shape.clone();
1381 out_shape.insert(axis_b, shape[axis_a].clone());
1382 out_shape
1383 });
1384 apply_unary(
1385 StdTensorOp::EmbedDiag { axis_a, axis_b },
1386 self,
1387 self.rank + 1,
1388 out_shape_hint,
1389 )
1390 }
1391
1392 pub fn sum(&self, axes: &[usize]) -> TracedTensor {
1400 self.reduce_sum(axes)
1401 }
1402
1403 pub fn broadcast(&self, shape: &[usize], dims: &[usize]) -> TracedTensor {
1411 self.broadcast_in_dim(shape, dims)
1412 }
1413
1414 pub fn shape_of(&self, axis: usize) -> TracedTensor {
1429 assert!(
1430 axis < self.rank,
1431 "axis {axis} out of bounds for rank {}",
1432 self.rank
1433 );
1434 apply_unary_with_dtype(
1435 StdTensorOp::ShapeOf { axis },
1436 self,
1437 0,
1438 Some(vec![]),
1439 DType::F64,
1440 )
1441 }
1442
1443 pub fn dynamic_truncate(&self, size: &TracedTensor, axis: usize) -> TracedTensor {
1461 assert!(
1462 axis < self.rank,
1463 "axis {axis} out of bounds for rank {}",
1464 self.rank
1465 );
1466 assert!(
1467 size.rank == 0,
1468 "dynamic_truncate size must be a scalar tensor, got rank {}",
1469 size.rank
1470 );
1471 apply_binary(
1472 StdTensorOp::DynamicTruncate { axis },
1473 self,
1474 size,
1475 self.rank,
1476 None,
1477 )
1478 }
1479
1480 pub fn pad_to_match(&self, reference: &TracedTensor, axis: usize) -> TracedTensor {
1496 assert!(
1497 axis < self.rank,
1498 "axis {axis} out of bounds for rank {}",
1499 self.rank
1500 );
1501 assert!(
1502 axis < reference.rank,
1503 "reference axis {axis} out of bounds for rank {}",
1504 reference.rank
1505 );
1506 apply_binary(
1507 StdTensorOp::PadToMatch { axis },
1508 self,
1509 reference,
1510 self.rank,
1511 reference.shape_hint.clone(),
1512 )
1513 }
1514}
1515
1516pub(crate) fn apply_unary(
1517 op: StdTensorOp,
1518 input: &TracedTensor,
1519 out_rank: usize,
1520 out_shape_hint: Option<Vec<SymDim>>,
1521) -> TracedTensor {
1522 apply_unary_with_dtype(op, input, out_rank, out_shape_hint, input.dtype)
1523}
1524
1525pub(crate) fn apply_unary_with_dtype(
1526 op: StdTensorOp,
1527 input: &TracedTensor,
1528 out_rank: usize,
1529 out_shape_hint: Option<Vec<SymDim>>,
1530 out_dtype: DType,
1531) -> TracedTensor {
1532 let mut builder = FragmentBuilder::new();
1533 builder.add_parent(input.fragment.clone());
1534 let input_ref = ValRef::External(input.fragment.vals()[input.val].key.clone());
1535 let outputs = builder.add_op(op, vec![input_ref], OpMode::Primal);
1536 builder.set_outputs(outputs.clone());
1537 let fragment = Arc::new(builder.build());
1538
1539 TracedTensor {
1540 id: next_traced_id(),
1541 rank: out_rank,
1542 dtype: out_dtype,
1543 fragment,
1544 val: outputs[0],
1545 data: None,
1546 shape_hint: out_shape_hint,
1547 inputs_map: input.inputs_map.clone(),
1548 extra_roots: input.extra_roots.clone(),
1549 checkpoint_chain: input.checkpoint_chain.clone(),
1550 }
1551}
1552
1553pub(crate) fn apply_nullary(
1554 op: StdTensorOp,
1555 rank: usize,
1556 dtype: DType,
1557 shape_hint: Option<Vec<SymDim>>,
1558) -> TracedTensor {
1559 let mut builder = FragmentBuilder::new();
1560 let outputs = builder.add_op(op, vec![], OpMode::Primal);
1561 builder.set_outputs(outputs.clone());
1562 let fragment = Arc::new(builder.build());
1563
1564 TracedTensor {
1565 id: next_traced_id(),
1566 rank,
1567 dtype,
1568 fragment,
1569 val: outputs[0],
1570 data: None,
1571 shape_hint,
1572 inputs_map: Arc::new(HashMap::new()),
1573 extra_roots: Vec::new(),
1574 checkpoint_chain: None,
1575 }
1576}
1577
1578pub(crate) fn apply_binary(
1579 op: StdTensorOp,
1580 lhs: &TracedTensor,
1581 rhs: &TracedTensor,
1582 out_rank: usize,
1583 out_shape_hint: Option<Vec<SymDim>>,
1584) -> TracedTensor {
1585 let mut builder = FragmentBuilder::new();
1586 builder.add_parent(lhs.fragment.clone());
1587 builder.add_parent(rhs.fragment.clone());
1588 let lhs_ref = ValRef::External(lhs.fragment.vals()[lhs.val].key.clone());
1589 let rhs_ref = ValRef::External(rhs.fragment.vals()[rhs.val].key.clone());
1590 let outputs = builder.add_op(op, vec![lhs_ref, rhs_ref], OpMode::Primal);
1591 builder.set_outputs(outputs.clone());
1592 let fragment = Arc::new(builder.build());
1593
1594 let mut merged = (*lhs.inputs_map).clone();
1595 merged.extend(rhs.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
1596 let mut extra_roots = lhs.extra_roots.clone();
1597 extra_roots.extend(rhs.extra_roots.iter().cloned());
1598
1599 TracedTensor {
1600 id: next_traced_id(),
1601 rank: out_rank,
1602 dtype: lhs.dtype,
1603 fragment,
1604 val: outputs[0],
1605 data: None,
1606 shape_hint: out_shape_hint,
1607 inputs_map: Arc::new(merged),
1608 extra_roots,
1609 checkpoint_chain: CheckpointNode::merge_chains(
1610 lhs.checkpoint_chain.clone(),
1611 rhs.checkpoint_chain.clone(),
1612 ),
1613 }
1614}
1615
1616pub(crate) fn apply_multi_output(
1617 op: StdTensorOp,
1618 input: &TracedTensor,
1619 output_shapes: Vec<Vec<SymDim>>,
1620) -> Vec<TracedTensor> {
1621 let mut builder = FragmentBuilder::new();
1622 builder.add_parent(input.fragment.clone());
1623 let input_ref = ValRef::External(input.fragment.vals()[input.val].key.clone());
1624 let outputs = builder.add_op(op, vec![input_ref], OpMode::Primal);
1625 builder.set_outputs(outputs.clone());
1626 let fragment = Arc::new(builder.build());
1627 assert_eq!(
1628 outputs.len(),
1629 output_shapes.len(),
1630 "apply_multi_output: output count must match output_shapes"
1631 );
1632
1633 outputs
1634 .iter()
1635 .zip(output_shapes)
1636 .map(|(&val, shape)| TracedTensor {
1637 id: next_traced_id(),
1638 rank: shape.len(),
1639 dtype: input.dtype,
1640 fragment: fragment.clone(),
1641 val,
1642 data: None,
1643 shape_hint: Some(shape),
1644 inputs_map: input.inputs_map.clone(),
1645 extra_roots: input.extra_roots.clone(),
1646 checkpoint_chain: input.checkpoint_chain.clone(),
1647 })
1648 .collect()
1649}
1650
1651impl TracedTensor {
1652 fn resolve_roots(&self) -> Vec<Arc<Fragment<StdTensorOp>>> {
1653 let mut roots = Vec::with_capacity(1 + self.extra_roots.len());
1654 roots.push(self.fragment.clone());
1655 roots.extend(self.extra_roots.iter().cloned());
1656 roots
1657 }
1658}
1659
1660fn leaf_input_key(tt: &TracedTensor) -> TensorInputKey {
1661 match &tt.fragment.vals()[tt.val].key {
1662 GlobalValKey::Input(key) => key.clone(),
1663 other => panic!("expected traced leaf input, got {:?}", other),
1664 }
1665}
1666
1667fn linear_input_key(fragment: &Fragment<StdTensorOp>, local_id: LocalValId) -> TensorInputKey {
1668 match &fragment.vals()[local_id].key {
1669 GlobalValKey::Input(key) => key.clone(),
1670 other => panic!("expected linear fragment input, got {:?}", other),
1671 }
1672}
1673
1674fn ones_tensor(dtype: DType, shape: Vec<usize>) -> Tensor {
1675 match dtype {
1676 DType::F32 => Tensor::F32(TypedTensor::ones(shape)),
1677 DType::F64 => Tensor::F64(TypedTensor::ones(shape)),
1678 DType::C32 => Tensor::C32(TypedTensor::ones(shape)),
1679 DType::C64 => Tensor::C64(TypedTensor::ones(shape)),
1680 }
1681}
1682
1683fn zeros_tensor(dtype: DType, shape: Vec<usize>) -> Tensor {
1684 match dtype {
1685 DType::F32 => Tensor::F32(TypedTensor::zeros(shape)),
1686 DType::F64 => Tensor::F64(TypedTensor::zeros(shape)),
1687 DType::C32 => Tensor::C32(TypedTensor::zeros(shape)),
1688 DType::C64 => Tensor::C64(TypedTensor::zeros(shape)),
1689 }
1690}
1691
1692pub fn eval_all<B: TensorBackend>(
1693 engine: &mut Engine<B>,
1694 outputs: &mut [&mut TracedTensor],
1695) -> Result<Vec<Tensor>> {
1696 let mut all_fragments = Vec::new();
1697 let mut output_keys = Vec::new();
1698 let mut all_inputs: HashMap<TensorInputKey, Arc<Tensor>> = HashMap::new();
1699
1700 for tt in outputs.iter() {
1701 all_fragments.extend(tt.resolve_roots());
1702 output_keys.push(tt.fragment.vals()[tt.val].key.clone());
1703 all_inputs.extend(tt.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
1704 }
1705
1706 let view = resolve(all_fragments);
1707 let graph = materialize_merge(&view, &output_keys);
1708 let compiled = compile(&graph);
1709
1710 let mut input_tensors = Vec::with_capacity(graph.inputs.len());
1711 let mut input_dtypes = Vec::with_capacity(graph.inputs.len());
1712 let mut input_shapes = Vec::with_capacity(graph.inputs.len());
1713 for key in &graph.inputs {
1714 match key {
1715 GlobalValKey::Input(k) => {
1716 let tensor = all_inputs.get(k).ok_or_else(|| {
1717 Error::MissingInput(format!("missing input data for key {:?}", k))
1718 })?;
1719 input_tensors.push(tensor.as_ref().clone());
1720 input_dtypes.push(tensor.dtype());
1721 input_shapes.push(DimExpr::from_concrete(tensor.shape()));
1722 }
1723 _ => {
1724 return Err(Error::Internal(
1725 "expected Input key in graph inputs".to_string(),
1726 ));
1727 }
1728 }
1729 }
1730 let exec = compile_std_to_exec(&compiled, &input_dtypes, &input_shapes);
1731
1732 let cached_exec = engine.get_or_compile(exec);
1733 let results: Vec<Tensor> = engine.eval_exec_ir(&cached_exec, input_tensors)?;
1734
1735 for (tt, result) in outputs.iter_mut().zip(results.iter()) {
1736 tt.data = Some(Arc::new(result.clone()));
1737 }
1738
1739 Ok(results)
1740}