Skip to main content

tidu/
linearized_graph.rs

1use computegraph::graph::Graph;
2use computegraph::{GraphOperation, LocalValueId};
3
4/// Graph produced by linearizing a primitive computation graph.
5pub struct LinearizedGraph<Op: GraphOperation> {
6    graph: Graph<Op>,
7    tangent_inputs: Vec<(Op::InputKey, LocalValueId)>,
8    tangent_outputs: Vec<Option<LocalValueId>>,
9}
10
11impl<Op: GraphOperation> LinearizedGraph<Op> {
12    pub(crate) fn from_parts(
13        graph: Graph<Op>,
14        tangent_inputs: Vec<(Op::InputKey, LocalValueId)>,
15        tangent_outputs: Vec<Option<LocalValueId>>,
16    ) -> Self {
17        Self {
18            graph,
19            tangent_inputs,
20            tangent_outputs,
21        }
22    }
23
24    /// Borrow the lower-level graph representation.
25    pub fn as_graph(&self) -> &Graph<Op> {
26        &self.graph
27    }
28
29    /// Consume this value and return the lower-level graph representation.
30    pub fn into_graph(self) -> Graph<Op> {
31        self.graph
32    }
33
34    /// Tangent input keys and local value ids.
35    pub fn tangent_inputs(&self) -> &[(Op::InputKey, LocalValueId)] {
36        &self.tangent_inputs
37    }
38
39    /// Tangent outputs aligned with requested primal outputs.
40    pub fn tangent_outputs(&self) -> &[Option<LocalValueId>] {
41        &self.tangent_outputs
42    }
43}