1use std::collections::HashMap;
2use std::fmt;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use computegraph::graph::{Graph, GraphBuilder};
7use computegraph::types::{OperationRole, ValueKey, ValueRef};
8use computegraph::LocalValueId;
9use num_complex::{Complex32, Complex64};
10use tenferro_ops::ad::context::GlobalMetadataScope;
11use tenferro_ops::broadcast::{broadcast_input_plan, broadcast_shape, broadcast_shapes};
12use tenferro_ops::dim_expr::DimExpr;
13use tenferro_ops::input_key::TensorInputKey;
14use tenferro_ops::std_tensor_op::StdTensorOp;
15use tenferro_tensor::{
16 CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
17 Tensor, TensorScalar,
18};
19
20use super::error::{Error, Result};
21use super::sym_dim::SymDim;
22use crate::checkpoint::CheckpointNode;
23use crate::metadata::{
24 concrete_tensor_meta, metadata_scopes_for_scope, metadata_scopes_with_new, push_metadata_scope,
25 register_scoped_graph_metadata, register_scoped_value_metadata, symbolic_input_meta,
26 tensor_meta,
27};
28use crate::scalar_semantics::round_real_to_i64;
29
30static NEXT_INPUT_ID: AtomicU64 = AtomicU64::new(0);
31static NEXT_TRACED_ID: AtomicU64 = AtomicU64::new(0);
32
33pub type TracedTensorId = u64;
34
35pub(crate) fn next_input_key() -> TensorInputKey {
36 TensorInputKey::User {
37 id: NEXT_INPUT_ID.fetch_add(1, Ordering::Relaxed),
38 }
39}
40
41pub(crate) fn next_traced_id() -> TracedTensorId {
42 NEXT_TRACED_ID.fetch_add(1, Ordering::Relaxed)
43}
44
45#[derive(Clone)]
46pub struct TracedTensor {
47 pub id: TracedTensorId,
48 pub rank: usize,
49 pub dtype: DType,
50 pub(crate) graph: Arc<Graph<StdTensorOp>>,
51 pub val: LocalValueId,
52 pub(crate) data: Option<Arc<Tensor>>,
53 pub(crate) shape_hint: Option<Vec<SymDim>>,
54 pub(crate) inputs_map: Arc<HashMap<TensorInputKey, Arc<Tensor>>>,
55 pub(crate) extra_roots: Vec<Arc<Graph<StdTensorOp>>>,
56 pub(crate) checkpoint_chain: Option<Arc<CheckpointNode>>,
57 pub(crate) metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
58}
59
60impl fmt::Debug for TracedTensor {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("TracedTensor")
63 .field("id", &self.id)
64 .field("rank", &self.rank)
65 .field("dtype", &self.dtype)
66 .field("val", &self.val)
67 .field("shape_hint", &self.shape_hint)
68 .field("has_data", &self.data.is_some())
69 .finish_non_exhaustive()
70 }
71}
72
73pub(crate) fn try_concrete_shape(tensor: &TracedTensor) -> Option<Vec<usize>> {
74 tensor
75 .shape_hint
76 .as_ref()?
77 .iter()
78 .map(SymDim::constant_value)
79 .collect()
80}
81
82pub(crate) fn concrete_shape(tensor: &TracedTensor) -> Result<Vec<usize>> {
83 tensor
84 .shape_hint
85 .as_ref()
86 .ok_or_else(|| Error::InvalidGraphBuild {
87 op: "TracedTensor::concrete_shape",
88 message: format!("missing shape hint for traced tensor {}", tensor.id),
89 })?
90 .iter()
91 .map(|dim| {
92 dim.constant_value()
93 .ok_or_else(|| Error::InvalidGraphBuild {
94 op: "TracedTensor::concrete_shape",
95 message: format!("symbolic dimension in shape hint for tensor {}", tensor.id),
96 })
97 })
98 .collect()
99}
100
101pub(crate) fn broadcast_to(tensor: &TracedTensor, target_shape: &[usize]) -> Result<TracedTensor> {
106 let tensor_shape = concrete_shape(tensor)?;
107 if tensor_shape == target_shape {
108 return Ok(tensor.clone());
109 }
110
111 let plan = broadcast_input_plan(&tensor_shape, target_shape).map_err(|err| {
112 Error::InvalidGraphBuild {
113 op: "broadcast_to",
114 message: err.to_string(),
115 }
116 })?;
117
118 let source = if plan.source_shape == tensor_shape {
119 tensor.clone()
120 } else {
121 tensor.reshape(&plan.source_shape)
122 };
123 source.broadcast_in_dim(target_shape, &plan.dims)
124}
125
126pub(crate) fn broadcast_binary(
128 a: &TracedTensor,
129 b: &TracedTensor,
130) -> Result<(TracedTensor, TracedTensor)> {
131 if a.shape_hint == b.shape_hint && a.rank == b.rank {
132 return Ok((a.clone(), b.clone()));
133 }
134 if (try_concrete_shape(a).is_none() || try_concrete_shape(b).is_none()) && a.rank == b.rank {
135 return Ok((a.clone(), b.clone()));
136 }
137 let a_shape = concrete_shape(a)?;
138 let b_shape = concrete_shape(b)?;
139 let target = broadcast_shape(&a_shape, &b_shape).map_err(|err| Error::InvalidGraphBuild {
140 op: "broadcast_binary",
141 message: err.to_string(),
142 })?;
143 Ok((broadcast_to(a, &target)?, broadcast_to(b, &target)?))
144}
145
146pub(crate) fn broadcast_ternary(
147 a: &TracedTensor,
148 b: &TracedTensor,
149 c: &TracedTensor,
150) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
151 let a_shape = concrete_shape(a)?;
152 let b_shape = concrete_shape(b)?;
153 let c_shape = concrete_shape(c)?;
154 let target = broadcast_shapes([a_shape.as_slice(), b_shape.as_slice(), c_shape.as_slice()])
155 .map_err(|err| Error::InvalidGraphBuild {
156 op: "broadcast_ternary",
157 message: err.to_string(),
158 })?;
159 Ok((
160 broadcast_to(a, &target)?,
161 broadcast_to(b, &target)?,
162 broadcast_to(c, &target)?,
163 ))
164}
165
166fn scale_with_constant(input: &TracedTensor, op: StdTensorOp) -> TracedTensor {
167 let scalar = apply_nullary(op, 0, input.dtype, Some(vec![]));
168 apply_binary(
169 StdTensorOp::Mul,
170 input,
171 &scalar,
172 input.rank,
173 input.shape_hint.clone(),
174 )
175}
176
177fn inferred_output_dtype_or_fallback(
178 op: &StdTensorOp,
179 inputs: &[DType],
180 fallback: DType,
181 context: &'static str,
182) -> DType {
183 match crate::shape_infer::infer_output_dtype(op, inputs) {
184 Ok(dtype) => dtype,
185 Err(err) => {
186 debug_assert!(
187 false,
188 "{context}: built-in traced dtype inference failed for {op:?}: {err}"
189 );
190 fallback
191 }
192 }
193}
194
195fn traced_input_shape_exprs(input_idx: usize, tensor: &TracedTensor) -> Vec<DimExpr> {
196 match tensor.shape_hint.as_ref() {
197 Some(shape) => shape
198 .iter()
199 .enumerate()
200 .map(|(axis, dim)| {
201 dim.constant_value()
202 .map_or(DimExpr::InputDim { input_idx, axis }, DimExpr::Const)
203 })
204 .collect(),
205 None => (0..tensor.rank)
206 .map(|axis| DimExpr::InputDim { input_idx, axis })
207 .collect(),
208 }
209}
210
211fn traced_input_sym_shape(tensor: &TracedTensor) -> Vec<SymDim> {
212 tensor.shape_hint.clone().unwrap_or_else(|| {
213 (0..tensor.rank)
214 .map(|axis| SymDim::tensor_axis(tensor.id, axis))
215 .collect()
216 })
217}
218
219pub(crate) fn infer_traced_single_output_shape(
220 op_name: &'static str,
221 op: &StdTensorOp,
222 inputs: &[&TracedTensor],
223) -> Result<(usize, Option<Vec<SymDim>>)> {
224 let input_shape_exprs: Vec<Vec<DimExpr>> = inputs
225 .iter()
226 .enumerate()
227 .map(|(input_idx, tensor)| traced_input_shape_exprs(input_idx, tensor))
228 .collect();
229 let input_shape_refs: Vec<&[DimExpr]> = input_shape_exprs.iter().map(Vec::as_slice).collect();
230 let output_shapes =
231 crate::shape_infer::infer_output_shapes(op, &input_shape_refs).map_err(|err| {
232 Error::InvalidGraphBuild {
233 op: op_name,
234 message: err.to_string(),
235 }
236 })?;
237 let output_shape = output_shapes
238 .first()
239 .ok_or_else(|| Error::InvalidGraphBuild {
240 op: op_name,
241 message: "shape inference returned no outputs".into(),
242 })?;
243 if output_shapes.len() != 1 {
244 return Err(Error::InvalidGraphBuild {
245 op: op_name,
246 message: format!(
247 "expected single-output shape inference, got {} outputs",
248 output_shapes.len()
249 ),
250 });
251 }
252
253 let input_sym_shapes: Vec<Vec<SymDim>> = inputs
254 .iter()
255 .map(|tensor| traced_input_sym_shape(tensor))
256 .collect();
257 let input_sym_refs: Vec<&[SymDim]> = input_sym_shapes.iter().map(Vec::as_slice).collect();
258 let out_shape_hint = output_shape
259 .iter()
260 .map(|dim| SymDim::from_dim_expr(dim, &input_sym_refs))
261 .collect();
262 Ok((output_shape.len(), Some(out_shape_hint)))
263}
264
265fn register_metadata_or_internal(
266 result: std::result::Result<GlobalMetadataScope, impl std::fmt::Display>,
267) -> Result<GlobalMetadataScope> {
268 result.map_err(|err| Error::Internal(format!("metadata registration failed: {err}")))
269}
270
271fn reduction_output_meta(
272 tensor: &TracedTensor,
273 axes: &[usize],
274 op: &'static str,
275) -> Result<(usize, Option<Vec<SymDim>>)> {
276 let mut seen = vec![false; tensor.rank];
277 for &axis in axes {
278 if axis >= tensor.rank {
279 return Err(Error::InvalidGraphBuild {
280 op,
281 message: format!("axis {axis} out of bounds for rank {}", tensor.rank),
282 });
283 }
284 if seen[axis] {
285 return Err(Error::InvalidGraphBuild {
286 op,
287 message: format!("duplicate reduction axis {axis}"),
288 });
289 }
290 seen[axis] = true;
291 }
292
293 let out_shape_hint = tensor.shape_hint.as_ref().map(|shape| {
294 (0..shape.len())
295 .filter(|d| !axes.contains(d))
296 .map(|d| shape[d].clone())
297 .collect()
298 });
299 Ok((tensor.rank - axes.len(), out_shape_hint))
300}
301
302fn validate_traced_axis(tensor: &TracedTensor, axis: usize, op: &'static str) -> Result<()> {
303 if axis >= tensor.rank {
304 return Err(Error::InvalidGraphBuild {
305 op,
306 message: format!("axis {axis} out of bounds for rank {}", tensor.rank),
307 });
308 }
309 Ok(())
310}
311
312fn validate_traced_axes(rank: usize, axes: &[usize], op: &'static str) -> Result<()> {
313 let mut seen = vec![false; rank];
314 for &axis in axes {
315 if axis >= rank {
316 return Err(Error::InvalidGraphBuild {
317 op,
318 message: format!("axis {axis} out of bounds for rank {rank}"),
319 });
320 }
321 if seen[axis] {
322 return Err(Error::InvalidGraphBuild {
323 op,
324 message: format!("duplicate axis {axis}"),
325 });
326 }
327 seen[axis] = true;
328 }
329 Ok(())
330}
331
332fn validate_traced_insert_axis(rank: usize, axis: usize, op: &'static str) -> Result<()> {
333 if axis > rank {
334 return Err(Error::InvalidGraphBuild {
335 op,
336 message: format!("axis {axis} out of bounds for rank {rank} insertion"),
337 });
338 }
339 Ok(())
340}
341
342fn validate_traced_perm(rank: usize, perm: &[usize], op: &'static str) -> Result<()> {
343 if perm.len() != rank {
344 return Err(Error::InvalidGraphBuild {
345 op,
346 message: format!(
347 "permutation length {} does not match rank {rank}",
348 perm.len()
349 ),
350 });
351 }
352 let mut seen = vec![false; rank];
353 for &axis in perm {
354 if axis >= rank {
355 return Err(Error::InvalidGraphBuild {
356 op,
357 message: format!("permutation axis {axis} out of bounds for rank {rank}"),
358 });
359 }
360 if seen[axis] {
361 return Err(Error::InvalidGraphBuild {
362 op,
363 message: format!("duplicate permutation axis {axis}"),
364 });
365 }
366 seen[axis] = true;
367 }
368 Ok(())
369}
370
371fn validate_broadcast_in_dim_args(
372 input: &TracedTensor,
373 output_shape: &[SymDim],
374 dims: &[usize],
375 op: &'static str,
376) -> Result<()> {
377 if dims.len() != input.rank {
378 return Err(Error::InvalidGraphBuild {
379 op,
380 message: format!(
381 "dims length {} must match input rank {}",
382 dims.len(),
383 input.rank
384 ),
385 });
386 }
387
388 let mut seen = vec![false; output_shape.len()];
389 for &dim in dims {
390 if dim >= output_shape.len() {
391 return Err(Error::InvalidGraphBuild {
392 op,
393 message: format!(
394 "broadcast dim {dim} out of bounds for output rank {}",
395 output_shape.len()
396 ),
397 });
398 }
399 if seen[dim] {
400 return Err(Error::InvalidGraphBuild {
401 op,
402 message: format!("duplicate broadcast dim {dim}"),
403 });
404 }
405 seen[dim] = true;
406 }
407
408 if let Some(input_shape) = input.shape_hint.as_ref() {
409 for (input_axis, &output_axis) in dims.iter().enumerate() {
410 let input_dim = &input_shape[input_axis];
411 let output_dim = &output_shape[output_axis];
412 if input_dim != output_dim && input_dim.constant_value() != Some(1) {
413 return Err(Error::InvalidGraphBuild {
414 op,
415 message: format!(
416 "input axis {input_axis} with dim {input_dim:?} cannot broadcast to \
417 output axis {output_axis} with dim {output_dim:?}"
418 ),
419 });
420 }
421 }
422 }
423
424 Ok(())
425}
426
427impl std::ops::Add for &TracedTensor {
428 type Output = Result<TracedTensor>;
429
430 fn add(self, rhs: &TracedTensor) -> Result<TracedTensor> {
431 TracedTensor::add(self, rhs)
432 }
433}
434
435impl std::ops::Sub for &TracedTensor {
436 type Output = Result<TracedTensor>;
437
438 fn sub(self, rhs: &TracedTensor) -> Result<TracedTensor> {
439 TracedTensor::sub(self, rhs)
440 }
441}
442
443impl std::ops::Mul for &TracedTensor {
444 type Output = Result<TracedTensor>;
445
446 fn mul(self, rhs: &TracedTensor) -> Result<TracedTensor> {
447 TracedTensor::mul(self, rhs)
448 }
449}
450
451impl std::ops::Mul<f64> for &TracedTensor {
452 type Output = TracedTensor;
453
454 fn mul(self, rhs: f64) -> TracedTensor {
455 self.scale_real(rhs)
456 }
457}
458
459impl std::ops::Mul<&TracedTensor> for f64 {
460 type Output = TracedTensor;
461
462 fn mul(self, rhs: &TracedTensor) -> TracedTensor {
463 rhs.scale_real(self)
464 }
465}
466
467impl std::ops::Neg for &TracedTensor {
468 type Output = TracedTensor;
469
470 fn neg(self) -> TracedTensor {
471 TracedTensor::neg(self)
472 }
473}
474
475impl std::ops::Div for &TracedTensor {
476 type Output = Result<TracedTensor>;
477
478 fn div(self, rhs: &TracedTensor) -> Result<TracedTensor> {
479 TracedTensor::div(self, rhs)
480 }
481}
482
483impl TracedTensor {
484 pub fn graph(&self) -> &Arc<Graph<StdTensorOp>> {
495 &self.graph
496 }
497
498 pub fn attached_data(&self) -> Option<&Arc<Tensor>> {
516 self.data.as_ref()
517 }
518
519 pub fn from_tensor_concrete_shape(tensor: Tensor) -> Result<Self> {
540 let shape = tensor.shape().to_vec();
541 let rank = shape.len();
542 let dtype = tensor.dtype();
543 let key = next_input_key();
544 let id = next_traced_id();
545 let data = Arc::new(tensor);
546
547 let mut builder = GraphBuilder::new();
548 let val = builder.add_input(key.clone());
549 builder.set_outputs(vec![val]);
550 let graph = Arc::new(builder.build());
551 let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
552 graph.values()[val].key.clone(),
553 concrete_tensor_meta(dtype, &shape),
554 ))?;
555
556 let mut map = HashMap::new();
557 map.insert(key, Arc::clone(&data));
558
559 Ok(Self {
560 id,
561 rank,
562 dtype,
563 graph,
564 val,
565 data: Some(data),
566 shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
567 inputs_map: Arc::new(map),
568 extra_roots: Vec::new(),
569 checkpoint_chain: None,
570 metadata_scopes: metadata_scopes_for_scope(metadata_scope),
571 })
572 }
573
574 pub fn from_tensor_symbolic_shape(tensor: Tensor) -> Result<Self> {
595 let rank = tensor.shape().len();
596 let dtype = tensor.dtype();
597 let key = next_input_key();
598 let id = next_traced_id();
599 let data = Arc::new(tensor);
600
601 let mut builder = GraphBuilder::new();
602 let val = builder.add_input(key.clone());
603 builder.set_outputs(vec![val]);
604 let graph = Arc::new(builder.build());
605 let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
606 graph.values()[val].key.clone(),
607 symbolic_input_meta(dtype, id, rank),
608 ))?;
609
610 let mut map = HashMap::new();
611 map.insert(key, Arc::clone(&data));
612
613 Ok(Self {
614 id,
615 rank,
616 dtype,
617 graph,
618 val,
619 data: Some(data),
620 shape_hint: None,
621 inputs_map: Arc::new(map),
622 extra_roots: Vec::new(),
623 checkpoint_chain: None,
624 metadata_scopes: metadata_scopes_for_scope(metadata_scope),
625 })
626 }
627
628 pub fn input_concrete_shape(dtype: DType, shape: &[usize]) -> Result<Self> {
645 let shape = shape.to_vec();
646 let rank = shape.len();
647 let key = next_input_key();
648 let id = next_traced_id();
649
650 let mut builder = GraphBuilder::new();
651 let val = builder.add_input(key.clone());
652 builder.set_outputs(vec![val]);
653 let graph = Arc::new(builder.build());
654 let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
655 graph.values()[val].key.clone(),
656 concrete_tensor_meta(dtype, &shape),
657 ))?;
658
659 Ok(Self {
660 id,
661 rank,
662 dtype,
663 graph,
664 val,
665 data: None,
666 shape_hint: Some(shape.into_iter().map(SymDim::from).collect()),
667 inputs_map: Arc::new(HashMap::new()),
668 extra_roots: Vec::new(),
669 checkpoint_chain: None,
670 metadata_scopes: metadata_scopes_for_scope(metadata_scope),
671 })
672 }
673
674 pub fn input_symbolic_shape(dtype: DType, rank: usize) -> Result<Self> {
691 let key = next_input_key();
692 let id = next_traced_id();
693
694 let mut builder = GraphBuilder::new();
695 let val = builder.add_input(key.clone());
696 builder.set_outputs(vec![val]);
697 let graph = Arc::new(builder.build());
698 let metadata_scope = register_metadata_or_internal(register_scoped_value_metadata(
699 graph.values()[val].key.clone(),
700 symbolic_input_meta(dtype, id, rank),
701 ))?;
702
703 Ok(Self {
704 id,
705 rank,
706 dtype,
707 graph,
708 val,
709 data: None,
710 shape_hint: None,
711 inputs_map: Arc::new(HashMap::new()),
712 extra_roots: Vec::new(),
713 checkpoint_chain: None,
714 metadata_scopes: metadata_scopes_for_scope(metadata_scope),
715 })
716 }
717
718 pub fn from_vec_col_major<T: TensorScalar>(shape: Vec<usize>, data: Vec<T>) -> Result<Self> {
736 Self::from_tensor_concrete_shape(Tensor::from_vec_col_major(shape, data)?)
737 }
738
739 pub fn is_concrete_shape(&self) -> bool {
754 try_concrete_shape(self).is_some()
755 }
756
757 pub fn try_concrete_shape(&self) -> Option<Vec<usize>> {
778 try_concrete_shape(self)
779 }
780
781 pub fn concrete_shape(&self) -> Result<Vec<usize>> {
787 concrete_shape(self)
788 }
789
790 pub fn input_key(&self) -> Option<TensorInputKey> {
793 match &self.graph.values()[self.val].key {
794 ValueKey::Input(key) => Some(key.clone()),
795 _ => None,
796 }
797 }
798
799 pub fn add(&self, other: &TracedTensor) -> Result<TracedTensor> {
813 let (lhs, rhs) = broadcast_binary(self, other)?;
814 Ok(apply_binary(
815 StdTensorOp::Add,
816 &lhs,
817 &rhs,
818 lhs.rank,
819 lhs.shape_hint.clone(),
820 ))
821 }
822
823 pub fn sub(&self, other: &TracedTensor) -> Result<TracedTensor> {
827 let (lhs, rhs) = broadcast_binary(self, other)?;
828 let rhs = rhs.neg();
829 Ok(apply_binary(
830 StdTensorOp::Add,
831 &lhs,
832 &rhs,
833 lhs.rank,
834 lhs.shape_hint.clone(),
835 ))
836 }
837
838 pub fn mul(&self, other: &TracedTensor) -> Result<TracedTensor> {
852 let (lhs, rhs) = broadcast_binary(self, other)?;
853 Ok(apply_binary(
854 StdTensorOp::Mul,
855 &lhs,
856 &rhs,
857 lhs.rank,
858 lhs.shape_hint.clone(),
859 ))
860 }
861
862 pub fn div(&self, other: &TracedTensor) -> Result<TracedTensor> {
876 let (lhs, rhs) = broadcast_binary(self, other)?;
877 Ok(apply_binary(
878 StdTensorOp::Div,
879 &lhs,
880 &rhs,
881 lhs.rank,
882 lhs.shape_hint.clone(),
883 ))
884 }
885
886 pub fn compare(&self, other: &TracedTensor, dir: CompareDir) -> Result<TracedTensor> {
888 let (lhs, rhs) = broadcast_binary(self, other)?;
889 Ok(apply_binary(
890 StdTensorOp::Compare(dir),
891 &lhs,
892 &rhs,
893 lhs.rank,
894 lhs.shape_hint.clone(),
895 ))
896 }
897
898 pub fn maximum(&self, other: &TracedTensor) -> Result<TracedTensor> {
900 apply_broadcast_binary_op(StdTensorOp::Maximum, self, other)
901 }
902
903 pub fn minimum(&self, other: &TracedTensor) -> Result<TracedTensor> {
905 apply_broadcast_binary_op(StdTensorOp::Minimum, self, other)
906 }
907
908 pub fn where_select(
910 condition: &TracedTensor,
911 on_true: &TracedTensor,
912 on_false: &TracedTensor,
913 ) -> Result<TracedTensor> {
914 apply_broadcast_ternary_op(StdTensorOp::Select, condition, on_true, on_false)
915 }
916
917 pub fn select(
919 condition: &TracedTensor,
920 on_true: &TracedTensor,
921 on_false: &TracedTensor,
922 ) -> Result<TracedTensor> {
923 Self::where_select(condition, on_true, on_false)
924 }
925
926 pub fn clamp(&self, lower: &TracedTensor, upper: &TracedTensor) -> Result<TracedTensor> {
928 apply_broadcast_ternary_op(StdTensorOp::Clamp, self, lower, upper)
929 }
930
931 fn apply_same_shape_unary(&self, op: StdTensorOp) -> TracedTensor {
932 apply_unary(op, self, self.rank, self.shape_hint.clone())
933 }
934
935 pub fn neg(&self) -> TracedTensor {
948 self.apply_same_shape_unary(StdTensorOp::Neg)
949 }
950
951 pub fn conj(&self) -> TracedTensor {
966 self.apply_same_shape_unary(StdTensorOp::Conj)
967 }
968
969 pub fn abs(&self) -> TracedTensor {
981 self.apply_same_shape_unary(StdTensorOp::Abs)
982 }
983
984 pub fn sign(&self) -> TracedTensor {
994 self.apply_same_shape_unary(StdTensorOp::Sign)
995 }
996
997 pub fn scale_real(&self, factor: f64) -> TracedTensor {
1007 let op = match self.dtype {
1008 DType::F64 => StdTensorOp::constant(factor),
1009 DType::F32 => StdTensorOp::constant(factor as f32),
1010 DType::I32 => StdTensorOp::constant(round_real_to_i64(factor) as i32),
1011 DType::I64 => StdTensorOp::constant(round_real_to_i64(factor)),
1012 DType::Bool => StdTensorOp::constant(factor != 0.0),
1013 DType::C64 => StdTensorOp::constant(Complex64::new(factor, 0.0)),
1014 DType::C32 => StdTensorOp::constant(Complex32::new(factor as f32, 0.0)),
1015 };
1016 scale_with_constant(self, op)
1017 }
1018
1019 pub fn scale_complex(&self, factor: Complex64) -> Result<TracedTensor> {
1037 match self.dtype {
1038 DType::C64 => Ok(scale_with_constant(self, StdTensorOp::constant(factor))),
1039 DType::C32 => Ok(scale_with_constant(
1040 self,
1041 StdTensorOp::constant(Complex32::new(factor.re as f32, factor.im as f32)),
1042 )),
1043 DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => {
1044 Err(Error::InvalidGraphBuild {
1045 op: "scale_complex",
1046 message: format!("requires complex tensor dtype, got {:?}", self.dtype),
1047 })
1048 }
1049 }
1050 }
1051
1052 pub fn exp(&self) -> TracedTensor {
1062 self.apply_same_shape_unary(StdTensorOp::Exp)
1063 }
1064
1065 pub fn log(&self) -> TracedTensor {
1075 self.apply_same_shape_unary(StdTensorOp::Log)
1076 }
1077
1078 pub fn sin(&self) -> TracedTensor {
1088 self.apply_same_shape_unary(StdTensorOp::Sin)
1089 }
1090
1091 pub fn cos(&self) -> TracedTensor {
1101 self.apply_same_shape_unary(StdTensorOp::Cos)
1102 }
1103
1104 pub fn tanh(&self) -> TracedTensor {
1114 self.apply_same_shape_unary(StdTensorOp::Tanh)
1115 }
1116
1117 pub fn sqrt(&self) -> TracedTensor {
1127 self.apply_same_shape_unary(StdTensorOp::Sqrt)
1128 }
1129
1130 pub fn rsqrt(&self) -> TracedTensor {
1140 self.apply_same_shape_unary(StdTensorOp::Rsqrt)
1141 }
1142
1143 pub fn pow(&self, other: &TracedTensor) -> Result<TracedTensor> {
1154 let (lhs, rhs) = broadcast_binary(self, other)?;
1155 Ok(apply_binary(
1156 StdTensorOp::Pow,
1157 &lhs,
1158 &rhs,
1159 lhs.rank,
1160 lhs.shape_hint.clone(),
1161 ))
1162 }
1163
1164 pub fn expm1(&self) -> TracedTensor {
1174 self.apply_same_shape_unary(StdTensorOp::Expm1)
1175 }
1176
1177 pub fn log1p(&self) -> TracedTensor {
1187 self.apply_same_shape_unary(StdTensorOp::Log1p)
1188 }
1189
1190 pub fn convert(&self, to: DType) -> Result<TracedTensor> {
1211 tenferro_tensor::validate::validate_convert_dtype("TracedTensor::convert", self.dtype, to)?;
1212 Ok(self.cast(to))
1213 }
1214
1215 pub fn cast(&self, to: DType) -> TracedTensor {
1231 if self.dtype == to {
1232 return self.clone();
1233 }
1234
1235 apply_unary_with_dtype(
1236 StdTensorOp::Convert {
1237 from: self.dtype,
1238 to,
1239 },
1240 self,
1241 self.rank,
1242 self.shape_hint.clone(),
1243 to,
1244 )
1245 }
1246
1247 pub fn dot_general(
1270 &self,
1271 other: &TracedTensor,
1272 config: DotGeneralConfig,
1273 ) -> Result<TracedTensor> {
1274 config
1275 .validate_dims_with_ranks(self.rank, other.rank)
1276 .map_err(|message| Error::InvalidGraphBuild {
1277 op: "dot_general",
1278 message,
1279 })?;
1280 let lhs_free: Vec<usize> = (0..self.rank)
1281 .filter(|d| {
1282 !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d)
1283 })
1284 .collect();
1285 let rhs_free: Vec<usize> = (0..other.rank)
1286 .filter(|d| {
1287 !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d)
1288 })
1289 .collect();
1290 let out_rank = config.lhs_batch_dims.len() + lhs_free.len() + rhs_free.len();
1291 let out_shape_hint = match (&self.shape_hint, &other.shape_hint) {
1292 (Some(lhs_shape), Some(rhs_shape)) => {
1293 let mut out_shape = Vec::with_capacity(out_rank);
1294 for &d in &lhs_free {
1295 out_shape.push(lhs_shape[d].clone());
1296 }
1297 for &d in &rhs_free {
1298 out_shape.push(rhs_shape[d].clone());
1299 }
1300 for &d in &config.lhs_batch_dims {
1301 out_shape.push(lhs_shape[d].clone());
1302 }
1303 Some(out_shape)
1304 }
1305 _ => None,
1306 };
1307
1308 Ok(apply_binary(
1309 StdTensorOp::DotGeneral { config },
1310 self,
1311 other,
1312 out_rank,
1313 out_shape_hint,
1314 ))
1315 }
1316
1317 pub fn matmul(&self, other: &TracedTensor) -> Result<TracedTensor> {
1319 if self.rank != 2 {
1320 return Err(Error::InvalidGraphBuild {
1321 op: "TracedTensor::matmul",
1322 message: format!("matmul requires rank-2 inputs, got lhs rank {}", self.rank),
1323 });
1324 }
1325 if other.rank != 2 {
1326 return Err(Error::InvalidGraphBuild {
1327 op: "TracedTensor::matmul",
1328 message: format!("matmul requires rank-2 inputs, got rhs rank {}", other.rank),
1329 });
1330 }
1331 if let (Some(lhs_shape), Some(rhs_shape)) = (&self.shape_hint, &other.shape_hint) {
1332 if let (Some(lhs_cols), Some(rhs_rows)) =
1333 (lhs_shape[1].constant_value(), rhs_shape[0].constant_value())
1334 {
1335 if lhs_cols != rhs_rows {
1336 return Err(Error::InvalidGraphBuild {
1337 op: "TracedTensor::matmul",
1338 message: format!(
1339 "matmul dimension mismatch: lhs columns {lhs_cols} != rhs rows {rhs_rows}"
1340 ),
1341 });
1342 }
1343 }
1344 }
1345 self.dot_general(
1346 other,
1347 DotGeneralConfig {
1348 lhs_contracting_dims: vec![1],
1349 rhs_contracting_dims: vec![0],
1350 lhs_batch_dims: vec![],
1351 rhs_batch_dims: vec![],
1352 },
1353 )
1354 }
1355
1356 pub fn reduce_sum(&self, axes: &[usize]) -> Result<TracedTensor> {
1372 let (out_rank, out_shape_hint) =
1373 reduction_output_meta(self, axes, "TracedTensor::reduce_sum")?;
1374 Ok(apply_unary(
1375 StdTensorOp::ReduceSum {
1376 axes: axes.to_vec(),
1377 },
1378 self,
1379 out_rank,
1380 out_shape_hint,
1381 ))
1382 }
1383
1384 pub fn reduce_max(&self, axes: &[usize]) -> Result<TracedTensor> {
1402 let (out_rank, out_shape_hint) =
1403 reduction_output_meta(self, axes, "TracedTensor::reduce_max")?;
1404 Ok(apply_unary(
1405 StdTensorOp::ReduceMax {
1406 axes: axes.to_vec(),
1407 },
1408 self,
1409 out_rank,
1410 out_shape_hint,
1411 ))
1412 }
1413
1414 pub fn reduce_min(&self, axes: &[usize]) -> Result<TracedTensor> {
1432 let (out_rank, out_shape_hint) =
1433 reduction_output_meta(self, axes, "TracedTensor::reduce_min")?;
1434 Ok(apply_unary(
1435 StdTensorOp::ReduceMin {
1436 axes: axes.to_vec(),
1437 },
1438 self,
1439 out_rank,
1440 out_shape_hint,
1441 ))
1442 }
1443
1444 pub fn reduce_prod(&self, axes: &[usize]) -> Result<TracedTensor> {
1459 let (out_rank, out_shape_hint) =
1460 reduction_output_meta(self, axes, "TracedTensor::reduce_prod")?;
1461 Ok(apply_unary(
1462 StdTensorOp::ReduceProd {
1463 axes: axes.to_vec(),
1464 },
1465 self,
1466 out_rank,
1467 out_shape_hint,
1468 ))
1469 }
1470
1471 pub fn reshape(&self, shape: &[usize]) -> TracedTensor {
1481 apply_unary(
1482 StdTensorOp::Reshape {
1483 to_shape: DimExpr::from_concrete(shape),
1484 },
1485 self,
1486 shape.len(),
1487 Some(shape.iter().copied().map(SymDim::from).collect()),
1488 )
1489 }
1490
1491 pub fn sym_size(&self, axis: usize) -> Result<SymDim> {
1522 validate_traced_axis(self, axis, "TracedTensor::sym_size")?;
1523 Ok(self
1524 .shape_hint
1525 .as_ref()
1526 .and_then(|shape| shape.get(axis))
1527 .filter(|dim| dim.constant_value().is_none())
1528 .cloned()
1529 .unwrap_or_else(|| SymDim::tensor_axis(self.id, axis)))
1530 }
1531
1532 pub fn axis_sym_dim(&self, axis: usize) -> Result<SymDim> {
1561 validate_traced_axis(self, axis, "TracedTensor::axis_sym_dim")?;
1562 match self.shape_hint.as_ref().and_then(|shape| shape.get(axis)) {
1563 Some(dim) => Ok(dim.clone()),
1564 None => Ok(SymDim::tensor_axis(self.id, axis)),
1565 }
1566 }
1567
1568 pub fn sym_shape(&self) -> Option<&[SymDim]> {
1590 self.shape_hint.as_deref()
1591 }
1592
1593 pub fn reshape_sym(&self, shape: &[SymDim]) -> Result<TracedTensor> {
1606 let tensor_map = [(self.id, 0usize)];
1607 let to_shape = shape
1608 .iter()
1609 .map(|dim| dim.to_dim_expr(&tensor_map).map_err(Error::Internal))
1610 .collect::<Result<Vec<_>>>()?;
1611 let out_shape_hint = Some(shape.to_vec());
1612 Ok(apply_unary(
1613 StdTensorOp::Reshape { to_shape },
1614 self,
1615 shape.len(),
1616 out_shape_hint,
1617 ))
1618 }
1619
1620 pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> Result<TracedTensor> {
1637 let out_shape_hint: Vec<SymDim> = shape.iter().copied().map(SymDim::from).collect();
1638 validate_broadcast_in_dim_args(
1639 self,
1640 &out_shape_hint,
1641 dims,
1642 "TracedTensor::broadcast_in_dim",
1643 )?;
1644 Ok(apply_unary(
1645 StdTensorOp::BroadcastInDim {
1646 shape: DimExpr::from_concrete(shape),
1647 dims: dims.to_vec(),
1648 },
1649 self,
1650 shape.len(),
1651 Some(out_shape_hint),
1652 ))
1653 }
1654
1655 pub fn broadcast_in_dim_sym(
1689 &self,
1690 shape: &[SymDim],
1691 dims: &[usize],
1692 shape_refs: &[&TracedTensor],
1693 ) -> Result<TracedTensor> {
1694 validate_broadcast_in_dim_args(self, shape, dims, "TracedTensor::broadcast_in_dim_sym")?;
1695
1696 let mut dedup_refs: Vec<&TracedTensor> = Vec::with_capacity(shape_refs.len());
1700 let mut tensor_map: Vec<(u64, usize)> = vec![(self.id, 0)];
1701 for &t in shape_refs {
1702 if !tensor_map.iter().any(|(id, _)| *id == t.id) {
1703 let idx = tensor_map.len();
1704 tensor_map.push((t.id, idx));
1705 dedup_refs.push(t);
1706 }
1707 }
1708
1709 let to_shape: Vec<DimExpr> = shape
1710 .iter()
1711 .map(|dim| {
1712 dim.to_dim_expr(&tensor_map)
1713 .map_err(|err| Error::InvalidGraphBuild {
1714 op: "broadcast_in_dim_sym",
1715 message: format!(
1716 "unresolved symbolic dimension: {err}; \
1717 pass every referenced tensor via `shape_refs`"
1718 ),
1719 })
1720 })
1721 .collect::<Result<Vec<_>>>()?;
1722
1723 let max_used_idx = DimExpr::max_input_idx_all(&to_shape).unwrap_or(0);
1730 let used_refs: Vec<&TracedTensor> = dedup_refs.into_iter().take(max_used_idx).collect();
1731
1732 let out_shape_hint = Some(shape.to_vec());
1733 Ok(apply_unary_with_shape_refs(
1734 StdTensorOp::BroadcastInDim {
1735 shape: to_shape,
1736 dims: dims.to_vec(),
1737 },
1738 self,
1739 &used_refs,
1740 shape.len(),
1741 out_shape_hint,
1742 ))
1743 }
1744
1745 pub fn slice(&self, config: SliceConfig) -> Result<TracedTensor> {
1747 let op = StdTensorOp::Slice(config);
1748 let (out_rank, out_shape_hint) =
1749 infer_traced_single_output_shape("TracedTensor::slice", &op, &[self])?;
1750 Ok(apply_unary(op, self, out_rank, out_shape_hint))
1751 }
1752
1753 pub fn pad(&self, config: PadConfig) -> Result<TracedTensor> {
1755 let op = StdTensorOp::Pad(config);
1756 let (out_rank, out_shape_hint) =
1757 infer_traced_single_output_shape("TracedTensor::pad", &op, &[self])?;
1758 Ok(apply_unary(op, self, out_rank, out_shape_hint))
1759 }
1760
1761 pub fn reverse(&self, axes: &[usize]) -> Result<TracedTensor> {
1763 validate_traced_axes(self.rank, axes, "TracedTensor::reverse")?;
1764 Ok(apply_unary(
1765 StdTensorOp::Reverse {
1766 axes: axes.to_vec(),
1767 },
1768 self,
1769 self.rank,
1770 self.shape_hint.clone(),
1771 ))
1772 }
1773
1774 pub fn gather(&self, indices: &TracedTensor, config: GatherConfig) -> Result<TracedTensor> {
1776 let op = StdTensorOp::Gather(config);
1777 let (out_rank, out_shape_hint) =
1778 infer_traced_single_output_shape("TracedTensor::gather", &op, &[self, indices])?;
1779 Ok(apply_binary_preserve_input_dtypes(
1780 op,
1781 self,
1782 indices,
1783 out_rank,
1784 out_shape_hint,
1785 self.dtype,
1786 ))
1787 }
1788
1789 pub fn scatter(
1791 &self,
1792 indices: &TracedTensor,
1793 updates: &TracedTensor,
1794 config: ScatterConfig,
1795 ) -> Result<TracedTensor> {
1796 let op = StdTensorOp::Scatter(config);
1797 let (out_rank, out_shape_hint) = infer_traced_single_output_shape(
1798 "TracedTensor::scatter",
1799 &op,
1800 &[self, indices, updates],
1801 )?;
1802 let out_dtype = crate::shape_infer::promote_dtype(self.dtype, updates.dtype);
1803 let operand = if self.dtype != out_dtype {
1804 self.cast(out_dtype)
1805 } else {
1806 self.clone()
1807 };
1808 let updates = if updates.dtype != out_dtype {
1809 updates.cast(out_dtype)
1810 } else {
1811 updates.clone()
1812 };
1813 Ok(apply_ternary_with_output_dtype(
1814 op,
1815 &operand,
1816 indices,
1817 &updates,
1818 out_rank,
1819 out_shape_hint,
1820 out_dtype,
1821 ))
1822 }
1823
1824 pub fn dynamic_slice(&self, starts: &TracedTensor, sizes: &[usize]) -> Result<TracedTensor> {
1826 let op = StdTensorOp::DynamicSlice {
1827 slice_sizes: sizes.to_vec(),
1828 };
1829 let (out_rank, out_shape_hint) =
1830 infer_traced_single_output_shape("TracedTensor::dynamic_slice", &op, &[self, starts])?;
1831 Ok(apply_binary_preserve_input_dtypes(
1832 op,
1833 self,
1834 starts,
1835 out_rank,
1836 out_shape_hint,
1837 self.dtype,
1838 ))
1839 }
1840
1841 pub fn tril(&self, k: i64) -> TracedTensor {
1843 apply_unary(
1844 StdTensorOp::Tril { k },
1845 self,
1846 self.rank,
1847 self.shape_hint.clone(),
1848 )
1849 }
1850
1851 pub fn triu(&self, k: i64) -> TracedTensor {
1853 apply_unary(
1854 StdTensorOp::Triu { k },
1855 self,
1856 self.rank,
1857 self.shape_hint.clone(),
1858 )
1859 }
1860
1861 pub fn transpose(&self, perm: &[usize]) -> Result<TracedTensor> {
1877 validate_traced_perm(self.rank, perm, "TracedTensor::transpose")?;
1878 let out_shape_hint = self
1879 .shape_hint
1880 .as_ref()
1881 .map(|shape| perm.iter().map(|&p| shape[p].clone()).collect());
1882 Ok(apply_unary(
1883 StdTensorOp::Transpose {
1884 perm: perm.to_vec(),
1885 },
1886 self,
1887 self.rank,
1888 out_shape_hint,
1889 ))
1890 }
1891
1892 pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> Result<TracedTensor> {
1908 validate_traced_axis(self, axis_a, "TracedTensor::extract_diag")?;
1909 validate_traced_axis(self, axis_b, "TracedTensor::extract_diag")?;
1910 if axis_a == axis_b {
1911 return Err(Error::InvalidGraphBuild {
1912 op: "TracedTensor::extract_diag",
1913 message: "diagonal axes must be distinct".into(),
1914 });
1915 }
1916 let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1917 shape
1918 .iter()
1919 .enumerate()
1920 .filter_map(|(axis, dim)| (axis != axis_b).then_some(dim.clone()))
1921 .collect()
1922 });
1923 Ok(apply_unary(
1924 StdTensorOp::ExtractDiag { axis_a, axis_b },
1925 self,
1926 self.rank - 1,
1927 out_shape_hint,
1928 ))
1929 }
1930
1931 pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> Result<TracedTensor> {
1947 validate_traced_axis(self, axis_a, "TracedTensor::embed_diag")?;
1948 validate_traced_insert_axis(self.rank, axis_b, "TracedTensor::embed_diag")?;
1949 let out_shape_hint = self.shape_hint.as_ref().map(|shape| {
1950 let mut out_shape = shape.clone();
1951 out_shape.insert(axis_b, shape[axis_a].clone());
1952 out_shape
1953 });
1954 Ok(apply_unary(
1955 StdTensorOp::EmbedDiag { axis_a, axis_b },
1956 self,
1957 self.rank + 1,
1958 out_shape_hint,
1959 ))
1960 }
1961
1962 pub fn shape_of(&self, axis: usize) -> Result<TracedTensor> {
1985 validate_traced_axis(self, axis, "TracedTensor::shape_of")?;
1986 Ok(apply_unary_with_dtype(
1987 StdTensorOp::ShapeOf { axis },
1988 self,
1989 0,
1990 Some(vec![]),
1991 DType::F64,
1992 ))
1993 }
1994
1995 pub fn dynamic_truncate(&self, size: &TracedTensor, axis: usize) -> Result<TracedTensor> {
2021 validate_traced_axis(self, axis, "TracedTensor::dynamic_truncate")?;
2022 if size.rank != 0 {
2023 return Err(Error::InvalidGraphBuild {
2024 op: "TracedTensor::dynamic_truncate",
2025 message: format!("size must be a scalar tensor, got rank {}", size.rank),
2026 });
2027 }
2028 Ok(apply_binary(
2029 StdTensorOp::DynamicTruncate { axis },
2030 self,
2031 size,
2032 self.rank,
2033 None,
2034 ))
2035 }
2036
2037 pub fn pad_to_match(&self, reference: &TracedTensor, axis: usize) -> Result<TracedTensor> {
2061 validate_traced_axis(self, axis, "TracedTensor::pad_to_match")?;
2062 validate_traced_axis(reference, axis, "TracedTensor::pad_to_match")?;
2063 Ok(apply_binary(
2064 StdTensorOp::PadToMatch { axis },
2065 self,
2066 reference,
2067 self.rank,
2068 reference.shape_hint.clone(),
2069 ))
2070 }
2071}
2072
2073pub(crate) fn apply_unary(
2074 op: StdTensorOp,
2075 input: &TracedTensor,
2076 out_rank: usize,
2077 out_shape_hint: Option<Vec<SymDim>>,
2078) -> TracedTensor {
2079 let out_dtype =
2080 inferred_output_dtype_or_fallback(&op, &[input.dtype], input.dtype, "apply_unary");
2081 apply_unary_with_dtype(op, input, out_rank, out_shape_hint, out_dtype)
2082}
2083
2084pub(crate) fn apply_unary_with_dtype(
2085 op: StdTensorOp,
2086 input: &TracedTensor,
2087 out_rank: usize,
2088 out_shape_hint: Option<Vec<SymDim>>,
2089 out_dtype: DType,
2090) -> TracedTensor {
2091 let mut builder = GraphBuilder::new();
2092 builder.add_parent(input.graph.clone());
2093 let input_ref = ValueRef::External(input.graph.values()[input.val].key.clone());
2094 let outputs = builder.add_operation(op, vec![input_ref], OperationRole::Primary);
2095 builder.set_outputs(outputs.clone());
2096 let graph = Arc::new(builder.build());
2097 let metadata_scope =
2098 register_single_output_metadata(graph.as_ref(), outputs[0], out_dtype, &out_shape_hint);
2099
2100 TracedTensor {
2101 id: next_traced_id(),
2102 rank: out_rank,
2103 dtype: out_dtype,
2104 graph,
2105 val: outputs[0],
2106 data: None,
2107 shape_hint: out_shape_hint,
2108 inputs_map: input.inputs_map.clone(),
2109 extra_roots: input.extra_roots.clone(),
2110 checkpoint_chain: input.checkpoint_chain.clone(),
2111 metadata_scopes: metadata_scopes_with_new(
2112 metadata_scope,
2113 [input.metadata_scopes.as_slice()],
2114 ),
2115 }
2116}
2117
2118pub(crate) fn apply_unary_with_shape_refs(
2127 op: StdTensorOp,
2128 input: &TracedTensor,
2129 shape_refs: &[&TracedTensor],
2130 out_rank: usize,
2131 out_shape_hint: Option<Vec<SymDim>>,
2132) -> TracedTensor {
2133 let mut builder = GraphBuilder::new();
2134 builder.add_parent(input.graph.clone());
2135 for t in shape_refs {
2136 builder.add_parent(t.graph.clone());
2137 }
2138 let mut op_inputs: Vec<ValueRef<StdTensorOp>> = Vec::with_capacity(1 + shape_refs.len());
2139 op_inputs.push(ValueRef::External(
2140 input.graph.values()[input.val].key.clone(),
2141 ));
2142 for t in shape_refs {
2143 op_inputs.push(ValueRef::External(t.graph.values()[t.val].key.clone()));
2144 }
2145 let outputs = builder.add_operation(op, op_inputs, OperationRole::Primary);
2146 builder.set_outputs(outputs.clone());
2147 let graph = Arc::new(builder.build());
2148 let metadata_scope =
2149 register_single_output_metadata(graph.as_ref(), outputs[0], input.dtype, &out_shape_hint);
2150
2151 let mut merged = (*input.inputs_map).clone();
2152 for t in shape_refs {
2153 merged.extend(t.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
2154 }
2155
2156 let mut extra_roots = input.extra_roots.clone();
2157 for t in shape_refs {
2158 extra_roots.extend(t.extra_roots.iter().cloned());
2159 }
2160
2161 let mut checkpoint_chain = input.checkpoint_chain.clone();
2162 for t in shape_refs {
2163 checkpoint_chain =
2164 CheckpointNode::merge_chains(checkpoint_chain, t.checkpoint_chain.clone());
2165 }
2166
2167 TracedTensor {
2168 id: next_traced_id(),
2169 rank: out_rank,
2170 dtype: input.dtype,
2171 graph,
2172 val: outputs[0],
2173 data: None,
2174 shape_hint: out_shape_hint,
2175 inputs_map: Arc::new(merged),
2176 extra_roots,
2177 checkpoint_chain,
2178 metadata_scopes: {
2179 let mut scopes =
2180 metadata_scopes_with_new(metadata_scope, [input.metadata_scopes.as_slice()]);
2181 for t in shape_refs {
2182 for scope in &t.metadata_scopes {
2183 push_metadata_scope(&mut scopes, Arc::clone(scope));
2184 }
2185 }
2186 scopes
2187 },
2188 }
2189}
2190
2191pub(crate) fn apply_nullary(
2192 op: StdTensorOp,
2193 rank: usize,
2194 dtype: DType,
2195 shape_hint: Option<Vec<SymDim>>,
2196) -> TracedTensor {
2197 let mut builder = GraphBuilder::new();
2198 let outputs = builder.add_operation(op, vec![], OperationRole::Primary);
2199 builder.set_outputs(outputs.clone());
2200 let graph = Arc::new(builder.build());
2201 let metadata_scope =
2202 register_single_output_metadata(graph.as_ref(), outputs[0], dtype, &shape_hint);
2203
2204 TracedTensor {
2205 id: next_traced_id(),
2206 rank,
2207 dtype,
2208 graph,
2209 val: outputs[0],
2210 data: None,
2211 shape_hint,
2212 inputs_map: Arc::new(HashMap::new()),
2213 extra_roots: Vec::new(),
2214 checkpoint_chain: None,
2215 metadata_scopes: metadata_scopes_for_scope(metadata_scope),
2216 }
2217}
2218
2219pub(crate) fn apply_binary(
2220 op: StdTensorOp,
2221 lhs: &TracedTensor,
2222 rhs: &TracedTensor,
2223 out_rank: usize,
2224 out_shape_hint: Option<Vec<SymDim>>,
2225) -> TracedTensor {
2226 let input_dtype = crate::shape_infer::promote_dtype_for_binary_op(&op, lhs.dtype, rhs.dtype);
2227 let out_dtype = inferred_output_dtype_or_fallback(
2228 &op,
2229 &[lhs.dtype, rhs.dtype],
2230 input_dtype,
2231 "apply_binary",
2232 );
2233
2234 let lhs = if lhs.dtype != input_dtype {
2236 lhs.cast(input_dtype)
2237 } else {
2238 lhs.clone()
2239 };
2240 let rhs = if rhs.dtype != input_dtype {
2241 rhs.cast(input_dtype)
2242 } else {
2243 rhs.clone()
2244 };
2245
2246 apply_binary_with_output_dtype(op, &lhs, &rhs, out_rank, out_shape_hint, out_dtype)
2247}
2248
2249pub(crate) fn apply_binary_preserve_input_dtypes(
2250 op: StdTensorOp,
2251 lhs: &TracedTensor,
2252 rhs: &TracedTensor,
2253 out_rank: usize,
2254 out_shape_hint: Option<Vec<SymDim>>,
2255 out_dtype: DType,
2256) -> TracedTensor {
2257 apply_binary_with_output_dtype(op, lhs, rhs, out_rank, out_shape_hint, out_dtype)
2258}
2259
2260pub(crate) fn apply_broadcast_binary_op(
2261 op: StdTensorOp,
2262 lhs: &TracedTensor,
2263 rhs: &TracedTensor,
2264) -> Result<TracedTensor> {
2265 let (lhs, rhs) = broadcast_binary(lhs, rhs)?;
2266 Ok(apply_binary(
2267 op,
2268 &lhs,
2269 &rhs,
2270 lhs.rank,
2271 lhs.shape_hint.clone(),
2272 ))
2273}
2274
2275pub(crate) fn apply_broadcast_ternary_op(
2276 op: StdTensorOp,
2277 first: &TracedTensor,
2278 second: &TracedTensor,
2279 third: &TracedTensor,
2280) -> Result<TracedTensor> {
2281 let (first, second, third) = broadcast_ternary(first, second, third)?;
2282 Ok(apply_ternary(
2283 op,
2284 &first,
2285 &second,
2286 &third,
2287 first.rank,
2288 first.shape_hint.clone(),
2289 ))
2290}
2291
2292pub(crate) fn apply_ternary(
2293 op: StdTensorOp,
2294 first: &TracedTensor,
2295 second: &TracedTensor,
2296 third: &TracedTensor,
2297 out_rank: usize,
2298 out_shape_hint: Option<Vec<SymDim>>,
2299) -> TracedTensor {
2300 let fallback_dtype =
2301 crate::shape_infer::promote_dtypes([first.dtype, second.dtype, third.dtype]);
2302 let out_dtype = inferred_output_dtype_or_fallback(
2303 &op,
2304 &[first.dtype, second.dtype, third.dtype],
2305 fallback_dtype,
2306 "apply_ternary",
2307 );
2308 let (first, second, third) = match op {
2309 StdTensorOp::Select => {
2310 let value_dtype = crate::shape_infer::promote_dtype(second.dtype, third.dtype);
2311 let second = if second.dtype != value_dtype {
2312 second.cast(value_dtype)
2313 } else {
2314 second.clone()
2315 };
2316 let third = if third.dtype != value_dtype {
2317 third.cast(value_dtype)
2318 } else {
2319 third.clone()
2320 };
2321 (first.clone(), second, third)
2322 }
2323 _ => {
2324 let input_dtype =
2325 crate::shape_infer::promote_dtypes([first.dtype, second.dtype, third.dtype]);
2326 let first = if first.dtype != input_dtype {
2327 first.cast(input_dtype)
2328 } else {
2329 first.clone()
2330 };
2331 let second = if second.dtype != input_dtype {
2332 second.cast(input_dtype)
2333 } else {
2334 second.clone()
2335 };
2336 let third = if third.dtype != input_dtype {
2337 third.cast(input_dtype)
2338 } else {
2339 third.clone()
2340 };
2341 (first, second, third)
2342 }
2343 };
2344 apply_ternary_with_output_dtype(
2345 op,
2346 &first,
2347 &second,
2348 &third,
2349 out_rank,
2350 out_shape_hint,
2351 out_dtype,
2352 )
2353}
2354
2355fn apply_binary_with_output_dtype(
2356 op: StdTensorOp,
2357 lhs: &TracedTensor,
2358 rhs: &TracedTensor,
2359 out_rank: usize,
2360 out_shape_hint: Option<Vec<SymDim>>,
2361 out_dtype: DType,
2362) -> TracedTensor {
2363 let lhs_ref = ValueRef::External(lhs.graph.values()[lhs.val].key.clone());
2364 let rhs_ref = ValueRef::External(rhs.graph.values()[rhs.val].key.clone());
2365
2366 let mut builder = GraphBuilder::new();
2367 builder.add_parent(lhs.graph.clone());
2368 builder.add_parent(rhs.graph.clone());
2369 let outputs = builder.add_operation(op, vec![lhs_ref, rhs_ref], OperationRole::Primary);
2370 builder.set_outputs(outputs.clone());
2371 let graph = Arc::new(builder.build());
2372 let metadata_scope =
2373 register_single_output_metadata(graph.as_ref(), outputs[0], out_dtype, &out_shape_hint);
2374
2375 let mut merged = (*lhs.inputs_map).clone();
2376 merged.extend(rhs.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
2377 let mut extra_roots = lhs.extra_roots.clone();
2378 extra_roots.extend(rhs.extra_roots.iter().cloned());
2379
2380 TracedTensor {
2381 id: next_traced_id(),
2382 rank: out_rank,
2383 dtype: out_dtype,
2384 graph,
2385 val: outputs[0],
2386 data: None,
2387 shape_hint: out_shape_hint,
2388 inputs_map: Arc::new(merged),
2389 extra_roots,
2390 checkpoint_chain: CheckpointNode::merge_chains(
2391 lhs.checkpoint_chain.clone(),
2392 rhs.checkpoint_chain.clone(),
2393 ),
2394 metadata_scopes: metadata_scopes_with_new(
2395 metadata_scope,
2396 [
2397 lhs.metadata_scopes.as_slice(),
2398 rhs.metadata_scopes.as_slice(),
2399 ],
2400 ),
2401 }
2402}
2403
2404fn apply_ternary_with_output_dtype(
2405 op: StdTensorOp,
2406 first: &TracedTensor,
2407 second: &TracedTensor,
2408 third: &TracedTensor,
2409 out_rank: usize,
2410 out_shape_hint: Option<Vec<SymDim>>,
2411 out_dtype: DType,
2412) -> TracedTensor {
2413 let first_ref = ValueRef::External(first.graph.values()[first.val].key.clone());
2414 let second_ref = ValueRef::External(second.graph.values()[second.val].key.clone());
2415 let third_ref = ValueRef::External(third.graph.values()[third.val].key.clone());
2416
2417 let mut builder = GraphBuilder::new();
2418 builder.add_parent(first.graph.clone());
2419 builder.add_parent(second.graph.clone());
2420 builder.add_parent(third.graph.clone());
2421 let outputs = builder.add_operation(
2422 op,
2423 vec![first_ref, second_ref, third_ref],
2424 OperationRole::Primary,
2425 );
2426 builder.set_outputs(outputs.clone());
2427 let graph = Arc::new(builder.build());
2428 let metadata_scope =
2429 register_single_output_metadata(graph.as_ref(), outputs[0], out_dtype, &out_shape_hint);
2430
2431 let mut merged = (*first.inputs_map).clone();
2432 merged.extend(
2433 second
2434 .inputs_map
2435 .iter()
2436 .map(|(k, v)| (k.clone(), v.clone())),
2437 );
2438 merged.extend(third.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
2439
2440 let mut extra_roots = first.extra_roots.clone();
2441 extra_roots.extend(second.extra_roots.iter().cloned());
2442 extra_roots.extend(third.extra_roots.iter().cloned());
2443
2444 let checkpoint_chain = CheckpointNode::merge_chains(
2445 CheckpointNode::merge_chains(
2446 first.checkpoint_chain.clone(),
2447 second.checkpoint_chain.clone(),
2448 ),
2449 third.checkpoint_chain.clone(),
2450 );
2451
2452 TracedTensor {
2453 id: next_traced_id(),
2454 rank: out_rank,
2455 dtype: out_dtype,
2456 graph,
2457 val: outputs[0],
2458 data: None,
2459 shape_hint: out_shape_hint,
2460 inputs_map: Arc::new(merged),
2461 extra_roots,
2462 checkpoint_chain,
2463 metadata_scopes: metadata_scopes_with_new(
2464 metadata_scope,
2465 [
2466 first.metadata_scopes.as_slice(),
2467 second.metadata_scopes.as_slice(),
2468 third.metadata_scopes.as_slice(),
2469 ],
2470 ),
2471 }
2472}
2473
2474fn register_single_output_metadata(
2475 graph: &Graph<StdTensorOp>,
2476 output: LocalValueId,
2477 dtype: DType,
2478 shape_hint: &Option<Vec<SymDim>>,
2479) -> GlobalMetadataScope {
2480 if let Some(shape) = shape_hint {
2481 register_scoped_value_metadata(
2484 graph.values()[output].key.clone(),
2485 tensor_meta(dtype, shape.clone()),
2486 )
2487 .expect("fresh traced graph output metadata registration failed")
2488 } else {
2489 register_scoped_graph_metadata(graph, std::iter::empty())
2492 .expect("fresh traced graph metadata registration failed")
2493 }
2494}
2495
2496impl TracedTensor {
2497 pub(crate) fn resolve_roots(&self) -> Vec<Arc<Graph<StdTensorOp>>> {
2498 let mut roots = Vec::with_capacity(1 + self.extra_roots.len());
2499 roots.push(self.graph.clone());
2500 roots.extend(self.extra_roots.iter().cloned());
2501 roots
2502 }
2503}
2504
2505#[cfg(test)]
2506mod tests;