TrackedTensor

Struct TrackedTensor 

Source
pub struct TrackedTensor<V: Differentiable> { /* private fields */ }
Expand description

Value wrapper for reverse-mode AD.

Wraps any Differentiable value and connects it to a Tape for gradient computation.

Created via Tape::leaf for gradient-tracked values, or TrackedTensor::new for values that do not require gradients.

§Examples

use chainrules::{Tape, TrackedTensor};
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;

let tape = Tape::<Tensor<f64>>::new();
let a = tape.leaf(Tensor::ones(
    &[2, 3],
    LogicalMemorySpace::MainMemory,
    MemoryOrder::ColumnMajor,
));
assert!(a.requires_grad());

Implementations§

Source§

impl<V: Differentiable> TrackedTensor<V>

Source

pub fn new(value: V) -> Self

Creates a tracked value with requires_grad = false (no tape).

§Examples
use chainrules::TrackedTensor;
let x = TrackedTensor::new(value);
assert!(!x.requires_grad());
Source

pub fn value(&self) -> &V

Returns the underlying value.

§Examples
let v = tracked.value();
Source

pub fn into_value(self) -> V

Consumes and returns the underlying value.

§Examples
let raw = tracked.into_value();
Source

pub fn requires_grad(&self) -> bool

Returns whether this value participates in gradient propagation.

§Examples
assert!(tracked.requires_grad());
Source

pub fn node_id(&self) -> Option<NodeId>

Returns the graph node ID when this value is connected to a tape.

§Examples
if let Some(id) = tracked.node_id() {
    println!("node = {}", id.index());
}
Source

pub fn tangent(&self) -> Option<&V::Tangent>

Returns the tangent for HVP, or None if not set.

§Examples
if let Some(t) = tracked.tangent() {
    // use tangent
}
Source

pub fn has_tangent(&self) -> bool

Returns whether this tracked value has a tangent for HVP.

§Examples
assert!(tracked.has_tangent());
Source

pub fn detach(self) -> Self

Consumes and returns a detached value that does not require gradients.

§Examples
let detached = tracked.detach();
assert!(!detached.requires_grad());

Auto Trait Implementations§

§

impl<V> Freeze for TrackedTensor<V>
where V: Freeze, <V as Differentiable>::Tangent: Freeze,

§

impl<V> RefUnwindSafe for TrackedTensor<V>

§

impl<V> Send for TrackedTensor<V>
where V: Send, <V as Differentiable>::Tangent: Send,

§

impl<V> Sync for TrackedTensor<V>
where V: Sync, <V as Differentiable>::Tangent: Sync,

§

impl<V> Unpin for TrackedTensor<V>
where V: Unpin, <V as Differentiable>::Tangent: Unpin,

§

impl<V> UnwindSafe for TrackedTensor<V>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.