Skip to main content

tidu/rules/
primitive_builder.rs

1use computegraph::graph::GraphBuilder;
2use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueRef};
3
4/// Reference to a value available to a primitive AD rule.
5#[derive(Clone, Debug, PartialEq, Eq, Hash)]
6pub enum PrimitiveValue<Op: GraphOperation> {
7    /// Value produced inside the graph being built.
8    Local(LocalValueId),
9    /// Value from the source primitive computation graph.
10    External(computegraph::ValueKey<Op>),
11}
12
13impl<Op: GraphOperation> From<PrimitiveValue<Op>> for ValueRef<Op> {
14    fn from(value: PrimitiveValue<Op>) -> Self {
15        match value {
16            PrimitiveValue::Local(id) => ValueRef::Local(id),
17            PrimitiveValue::External(key) => ValueRef::External(key),
18        }
19    }
20}
21
22impl<Op: GraphOperation> From<ValueRef<Op>> for PrimitiveValue<Op> {
23    fn from(value: ValueRef<Op>) -> Self {
24        match value {
25            ValueRef::Local(id) => PrimitiveValue::Local(id),
26            ValueRef::External(key) => PrimitiveValue::External(key),
27        }
28    }
29}
30
31/// Builder used by primitive JVP and transpose rules to append primitive applications.
32pub trait PrimitiveBuilder<Op: GraphOperation> {
33    /// Add one primitive application and return local ids for its outputs.
34    fn add_primitive(
35        &mut self,
36        op: Op,
37        inputs: Vec<PrimitiveValue<Op>>,
38        role: OperationRole,
39    ) -> Vec<LocalValueId>;
40}
41
42pub(crate) struct GraphPrimitiveBuilder<'a, Op: GraphOperation> {
43    inner: &'a mut GraphBuilder<Op>,
44}
45
46impl<'a, Op: GraphOperation> GraphPrimitiveBuilder<'a, Op> {
47    pub(crate) fn new(inner: &'a mut GraphBuilder<Op>) -> Self {
48        Self { inner }
49    }
50}
51
52impl<Op: GraphOperation> PrimitiveBuilder<Op> for GraphPrimitiveBuilder<'_, Op> {
53    fn add_primitive(
54        &mut self,
55        op: Op,
56        inputs: Vec<PrimitiveValue<Op>>,
57        role: OperationRole,
58    ) -> Vec<LocalValueId> {
59        let inputs = inputs.into_iter().map(ValueRef::from).collect();
60        self.inner.add_operation(op, inputs, role)
61    }
62}