Skip to main content

differentiate

Function differentiate 

Source
pub fn differentiate<Op: PrimitiveOp>(
    view: &ResolvedView<Op>,
    outputs: &[GlobalValKey<Op>],
    wrt: &[Op::InputKey],
    pass: DiffPassId,
    ctx: &mut Op::ADContext,
    aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
) -> LinearFragment<Op>
where Op::InputKey: ADKey,
Expand description

Differentiate a resolved computation graph, producing a linear fragment.

The transform walks the reachable DAG from outputs in dependency-first order and delegates primitive-specific JVP generation to [chainrules::PrimitiveOp::linearize].

§Examples

use computegraph::resolve::resolve;
use tidu::differentiate;

let view = resolve(vec![primal_fragment]);
let mut ctx = ();
let aliases = std::collections::HashMap::new();
let linear = differentiate(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases);
assert_eq!(linear.tangent_outputs.len(), 1);