ad_tensors_rs/
ad_value.rs

1use tenferro_algebra::Scalar;
2use tenferro_tensor::Tensor;
3
4/// Automatic differentiation mode.
5///
6/// # Examples
7///
8/// ```rust
9/// use ad_tensors_rs::AdMode;
10///
11/// assert_eq!(AdMode::Primal, AdMode::Primal);
12/// ```
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AdMode {
15    /// Plain evaluation without derivative propagation.
16    Primal,
17    /// Forward-mode value carrying tangent information.
18    Forward,
19    /// Reverse-mode value carrying graph metadata.
20    Reverse,
21}
22
23/// Opaque identifier of a reverse-mode graph node.
24///
25/// # Examples
26///
27/// ```rust
28/// use ad_tensors_rs::NodeId;
29///
30/// let node = NodeId(7);
31/// assert_eq!(node.0, 7);
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub struct NodeId(pub u64);
35
36/// Opaque identifier of a tape instance.
37///
38/// # Examples
39///
40/// ```rust
41/// use ad_tensors_rs::TapeId;
42///
43/// let tape = TapeId(2);
44/// assert_eq!(tape.0, 2);
45/// ```
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct TapeId(pub u64);
48
49/// Generic AD value that can wrap any user-defined payload type `T`.
50///
51/// This is the primary extension point of the crate.
52///
53/// # Examples
54///
55/// ```rust
56/// use ad_tensors_rs::{AdMode, AdValue, NodeId, TapeId};
57///
58/// let primal = AdValue::primal(3.0_f64);
59/// assert_eq!(primal.mode(), AdMode::Primal);
60///
61/// let dual = AdValue::forward(3.0_f64, 1.0_f64);
62/// assert_eq!(dual.mode(), AdMode::Forward);
63///
64/// let tracked = AdValue::reverse(3.0_f64, NodeId(1), TapeId(9), None);
65/// assert_eq!(tracked.mode(), AdMode::Reverse);
66/// ```
67#[derive(Debug, Clone, PartialEq)]
68pub enum AdValue<T> {
69    /// Primal-only value.
70    Primal(T),
71    /// Forward-mode value and tangent.
72    Forward { primal: T, tangent: T },
73    /// Reverse-mode value with graph metadata.
74    Reverse {
75        primal: T,
76        node: NodeId,
77        tape: TapeId,
78        tangent: Option<T>,
79    },
80}
81
82impl<T> AdValue<T> {
83    /// Creates a primal-only value.
84    ///
85    /// # Examples
86    ///
87    /// ```rust
88    /// use ad_tensors_rs::AdValue;
89    ///
90    /// let x = AdValue::primal(2_i32);
91    /// assert!(matches!(x, AdValue::Primal(2)));
92    /// ```
93    pub fn primal(value: T) -> Self {
94        Self::Primal(value)
95    }
96
97    /// Creates a forward-mode value.
98    ///
99    /// # Examples
100    ///
101    /// ```rust
102    /// use ad_tensors_rs::AdValue;
103    ///
104    /// let x = AdValue::forward(2.0_f64, 1.0_f64);
105    /// assert!(matches!(x, AdValue::Forward { .. }));
106    /// ```
107    pub fn forward(primal: T, tangent: T) -> Self {
108        Self::Forward { primal, tangent }
109    }
110
111    /// Creates a reverse-mode value.
112    ///
113    /// # Examples
114    ///
115    /// ```rust
116    /// use ad_tensors_rs::{AdValue, NodeId, TapeId};
117    ///
118    /// let x = AdValue::reverse(2.0_f64, NodeId(3), TapeId(5), Some(0.1));
119    /// assert!(matches!(x, AdValue::Reverse { .. }));
120    /// ```
121    pub fn reverse(primal: T, node: NodeId, tape: TapeId, tangent: Option<T>) -> Self {
122        Self::Reverse {
123            primal,
124            node,
125            tape,
126            tangent,
127        }
128    }
129
130    /// Returns the AD mode.
131    ///
132    /// # Examples
133    ///
134    /// ```rust
135    /// use ad_tensors_rs::{AdMode, AdValue};
136    ///
137    /// let x = AdValue::forward(1.0_f64, 1.0_f64);
138    /// assert_eq!(x.mode(), AdMode::Forward);
139    /// ```
140    pub fn mode(&self) -> AdMode {
141        match self {
142            Self::Primal(_) => AdMode::Primal,
143            Self::Forward { .. } => AdMode::Forward,
144            Self::Reverse { .. } => AdMode::Reverse,
145        }
146    }
147
148    /// Returns a reference to the primal payload.
149    ///
150    /// # Examples
151    ///
152    /// ```rust
153    /// use ad_tensors_rs::AdValue;
154    ///
155    /// let x = AdValue::forward(10_i32, 1_i32);
156    /// assert_eq!(x.primal_ref(), &10);
157    /// ```
158    pub fn primal_ref(&self) -> &T {
159        match self {
160            Self::Primal(value) => value,
161            Self::Forward { primal, .. } => primal,
162            Self::Reverse { primal, .. } => primal,
163        }
164    }
165
166    /// Returns a mutable reference to the primal payload.
167    ///
168    /// # Examples
169    ///
170    /// ```rust
171    /// use ad_tensors_rs::AdValue;
172    ///
173    /// let mut x = AdValue::primal(1_i32);
174    /// *x.primal_mut() = 7;
175    /// assert_eq!(x.primal_ref(), &7);
176    /// ```
177    pub fn primal_mut(&mut self) -> &mut T {
178        match self {
179            Self::Primal(value) => value,
180            Self::Forward { primal, .. } => primal,
181            Self::Reverse { primal, .. } => primal,
182        }
183    }
184
185    /// Returns a reference to tangent payload when available.
186    ///
187    /// # Examples
188    ///
189    /// ```rust
190    /// use ad_tensors_rs::AdValue;
191    ///
192    /// let x = AdValue::forward(2.0_f64, 3.0_f64);
193    /// assert_eq!(x.tangent_ref(), Some(&3.0));
194    /// ```
195    pub fn tangent_ref(&self) -> Option<&T> {
196        match self {
197            Self::Primal(_) => None,
198            Self::Forward { tangent, .. } => Some(tangent),
199            Self::Reverse { tangent, .. } => tangent.as_ref(),
200        }
201    }
202
203    /// Returns reverse-mode node id when available.
204    ///
205    /// # Examples
206    ///
207    /// ```rust
208    /// use ad_tensors_rs::{AdValue, NodeId, TapeId};
209    ///
210    /// let x = AdValue::reverse(1.0_f64, NodeId(4), TapeId(6), None);
211    /// assert_eq!(x.node_id(), Some(NodeId(4)));
212    /// ```
213    pub fn node_id(&self) -> Option<NodeId> {
214        match self {
215            Self::Reverse { node, .. } => Some(*node),
216            _ => None,
217        }
218    }
219
220    /// Returns reverse-mode tape id when available.
221    ///
222    /// # Examples
223    ///
224    /// ```rust
225    /// use ad_tensors_rs::{AdValue, NodeId, TapeId};
226    ///
227    /// let x = AdValue::reverse(1.0_f64, NodeId(4), TapeId(6), None);
228    /// assert_eq!(x.tape_id(), Some(TapeId(6)));
229    /// ```
230    pub fn tape_id(&self) -> Option<TapeId> {
231        match self {
232            Self::Reverse { tape, .. } => Some(*tape),
233            _ => None,
234        }
235    }
236
237    /// Maps the payload type while preserving AD mode.
238    ///
239    /// # Examples
240    ///
241    /// ```rust
242    /// use ad_tensors_rs::AdValue;
243    ///
244    /// let x = AdValue::forward(2_i32, 3_i32);
245    /// let y = x.map(|v| v as f64);
246    /// assert_eq!(y.primal_ref(), &2.0_f64);
247    /// assert_eq!(y.tangent_ref(), Some(&3.0_f64));
248    /// ```
249    pub fn map<U>(self, mut f: impl FnMut(T) -> U) -> AdValue<U> {
250        match self {
251            Self::Primal(value) => AdValue::Primal(f(value)),
252            Self::Forward { primal, tangent } => AdValue::Forward {
253                primal: f(primal),
254                tangent: f(tangent),
255            },
256            Self::Reverse {
257                primal,
258                node,
259                tape,
260                tangent,
261            } => AdValue::Reverse {
262                primal: f(primal),
263                node,
264                tape,
265                tangent: tangent.map(f),
266            },
267        }
268    }
269}
270
271impl<T> From<T> for AdValue<T> {
272    fn from(value: T) -> Self {
273        Self::Primal(value)
274    }
275}
276
277/// Scalar newtype carrying AD mode information.
278///
279/// # Examples
280///
281/// ```rust
282/// use ad_tensors_rs::{AdMode, AdScalar};
283///
284/// let x: AdScalar<f64> = 2.0_f64.into();
285/// assert_eq!(x.mode(), AdMode::Primal);
286/// ```
287#[derive(Debug, Clone, PartialEq)]
288pub struct AdScalar<T>(pub AdValue<T>);
289
290impl<T> AdScalar<T> {
291    /// Creates a primal scalar.
292    ///
293    /// # Examples
294    ///
295    /// ```rust
296    /// use ad_tensors_rs::{AdMode, AdScalar};
297    ///
298    /// let x = AdScalar::new_primal(1.5_f64);
299    /// assert_eq!(x.mode(), AdMode::Primal);
300    /// ```
301    pub fn new_primal(value: T) -> Self {
302        Self(AdValue::primal(value))
303    }
304
305    /// Creates a forward-mode scalar.
306    ///
307    /// # Examples
308    ///
309    /// ```rust
310    /// use ad_tensors_rs::{AdMode, AdScalar};
311    ///
312    /// let x = AdScalar::new_forward(2.0_f64, 1.0_f64);
313    /// assert_eq!(x.mode(), AdMode::Forward);
314    /// ```
315    pub fn new_forward(primal: T, tangent: T) -> Self {
316        Self(AdValue::forward(primal, tangent))
317    }
318
319    /// Creates a reverse-mode scalar.
320    ///
321    /// # Examples
322    ///
323    /// ```rust
324    /// use ad_tensors_rs::{AdMode, AdScalar, NodeId, TapeId};
325    ///
326    /// let x = AdScalar::new_reverse(2.0_f64, NodeId(1), TapeId(2), Some(0.4));
327    /// assert_eq!(x.mode(), AdMode::Reverse);
328    /// ```
329    pub fn new_reverse(primal: T, node: NodeId, tape: TapeId, tangent: Option<T>) -> Self {
330        Self(AdValue::reverse(primal, node, tape, tangent))
331    }
332
333    /// Returns AD mode.
334    ///
335    /// # Examples
336    ///
337    /// ```rust
338    /// use ad_tensors_rs::{AdMode, AdScalar};
339    ///
340    /// let x = AdScalar::new_primal(2.0_f64);
341    /// assert_eq!(x.mode(), AdMode::Primal);
342    /// ```
343    pub fn mode(&self) -> AdMode {
344        self.0.mode()
345    }
346
347    /// Returns reference to underlying [`AdValue`].
348    ///
349    /// # Examples
350    ///
351    /// ```rust
352    /// use ad_tensors_rs::{AdScalar, AdValue};
353    ///
354    /// let x = AdScalar::new_primal(2.0_f64);
355    /// assert!(matches!(x.as_value(), AdValue::Primal(_)));
356    /// ```
357    pub fn as_value(&self) -> &AdValue<T> {
358        &self.0
359    }
360
361    /// Consumes wrapper and returns the underlying [`AdValue`].
362    ///
363    /// # Examples
364    ///
365    /// ```rust
366    /// use ad_tensors_rs::{AdScalar, AdValue};
367    ///
368    /// let x = AdScalar::new_primal(2.0_f64).into_value();
369    /// assert!(matches!(x, AdValue::Primal(_)));
370    /// ```
371    pub fn into_value(self) -> AdValue<T> {
372        self.0
373    }
374
375    /// Returns primal scalar reference.
376    ///
377    /// # Examples
378    ///
379    /// ```rust
380    /// use ad_tensors_rs::AdScalar;
381    ///
382    /// let x = AdScalar::new_forward(2.0_f64, 1.0_f64);
383    /// assert_eq!(x.primal(), &2.0);
384    /// ```
385    pub fn primal(&self) -> &T {
386        self.0.primal_ref()
387    }
388
389    /// Returns tangent scalar reference when available.
390    ///
391    /// # Examples
392    ///
393    /// ```rust
394    /// use ad_tensors_rs::AdScalar;
395    ///
396    /// let x = AdScalar::new_forward(2.0_f64, 1.0_f64);
397    /// assert_eq!(x.tangent(), Some(&1.0));
398    /// ```
399    pub fn tangent(&self) -> Option<&T> {
400        self.0.tangent_ref()
401    }
402}
403
404impl<T> From<T> for AdScalar<T> {
405    fn from(value: T) -> Self {
406        Self(AdValue::Primal(value))
407    }
408}
409
410impl<T> From<AdValue<T>> for AdScalar<T> {
411    fn from(value: AdValue<T>) -> Self {
412        Self(value)
413    }
414}
415
416impl<T> From<AdScalar<T>> for AdValue<T> {
417    fn from(value: AdScalar<T>) -> Self {
418        value.0
419    }
420}
421
422/// Tensor newtype carrying AD mode information.
423///
424/// # Examples
425///
426/// ```rust
427/// use ad_tensors_rs::{AdMode, AdTensor};
428/// use tenferro_tensor::{MemoryOrder, Tensor};
429///
430/// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
431/// let x: AdTensor<f64> = t.into();
432/// assert_eq!(x.mode(), AdMode::Primal);
433/// ```
434#[derive(Clone)]
435pub struct AdTensor<T: Scalar>(pub AdValue<Tensor<T>>);
436
437impl<T: Scalar> AdTensor<T> {
438    /// Creates a primal tensor.
439    ///
440    /// # Examples
441    ///
442    /// ```rust
443    /// use ad_tensors_rs::AdTensor;
444    /// use tenferro_tensor::{MemoryOrder, Tensor};
445    ///
446    /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
447    /// let x = AdTensor::new_primal(t);
448    /// assert_eq!(x.dims(), &[1]);
449    /// ```
450    pub fn new_primal(tensor: Tensor<T>) -> Self {
451        Self(AdValue::primal(tensor))
452    }
453
454    /// Creates a forward-mode tensor.
455    ///
456    /// # Examples
457    ///
458    /// ```rust
459    /// use ad_tensors_rs::{AdMode, AdTensor};
460    /// use tenferro_tensor::{MemoryOrder, Tensor};
461    ///
462    /// let primal = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
463    /// let tangent = Tensor::<f64>::from_slice(&[0.1], &[1], MemoryOrder::ColumnMajor).unwrap();
464    /// let x = AdTensor::new_forward(primal, tangent);
465    /// assert_eq!(x.mode(), AdMode::Forward);
466    /// ```
467    pub fn new_forward(primal: Tensor<T>, tangent: Tensor<T>) -> Self {
468        Self(AdValue::forward(primal, tangent))
469    }
470
471    /// Creates a reverse-mode tensor.
472    ///
473    /// # Examples
474    ///
475    /// ```rust
476    /// use ad_tensors_rs::{AdMode, AdTensor, NodeId, TapeId};
477    /// use tenferro_tensor::{MemoryOrder, Tensor};
478    ///
479    /// let primal = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
480    /// let x = AdTensor::new_reverse(primal, NodeId(8), TapeId(3), None);
481    /// assert_eq!(x.mode(), AdMode::Reverse);
482    /// ```
483    pub fn new_reverse(
484        primal: Tensor<T>,
485        node: NodeId,
486        tape: TapeId,
487        tangent: Option<Tensor<T>>,
488    ) -> Self {
489        Self(AdValue::reverse(primal, node, tape, tangent))
490    }
491
492    /// Returns AD mode.
493    ///
494    /// # Examples
495    ///
496    /// ```rust
497    /// use ad_tensors_rs::{AdMode, AdTensor};
498    /// use tenferro_tensor::{MemoryOrder, Tensor};
499    ///
500    /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
501    /// let x = AdTensor::new_primal(t);
502    /// assert_eq!(x.mode(), AdMode::Primal);
503    /// ```
504    pub fn mode(&self) -> AdMode {
505        self.0.mode()
506    }
507
508    /// Returns reference to underlying [`AdValue`].
509    ///
510    /// # Examples
511    ///
512    /// ```rust
513    /// use ad_tensors_rs::{AdTensor, AdValue};
514    /// use tenferro_tensor::{MemoryOrder, Tensor};
515    ///
516    /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
517    /// let x = AdTensor::new_primal(t);
518    /// assert!(matches!(x.as_value(), AdValue::Primal(_)));
519    /// ```
520    pub fn as_value(&self) -> &AdValue<Tensor<T>> {
521        &self.0
522    }
523
524    /// Consumes wrapper and returns the underlying [`AdValue`].
525    ///
526    /// # Examples
527    ///
528    /// ```rust
529    /// use ad_tensors_rs::{AdTensor, AdValue};
530    /// use tenferro_tensor::{MemoryOrder, Tensor};
531    ///
532    /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
533    /// let x = AdTensor::new_primal(t).into_value();
534    /// assert!(matches!(x, AdValue::Primal(_)));
535    /// ```
536    pub fn into_value(self) -> AdValue<Tensor<T>> {
537        self.0
538    }
539
540    /// Returns primal tensor reference.
541    ///
542    /// # Examples
543    ///
544    /// ```rust
545    /// use ad_tensors_rs::AdTensor;
546    /// use tenferro_tensor::{MemoryOrder, Tensor};
547    ///
548    /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
549    /// let x = AdTensor::new_primal(t);
550    /// assert_eq!(x.primal().dims(), &[2]);
551    /// ```
552    pub fn primal(&self) -> &Tensor<T> {
553        self.0.primal_ref()
554    }
555
556    /// Returns tangent tensor reference when available.
557    ///
558    /// # Examples
559    ///
560    /// ```rust
561    /// use ad_tensors_rs::AdTensor;
562    /// use tenferro_tensor::{MemoryOrder, Tensor};
563    ///
564    /// let primal = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
565    /// let tangent = Tensor::<f64>::from_slice(&[0.5], &[1], MemoryOrder::ColumnMajor).unwrap();
566    /// let x = AdTensor::new_forward(primal, tangent);
567    /// assert_eq!(x.tangent().unwrap().dims(), &[1]);
568    /// ```
569    pub fn tangent(&self) -> Option<&Tensor<T>> {
570        self.0.tangent_ref()
571    }
572
573    /// Returns dimensions of the primal tensor.
574    ///
575    /// # Examples
576    ///
577    /// ```rust
578    /// use ad_tensors_rs::AdTensor;
579    /// use tenferro_tensor::{MemoryOrder, Tensor};
580    ///
581    /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
582    /// let x = AdTensor::new_primal(t);
583    /// assert_eq!(x.dims(), &[2]);
584    /// ```
585    pub fn dims(&self) -> &[usize] {
586        self.primal().dims()
587    }
588
589    /// Returns number of dimensions of the primal tensor.
590    ///
591    /// # Examples
592    ///
593    /// ```rust
594    /// use ad_tensors_rs::AdTensor;
595    /// use tenferro_tensor::{MemoryOrder, Tensor};
596    ///
597    /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
598    /// let x = AdTensor::new_primal(t);
599    /// assert_eq!(x.ndim(), 1);
600    /// ```
601    pub fn ndim(&self) -> usize {
602        self.dims().len()
603    }
604
605    /// Returns total number of elements in the primal tensor.
606    ///
607    /// # Examples
608    ///
609    /// ```rust
610    /// use ad_tensors_rs::AdTensor;
611    /// use tenferro_tensor::{MemoryOrder, Tensor};
612    ///
613    /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
614    /// let x = AdTensor::new_primal(t);
615    /// assert_eq!(x.len(), 2);
616    /// ```
617    pub fn len(&self) -> usize {
618        self.dims().iter().product()
619    }
620
621    /// Returns true when primal tensor has zero elements.
622    ///
623    /// # Examples
624    ///
625    /// ```rust
626    /// use ad_tensors_rs::AdTensor;
627    /// use tenferro_tensor::{MemoryOrder, Tensor};
628    ///
629    /// let t = Tensor::<f64>::from_slice(&[], &[0], MemoryOrder::ColumnMajor).unwrap();
630    /// let x = AdTensor::new_primal(t);
631    /// assert!(x.is_empty());
632    /// ```
633    pub fn is_empty(&self) -> bool {
634        self.len() == 0
635    }
636}
637
638impl<T: Scalar> From<Tensor<T>> for AdTensor<T> {
639    fn from(value: Tensor<T>) -> Self {
640        Self(AdValue::Primal(value))
641    }
642}
643
644impl<T: Scalar> From<AdValue<Tensor<T>>> for AdTensor<T> {
645    fn from(value: AdValue<Tensor<T>>) -> Self {
646        Self(value)
647    }
648}
649
650impl<T: Scalar> From<AdTensor<T>> for AdValue<Tensor<T>> {
651    fn from(value: AdTensor<T>) -> Self {
652        value.0
653    }
654}
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659    use tenferro_tensor::MemoryOrder;
660
661    #[test]
662    fn ad_value_map_preserves_mode() {
663        let x = AdValue::forward(2_i32, 3_i32);
664        let y = x.map(|v| v as f64);
665        assert_eq!(y.mode(), AdMode::Forward);
666        assert_eq!(y.primal_ref(), &2.0_f64);
667        assert_eq!(y.tangent_ref(), Some(&3.0_f64));
668    }
669
670    #[test]
671    fn ad_tensor_metadata() {
672        let tensor =
673            Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
674        let ad = AdTensor::new_primal(tensor);
675        assert_eq!(ad.mode(), AdMode::Primal);
676        assert_eq!(ad.dims(), &[2]);
677        assert_eq!(ad.ndim(), 1);
678        assert_eq!(ad.len(), 2);
679    }
680}