Skip to main content

Recorder

Struct Recorder 

Source
pub struct Recorder<K> { /* private fields */ }
Expand description

Stateful eager operation recorder.

Implementations§

Source§

impl<K> Recorder<K>

Source

pub fn new(key_source: K) -> Self

Create a recorder from a downstream key source.

Source

pub fn key_source_mut(&mut self) -> &mut K

Borrow the underlying key source.

Source

pub fn into_key_source(self) -> K

Return the underlying key source.

Source

pub fn fresh_input_keys<Op>(&mut self, count: usize) -> Vec<Op::InputKey>
where Op: GraphOperation, K: KeySource<Op>,

Return fresh graph input keys for one eager graph invocation.

§Examples
use tidu::eager::{KeySource, Recorder};
use computegraph::GraphOperation;

#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Op { Id }

impl GraphOperation for Op {
    type Operand = f64;
    type Context = ();
    type InputKey = usize;

    fn input_count(&self) -> usize { 1 }
    fn output_count(&self) -> usize { 1 }
}

struct Keys(usize);

impl KeySource<Op> for Keys {
    fn fresh_input_key(&mut self) -> usize {
        let key = self.0;
        self.0 += 1;
        key
    }
}

let mut recorder = Recorder::new(Keys(0));
assert_eq!(recorder.fresh_input_keys::<Op>(2), vec![0, 1]);
Source

pub fn record_graph<Op>( &mut self, graph: RecordedGraph<Op>, inputs: &[EagerInput<Op>], outputs: &[Arc<Op::Operand>], retained_values: HashMap<ValueKey<Op>, Arc<Op::Operand>>, ) -> Vec<EagerOutput<Op>>
where Op: Primitive, Op::InputKey: ADKey, K: KeySource<Op>,

Record a concrete eager graph invocation for reverse-mode AD.

§Examples
use std::collections::HashMap;
use std::sync::Arc;
use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueKey};
use tidu::{
    ADKey, DiffPassId, Primitive, PrimitiveBuilder, PrimitiveValue,
};
use tidu::eager::{EagerInput, KeySource, RecordedGraph, Recorder};

#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Key {
    User(&'static str),
    Generated(usize),
    Tangent(Box<Key>, DiffPassId),
}

impl ADKey for Key {
    fn tangent_of(&self, pass: DiffPassId) -> Self {
        Self::Tangent(Box::new(self.clone()), pass)
    }
}

#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Op { Id }

impl GraphOperation for Op {
    type Operand = f64;
    type Context = ();
    type InputKey = Key;

    fn input_count(&self) -> usize { 1 }
    fn output_count(&self) -> usize { 1 }
}

impl Primitive for Op {
    type ADContext = ();
    fn add() -> Self { Self::Id }

    fn jvp_rule(
        &self,
        _builder: &mut impl PrimitiveBuilder<Self>,
        _primal_in: &[ValueKey<Self>],
        _primal_out: &[ValueKey<Self>],
        tangent_in: &[Option<LocalValueId>],
        _ctx: &mut (),
    ) -> Vec<Option<LocalValueId>> {
        vec![tangent_in[0]]
    }

    fn transpose_rule(
        &self,
        _builder: &mut impl PrimitiveBuilder<Self>,
        cotangent_out: &[Option<LocalValueId>],
        _inputs: &[PrimitiveValue<Self>],
        _role: &OperationRole,
        _ctx: &mut (),
    ) -> Vec<Option<LocalValueId>> {
        vec![cotangent_out[0]]
    }
}

struct Keys(usize);
impl KeySource<Op> for Keys {
    fn fresh_input_key(&mut self) -> Key {
        let key = Key::Generated(self.0);
        self.0 += 1;
        key
    }
}

let mut recorder = Recorder::new(Keys(0));
let graph_inputs = recorder.fresh_input_keys::<Op>(1);
let graph = RecordedGraph::from_primitive(Op::Id, graph_inputs);
let input = EagerInput {
    key: ValueKey::Input(Key::User("x")),
    trace: None,
    requires_grad: true,
    data: Arc::new(2.0),
};

let outputs = recorder.record_graph(
    graph,
    &[input],
    &[Arc::new(2.0)],
    HashMap::new(),
);
assert!(outputs[0].trace.is_some());

Auto Trait Implementations§

§

impl<K> Freeze for Recorder<K>
where K: Freeze,

§

impl<K> RefUnwindSafe for Recorder<K>
where K: RefUnwindSafe,

§

impl<K> Send for Recorder<K>
where K: Send,

§

impl<K> Sync for Recorder<K>
where K: Sync,

§

impl<K> Unpin for Recorder<K>
where K: Unpin,

§

impl<K> UnsafeUnpin for Recorder<K>
where K: UnsafeUnpin,

§

impl<K> UnwindSafe for Recorder<K>
where K: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.