1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use computegraph::resolve::resolve;
5use computegraph::types::ValueKey;
6use tenferro_ops::input_key::TensorInputKey;
7use tenferro_ops::ExtensionRuleSet;
8use tenferro_ops::ShapeGuardContext;
9use tenferro_runtime::ad_support::{
10 checkpoint_chain as tensor_checkpoint_chain, checkpoint_tensor,
11 extra_roots as tensor_extra_roots, inputs_map as tensor_inputs_map, leaf_input_key,
12 linear_input_key, metadata_scopes as tensor_metadata_scopes, metadata_scopes_with_new,
13 ones_tensor, push_metadata_scope, register_scoped_graph_metadata, registered_meta,
14 resolve_roots as tensor_resolve_roots, shape_hint as tensor_shape_hint, tensor_from_parts,
15 tensor_meta_from_tensor, TracedTensorParts,
16};
17use tenferro_runtime::{Error, GraphCompiler, GraphExecutor, Result, TracedTensor};
18use tenferro_tensor::TensorBackend;
19use tidu::{linear_transpose, linearize, ADRuleError};
20
21static NEXT_DIFF_PASS_ID: AtomicU64 = AtomicU64::new(0);
22
23fn next_pass_id() -> u64 {
24 NEXT_DIFF_PASS_ID.fetch_add(1, Ordering::Relaxed)
25}
26
27pub(crate) fn next_input_key() -> TensorInputKey {
28 tenferro_runtime::ad_support::allocate_input_key()
29}
30
31fn error_shape_hint(tensor: &TracedTensor) -> Vec<usize> {
32 tensor
33 .try_concrete_shape()
34 .unwrap_or_else(|| vec![0; tensor.rank])
35}
36
37fn shape_guard_context(extension_rules: Option<&ExtensionRuleSet>) -> ShapeGuardContext {
38 let ctx = ShapeGuardContext::with_global_metadata();
39 match extension_rules {
40 Some(rules) => ctx.with_extension_rules(rules.clone()),
41 None => ctx,
42 }
43}
44
45fn ad_rule_error(transform: &'static str, err: ADRuleError) -> Error {
46 match err {
47 ADRuleError::Unsupported { op, .. } => {
48 Error::Internal(format!("unsupported {transform} AD rule for {op}"))
49 }
50 ADRuleError::InvalidInput { op, message, .. } => Error::InvalidGraphBuild {
51 op: transform,
52 message: format!("{op}: {message}"),
53 },
54 }
55}
56
57pub(crate) fn grad_with_rules(
58 output: &TracedTensor,
59 wrt: &TracedTensor,
60 extension_rules: &ExtensionRuleSet,
61) -> Result<TracedTensor> {
62 grad_with_optional_rules(output, wrt, Some(extension_rules))
63}
64
65pub(crate) fn jvp_with_rules(
66 output: &TracedTensor,
67 wrt: &TracedTensor,
68 tangent: &TracedTensor,
69 extension_rules: &ExtensionRuleSet,
70) -> Result<TracedTensor> {
71 let wrt_input_key = leaf_input_key(wrt)?;
72 jvp_optional_impl(output, wrt, tangent, Some(extension_rules))?
73 .ok_or_else(|| Error::Internal(format!("jvp output is inactive for {:?}", wrt_input_key)))
74}
75
76pub(crate) fn grad_optional_with_rules(
77 output: &TracedTensor,
78 wrt: &TracedTensor,
79 extension_rules: &ExtensionRuleSet,
80) -> Result<Option<TracedTensor>> {
81 if output.rank != 0 {
82 return Err(Error::NonScalarGrad {
83 shape: error_shape_hint(output),
84 });
85 }
86
87 let ones = ones_tensor(output.dtype, vec![])?;
88 let seed = TracedTensor::from_tensor_concrete_shape(ones)?;
89 vjp_optional_impl(output, wrt, &seed, Some(extension_rules))
90}
91
92pub(crate) fn jvp_optional_with_rules(
93 output: &TracedTensor,
94 wrt: &TracedTensor,
95 tangent: &TracedTensor,
96 extension_rules: &ExtensionRuleSet,
97) -> Result<Option<TracedTensor>> {
98 jvp_optional_impl(output, wrt, tangent, Some(extension_rules))
99}
100
101pub(crate) fn vjp_with_rules(
102 output: &TracedTensor,
103 wrt: &TracedTensor,
104 cotangent: &TracedTensor,
105 extension_rules: &ExtensionRuleSet,
106) -> Result<TracedTensor> {
107 let wrt_input_key = leaf_input_key(wrt)?;
108 vjp_optional_impl(output, wrt, cotangent, Some(extension_rules))?
109 .ok_or_else(|| Error::Internal(format!("vjp output is inactive for {:?}", wrt_input_key)))
110}
111
112pub(crate) fn vjp_optional_with_rules(
113 output: &TracedTensor,
114 wrt: &TracedTensor,
115 cotangent: &TracedTensor,
116 extension_rules: &ExtensionRuleSet,
117) -> Result<Option<TracedTensor>> {
118 vjp_optional_impl(output, wrt, cotangent, Some(extension_rules))
119}
120
121fn grad_with_optional_rules(
122 output: &TracedTensor,
123 wrt: &TracedTensor,
124 extension_rules: Option<&ExtensionRuleSet>,
125) -> Result<TracedTensor> {
126 if output.rank != 0 {
127 return Err(Error::NonScalarGrad {
128 shape: error_shape_hint(output),
129 });
130 }
131
132 let ones = ones_tensor(output.dtype, vec![])?;
133 let seed = TracedTensor::from_tensor_concrete_shape(ones)?;
134 let wrt_input_key = leaf_input_key(wrt)?;
135 vjp_optional_impl(output, wrt, &seed, extension_rules)?
136 .ok_or_else(|| Error::Internal(format!("grad output is inactive for {:?}", wrt_input_key)))
137}
138
139pub trait TracedTensorAdExt {
153 fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>;
181
182 fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>;
197
198 fn checkpoint<B: TensorBackend>(
219 &mut self,
220 compiler: &mut GraphCompiler,
221 executor: &mut GraphExecutor<B>,
222 ) -> Result<()>;
223
224 fn jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> Result<TracedTensor>;
248
249 fn jvp_optional(
265 &self,
266 wrt: &TracedTensor,
267 tangent: &TracedTensor,
268 ) -> Result<Option<TracedTensor>>;
269
270 fn vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> Result<TracedTensor>;
299
300 fn vjp_optional(
316 &self,
317 wrt: &TracedTensor,
318 cotangent: &TracedTensor,
319 ) -> Result<Option<TracedTensor>>;
320}
321
322impl TracedTensorAdExt for TracedTensor {
323 fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor> {
324 grad_with_optional_rules(self, wrt, None)
325 }
326
327 fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>> {
328 if self.rank != 0 {
329 return Err(Error::NonScalarGrad {
330 shape: error_shape_hint(self),
331 });
332 }
333
334 let ones = ones_tensor(self.dtype, vec![])?;
335 let seed = TracedTensor::from_tensor_concrete_shape(ones)?;
336 vjp_optional_impl(self, wrt, &seed, None)
337 }
338
339 fn checkpoint<B: TensorBackend>(
340 &mut self,
341 compiler: &mut GraphCompiler,
342 executor: &mut GraphExecutor<B>,
343 ) -> Result<()> {
344 let data = if let Some(data) = self.attached_data() {
345 Arc::clone(data)
346 } else {
347 let program = compiler.compile(self)?;
348 Arc::new(executor.run(&program)?)
349 };
350 checkpoint_tensor(self, data);
351 Ok(())
352 }
353
354 fn jvp(&self, wrt: &TracedTensor, tangent: &TracedTensor) -> Result<TracedTensor> {
355 let wrt_input_key = leaf_input_key(wrt)?;
356 self.jvp_optional(wrt, tangent)?.ok_or_else(|| {
357 Error::Internal(format!("jvp output is inactive for {:?}", wrt_input_key))
358 })
359 }
360
361 fn jvp_optional(
362 &self,
363 wrt: &TracedTensor,
364 tangent: &TracedTensor,
365 ) -> Result<Option<TracedTensor>> {
366 jvp_optional_impl(self, wrt, tangent, None)
367 }
368
369 fn vjp(&self, wrt: &TracedTensor, cotangent: &TracedTensor) -> Result<TracedTensor> {
370 let wrt_input_key = leaf_input_key(wrt)?;
371 self.vjp_optional(wrt, cotangent)?.ok_or_else(|| {
372 Error::Internal(format!("vjp output is inactive for {:?}", wrt_input_key))
373 })
374 }
375
376 fn vjp_optional(
377 &self,
378 wrt: &TracedTensor,
379 cotangent: &TracedTensor,
380 ) -> Result<Option<TracedTensor>> {
381 vjp_optional_impl(self, wrt, cotangent, None)
382 }
383}
384
385fn jvp_optional_impl(
386 output: &TracedTensor,
387 wrt: &TracedTensor,
388 tangent: &TracedTensor,
389 extension_rules: Option<&ExtensionRuleSet>,
390) -> Result<Option<TracedTensor>> {
391 let wrt_input_key = leaf_input_key(wrt)?;
392 let output_key = output.graph().values()[output.val].key.clone();
393 let checkpoint_chain = tensor_checkpoint_chain(output);
394 let aliases = checkpoint_chain
395 .as_ref()
396 .map(|chain| chain.collect_aliases())
397 .unwrap_or_default();
398 let checkpoint_graphs = checkpoint_chain
399 .as_ref()
400 .map(|chain| chain.collect_graphs())
401 .unwrap_or_default();
402 let mut roots = tensor_resolve_roots(output);
403 roots.extend(checkpoint_graphs.iter().cloned());
404 let view = resolve(roots);
405 let mut ad_ctx = shape_guard_context(extension_rules);
406 let linear = linearize(
407 &view,
408 std::slice::from_ref(&output_key),
409 std::slice::from_ref(&wrt_input_key),
410 next_pass_id(),
411 &mut ad_ctx,
412 &aliases,
413 )
414 .map_err(|err| ad_rule_error("jvp", err))?;
415 let Some(tangent_output) = linear.tangent_outputs()[0] else {
416 return Ok(None);
417 };
418 let tangent_input_key = linear_input_key(linear.as_graph(), linear.tangent_inputs()[0].1)?;
419 let tangent_data =
420 tangent
421 .attached_data()
422 .cloned()
423 .ok_or_else(|| Error::InvalidGraphBuild {
424 op: "jvp",
425 message: "jvp tangent must have concrete tensor data".to_string(),
426 })?;
427 let metadata_scope = register_scoped_graph_metadata(
428 linear.as_graph(),
429 vec![(
430 ValueKey::Input(tangent_input_key.clone()),
431 tensor_meta_from_tensor(tangent_data.as_ref()),
432 )],
433 )?;
434
435 let mut inputs_map = (*tensor_inputs_map(output)).clone();
436 if let Some(chain) = &checkpoint_chain {
437 inputs_map.extend(chain.collect_inputs());
438 }
439 inputs_map.insert(tangent_input_key, tangent_data);
440
441 let mut extra_roots = vec![Arc::clone(output.graph())];
442 extra_roots.extend(checkpoint_graphs);
443 extra_roots.extend(tensor_extra_roots(output));
444
445 Ok(Some(tensor_from_parts(TracedTensorParts {
446 rank: output.rank,
447 dtype: output.dtype,
448 graph: Arc::new(linear.into_graph()),
449 val: tangent_output,
450 data: None,
451 shape_hint: tensor_shape_hint(output),
452 inputs_map: Arc::new(inputs_map),
453 extra_roots,
454 checkpoint_chain,
455 metadata_scopes: metadata_scopes_with_new(
456 metadata_scope,
457 [
458 tensor_metadata_scopes(output),
459 tensor_metadata_scopes(wrt),
460 tensor_metadata_scopes(tangent),
461 ],
462 ),
463 })))
464}
465
466fn vjp_optional_impl(
467 output: &TracedTensor,
468 wrt: &TracedTensor,
469 cotangent: &TracedTensor,
470 extension_rules: Option<&ExtensionRuleSet>,
471) -> Result<Option<TracedTensor>> {
472 let wrt_input_key = leaf_input_key(wrt)?;
473 let output_key = output.graph().values()[output.val].key.clone();
474 let checkpoint_chain = tensor_checkpoint_chain(output);
475 let aliases = checkpoint_chain
476 .as_ref()
477 .map(|chain| chain.collect_aliases())
478 .unwrap_or_default();
479 let checkpoint_graphs = checkpoint_chain
480 .as_ref()
481 .map(|chain| chain.collect_graphs())
482 .unwrap_or_default();
483 let mut roots = tensor_resolve_roots(output);
484 roots.extend(checkpoint_graphs.iter().cloned());
485 let view = resolve(roots);
486 let mut ad_ctx = shape_guard_context(extension_rules);
487 let linear = linearize(
488 &view,
489 std::slice::from_ref(&output_key),
490 std::slice::from_ref(&wrt_input_key),
491 next_pass_id(),
492 &mut ad_ctx,
493 &aliases,
494 )
495 .map_err(|err| ad_rule_error("vjp", err))?;
496 if linear.tangent_outputs()[0].is_none() {
497 return Ok(None);
498 }
499 let linear_seed_key = linear_input_key(linear.as_graph(), linear.tangent_inputs()[0].1)?;
500 let linear_metadata_scope = register_scoped_graph_metadata(
501 linear.as_graph(),
502 vec![(
503 ValueKey::Input(linear_seed_key),
504 registered_meta(&wrt.graph().values()[wrt.val].key)?,
505 )],
506 )?;
507 ad_ctx.refresh_global_metadata();
508 let transposed =
509 linear_transpose(&linear, &mut ad_ctx).map_err(|err| ad_rule_error("vjp", err))?;
510 let cotangent_input_key =
511 linear_input_key(transposed.as_graph(), transposed.tangent_inputs()[0].1)?;
512 let cotangent_data =
513 cotangent
514 .attached_data()
515 .cloned()
516 .ok_or_else(|| Error::InvalidGraphBuild {
517 op: "vjp",
518 message: "vjp cotangent must have concrete tensor data".to_string(),
519 })?;
520 let transposed_metadata_scope = register_scoped_graph_metadata(
521 transposed.as_graph(),
522 vec![(
523 ValueKey::Input(cotangent_input_key.clone()),
524 tensor_meta_from_tensor(cotangent_data.as_ref()),
525 )],
526 )?;
527 let linear_graph = Arc::new(linear.into_graph());
528 let Some(cotangent_output) = transposed.tangent_outputs()[0] else {
529 return Ok(None);
530 };
531
532 let mut inputs_map = (*tensor_inputs_map(output)).clone();
533 if let Some(chain) = &checkpoint_chain {
534 inputs_map.extend(chain.collect_inputs());
535 }
536 inputs_map.insert(cotangent_input_key.clone(), cotangent_data);
537
538 let mut extra_roots = vec![Arc::clone(output.graph()), linear_graph];
539 extra_roots.extend(checkpoint_graphs);
540 extra_roots.extend(tensor_extra_roots(output));
541
542 Ok(Some(tensor_from_parts(TracedTensorParts {
543 rank: wrt.rank,
544 dtype: wrt.dtype,
545 graph: Arc::new(transposed.into_graph()),
546 val: cotangent_output,
547 data: None,
548 shape_hint: tensor_shape_hint(wrt),
549 inputs_map: Arc::new(inputs_map),
550 extra_roots,
551 checkpoint_chain,
552 metadata_scopes: {
553 let mut scopes = metadata_scopes_with_new(
554 linear_metadata_scope,
555 [
556 tensor_metadata_scopes(output),
557 tensor_metadata_scopes(wrt),
558 tensor_metadata_scopes(cotangent),
559 ],
560 );
561 push_metadata_scope(&mut scopes, Arc::new(transposed_metadata_scope));
562 scopes
563 },
564 })))
565}