pub fn differentiate<Op: PrimitiveOp>(
view: &ResolvedView<Op>,
outputs: &[GlobalValKey<Op>],
wrt: &[Op::InputKey],
pass: DiffPassId,
ctx: &mut Op::ADContext,
aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
) -> LinearFragment<Op>where
Op::InputKey: ADKey,Expand description
Differentiate a resolved computation graph, producing a linear fragment.
The transform walks the reachable DAG from outputs in dependency-first
order and delegates primitive-specific JVP generation to
[chainrules::PrimitiveOp::linearize].
§Examples
ⓘ
use computegraph::resolve::resolve;
use tidu::differentiate;
let view = resolve(vec![primal_fragment]);
let mut ctx = ();
let aliases = std::collections::HashMap::new();
let linear = differentiate(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases);
assert_eq!(linear.tangent_outputs.len(), 1);