pub fn linearize<Op: Primitive>(
view: &ResolvedView<Op>,
outputs: &[ValueKey<Op>],
wrt: &[Op::InputKey],
pass: DiffPassId,
ctx: &mut Op::ADContext,
aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
) -> LinearizedGraph<Op>where
Op::InputKey: ADKey,Expand description
Linearize a resolved computation graph, producing a linear graph.
The transform walks the reachable DAG from outputs in dependency-first
order and delegates primitive-specific JVP generation to
crate::Primitive::try_jvp_rule.
§Examples
ⓘ
use computegraph::resolve::resolve;
use tidu::try_linearize;
let view = resolve(vec![primal_graph]);
let mut ctx = ();
let aliases = std::collections::HashMap::new();
let linear = try_linearize(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases)?;
assert_eq!(linear.tangent_outputs().len(), 1);