pub struct Recorder<K> { /* private fields */ }Expand description
Stateful eager operation recorder.
Implementations§
Source§impl<K> Recorder<K>
impl<K> Recorder<K>
Sourcepub fn key_source_mut(&mut self) -> &mut K
pub fn key_source_mut(&mut self) -> &mut K
Borrow the underlying key source.
Sourcepub fn into_key_source(self) -> K
pub fn into_key_source(self) -> K
Return the underlying key source.
Sourcepub fn fresh_input_keys<Op>(&mut self, count: usize) -> Vec<Op::InputKey>where
Op: GraphOperation,
K: KeySource<Op>,
pub fn fresh_input_keys<Op>(&mut self, count: usize) -> Vec<Op::InputKey>where
Op: GraphOperation,
K: KeySource<Op>,
Return fresh graph input keys for one eager graph invocation.
§Examples
use tidu::eager::{KeySource, Recorder};
use computegraph::GraphOperation;
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Op { Id }
impl GraphOperation for Op {
type Operand = f64;
type Context = ();
type InputKey = usize;
fn input_count(&self) -> usize { 1 }
fn output_count(&self) -> usize { 1 }
}
struct Keys(usize);
impl KeySource<Op> for Keys {
fn fresh_input_key(&mut self) -> usize {
let key = self.0;
self.0 += 1;
key
}
}
let mut recorder = Recorder::new(Keys(0));
assert_eq!(recorder.fresh_input_keys::<Op>(2), vec![0, 1]);Sourcepub fn record_graph<Op>(
&mut self,
graph: RecordedGraph<Op>,
inputs: &[EagerInput<Op>],
outputs: &[Arc<Op::Operand>],
retained_values: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
) -> Vec<EagerOutput<Op>>
pub fn record_graph<Op>( &mut self, graph: RecordedGraph<Op>, inputs: &[EagerInput<Op>], outputs: &[Arc<Op::Operand>], retained_values: HashMap<ValueKey<Op>, Arc<Op::Operand>>, ) -> Vec<EagerOutput<Op>>
Record a concrete eager graph invocation for reverse-mode AD.
§Examples
use std::collections::HashMap;
use std::sync::Arc;
use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueKey};
use tidu::{
ADKey, DiffPassId, Primitive, PrimitiveBuilder, PrimitiveValue,
};
use tidu::eager::{EagerInput, KeySource, RecordedGraph, Recorder};
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Key {
User(&'static str),
Generated(usize),
Tangent(Box<Key>, DiffPassId),
}
impl ADKey for Key {
fn tangent_of(&self, pass: DiffPassId) -> Self {
Self::Tangent(Box::new(self.clone()), pass)
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Op { Id }
impl GraphOperation for Op {
type Operand = f64;
type Context = ();
type InputKey = Key;
fn input_count(&self) -> usize { 1 }
fn output_count(&self) -> usize { 1 }
}
impl Primitive for Op {
type ADContext = ();
fn add() -> Self { Self::Id }
fn jvp_rule(
&self,
_builder: &mut impl PrimitiveBuilder<Self>,
_primal_in: &[ValueKey<Self>],
_primal_out: &[ValueKey<Self>],
tangent_in: &[Option<LocalValueId>],
_ctx: &mut (),
) -> Vec<Option<LocalValueId>> {
vec![tangent_in[0]]
}
fn transpose_rule(
&self,
_builder: &mut impl PrimitiveBuilder<Self>,
cotangent_out: &[Option<LocalValueId>],
_inputs: &[PrimitiveValue<Self>],
_role: &OperationRole,
_ctx: &mut (),
) -> Vec<Option<LocalValueId>> {
vec![cotangent_out[0]]
}
}
struct Keys(usize);
impl KeySource<Op> for Keys {
fn fresh_input_key(&mut self) -> Key {
let key = Key::Generated(self.0);
self.0 += 1;
key
}
}
let mut recorder = Recorder::new(Keys(0));
let graph_inputs = recorder.fresh_input_keys::<Op>(1);
let graph = RecordedGraph::from_primitive(Op::Id, graph_inputs);
let input = EagerInput {
key: ValueKey::Input(Key::User("x")),
trace: None,
requires_grad: true,
data: Arc::new(2.0),
};
let outputs = recorder.record_graph(
graph,
&[input],
&[Arc::new(2.0)],
HashMap::new(),
);
assert!(outputs[0].trace.is_some());Auto Trait Implementations§
impl<K> Freeze for Recorder<K>where
K: Freeze,
impl<K> RefUnwindSafe for Recorder<K>where
K: RefUnwindSafe,
impl<K> Send for Recorder<K>where
K: Send,
impl<K> Sync for Recorder<K>where
K: Sync,
impl<K> Unpin for Recorder<K>where
K: Unpin,
impl<K> UnsafeUnpin for Recorder<K>where
K: UnsafeUnpin,
impl<K> UnwindSafe for Recorder<K>where
K: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more