tidu/rules/
primitive_builder.rs1use computegraph::graph::GraphBuilder;
2use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueRef};
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
6pub enum PrimitiveValue<Op: GraphOperation> {
7 Local(LocalValueId),
9 External(computegraph::ValueKey<Op>),
11}
12
13impl<Op: GraphOperation> From<PrimitiveValue<Op>> for ValueRef<Op> {
14 fn from(value: PrimitiveValue<Op>) -> Self {
15 match value {
16 PrimitiveValue::Local(id) => ValueRef::Local(id),
17 PrimitiveValue::External(key) => ValueRef::External(key),
18 }
19 }
20}
21
22impl<Op: GraphOperation> From<ValueRef<Op>> for PrimitiveValue<Op> {
23 fn from(value: ValueRef<Op>) -> Self {
24 match value {
25 ValueRef::Local(id) => PrimitiveValue::Local(id),
26 ValueRef::External(key) => PrimitiveValue::External(key),
27 }
28 }
29}
30
31pub trait PrimitiveBuilder<Op: GraphOperation> {
33 fn add_primitive(
35 &mut self,
36 op: Op,
37 inputs: Vec<PrimitiveValue<Op>>,
38 role: OperationRole,
39 ) -> Vec<LocalValueId>;
40}
41
42pub(crate) struct GraphPrimitiveBuilder<'a, Op: GraphOperation> {
43 inner: &'a mut GraphBuilder<Op>,
44}
45
46impl<'a, Op: GraphOperation> GraphPrimitiveBuilder<'a, Op> {
47 pub(crate) fn new(inner: &'a mut GraphBuilder<Op>) -> Self {
48 Self { inner }
49 }
50}
51
52impl<Op: GraphOperation> PrimitiveBuilder<Op> for GraphPrimitiveBuilder<'_, Op> {
53 fn add_primitive(
54 &mut self,
55 op: Op,
56 inputs: Vec<PrimitiveValue<Op>>,
57 role: OperationRole,
58 ) -> Vec<LocalValueId> {
59 let inputs = inputs.into_iter().map(ValueRef::from).collect();
60 self.inner.add_operation(op, inputs, role)
61 }
62}