pub struct AdContext { /* private fields */ }Expand description
Explicit automatic-differentiation context.
AdContext owns the extension AD rules used by traced AD transforms.
§Examples
use tenferro_ad::AdContext;
let ad = AdContext::builder().build().unwrap();
assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());Implementations§
Source§impl AdContext
impl AdContext
Sourcepub fn builder() -> AdContextBuilder
pub fn builder() -> AdContextBuilder
Start building an explicit AD context.
§Examples
use tenferro_ad::AdContext;
let _builder = AdContext::builder();Sourcepub fn extension_rules(&self) -> &ExtensionRuleSet
pub fn extension_rules(&self) -> &ExtensionRuleSet
Return the extension rules owned by this context.
§Examples
use tenferro_ad::AdContext;
let ad = AdContext::builder().build().unwrap();
assert!(!ad.extension_rules().is_rule_registered("example.missing.v1"));Sourcepub fn grad(
&self,
output: &TracedTensor,
wrt: &TracedTensor,
) -> Result<TracedTensor>
pub fn grad( &self, output: &TracedTensor, wrt: &TracedTensor, ) -> Result<TracedTensor>
Gradient of a scalar traced output with respect to a traced input.
For complex scalar outputs, tenferro returns the Hermitian-adjoint
cotangent. To compare seed-1 scalar gradients with JAX’s public
grad values, use the complex conjugate of this result. See
https://tensor4all.org/tenferro-rs/guides/complex-ad.html.
§Examples
use tenferro_ad::AdContext;
use tenferro_runtime::TracedTensor;
let ad = AdContext::builder().build().unwrap();
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let loss = (&x * &x).unwrap();
let grad = ad.grad(&loss, &x).unwrap();
assert_eq!(grad.rank, 0);Sourcepub fn grad_optional(
&self,
output: &TracedTensor,
wrt: &TracedTensor,
) -> Result<Option<TracedTensor>>
pub fn grad_optional( &self, output: &TracedTensor, wrt: &TracedTensor, ) -> Result<Option<TracedTensor>>
Gradient that returns None when wrt is inactive.
§Examples
use tenferro_ad::AdContext;
use tenferro_runtime::TracedTensor;
let ad = AdContext::builder().build().unwrap();
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let loss = (&x * &x).unwrap();
assert!(ad.grad_optional(&loss, &x).unwrap().is_some());Sourcepub fn jvp(
&self,
output: &TracedTensor,
wrt: &TracedTensor,
tangent: &TracedTensor,
) -> Result<TracedTensor>
pub fn jvp( &self, output: &TracedTensor, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<TracedTensor>
Forward-mode Jacobian-vector product.
§Examples
use tenferro_ad::AdContext;
use tenferro_runtime::TracedTensor;
let ad = AdContext::builder().build().unwrap();
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let dx = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let y = (&x * &x).unwrap();
let dy = ad.jvp(&y, &x, &dx).unwrap();
assert_eq!(dy.rank, 0);Sourcepub fn jvp_optional(
&self,
output: &TracedTensor,
wrt: &TracedTensor,
tangent: &TracedTensor,
) -> Result<Option<TracedTensor>>
pub fn jvp_optional( &self, output: &TracedTensor, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<Option<TracedTensor>>
Forward-mode Jacobian-vector product that returns None for inactive output.
§Examples
use tenferro_ad::AdContext;
use tenferro_runtime::TracedTensor;
let ad = AdContext::builder().build().unwrap();
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let dx = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let y = (&x * &x).unwrap();
assert!(ad.jvp_optional(&y, &x, &dx).unwrap().is_some());Sourcepub fn vjp(
&self,
output: &TracedTensor,
wrt: &TracedTensor,
cotangent: &TracedTensor,
) -> Result<TracedTensor>
pub fn vjp( &self, output: &TracedTensor, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<TracedTensor>
Reverse-mode vector-Jacobian product.
Complex cotangents use tenferro’s Hermitian real-inner-product convention. Non-real complex cotangent seeds therefore need an explicit seed-convention comparison when matching JAX. See https://tensor4all.org/tenferro-rs/guides/complex-ad.html.
§Examples
use tenferro_ad::AdContext;
use tenferro_runtime::TracedTensor;
let ad = AdContext::builder().build().unwrap();
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let dy = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let y = (&x * &x).unwrap();
let dx = ad.vjp(&y, &x, &dy).unwrap();
assert_eq!(dx.rank, 0);Sourcepub fn vjp_optional(
&self,
output: &TracedTensor,
wrt: &TracedTensor,
cotangent: &TracedTensor,
) -> Result<Option<TracedTensor>>
pub fn vjp_optional( &self, output: &TracedTensor, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<Option<TracedTensor>>
Reverse-mode vector-Jacobian product that returns None for inactive input.
§Examples
use tenferro_ad::AdContext;
use tenferro_runtime::TracedTensor;
let ad = AdContext::builder().build().unwrap();
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let dy = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let y = (&x * &x).unwrap();
assert!(ad.vjp_optional(&y, &x, &dy).unwrap().is_some());Trait Implementations§
Auto Trait Implementations§
impl Freeze for AdContext
impl !RefUnwindSafe for AdContext
impl Send for AdContext
impl Sync for AdContext
impl Unpin for AdContext
impl UnsafeUnpin for AdContext
impl !UnwindSafe for AdContext
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> Twhere
Self: Distribution<T>,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more