Skip to main content

tidu/
linear_fragment.rs

1use computegraph::fragment::Fragment;
2use computegraph::{GraphOp, LocalValId};
3
4/// A linear fragment produced by [`crate::differentiate`] or [`crate::transpose`].
5///
6/// # Examples
7///
8/// ```ignore
9/// use computegraph::resolve::resolve;
10/// use tidu::differentiate;
11///
12/// let view = resolve(vec![primal_fragment]);
13/// let mut ctx = ();
14/// let linear = differentiate(&view, &[output_key], &[input_key], 1, &mut ctx);
15/// assert_eq!(linear.tangent_inputs.len(), 1);
16/// ```
17pub struct LinearFragment<Op: GraphOp> {
18    /// The fragment containing linear ops.
19    pub fragment: Fragment<Op>,
20    /// `(primal_input_key, tangent_local_val_id)` pairs.
21    pub tangent_inputs: Vec<(Op::InputKey, LocalValId)>,
22    /// Tangent outputs, aligned with the requested outputs of the source transform.
23    /// `None` means the corresponding output is inactive.
24    pub tangent_outputs: Vec<Option<LocalValId>>,
25}