ad_tensors_rs/ad_value.rs
1use tenferro_algebra::Scalar;
2use tenferro_tensor::Tensor;
3
4/// Automatic differentiation mode.
5///
6/// # Examples
7///
8/// ```rust
9/// use ad_tensors_rs::AdMode;
10///
11/// assert_eq!(AdMode::Primal, AdMode::Primal);
12/// ```
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum AdMode {
15 /// Plain evaluation without derivative propagation.
16 Primal,
17 /// Forward-mode value carrying tangent information.
18 Forward,
19 /// Reverse-mode value carrying graph metadata.
20 Reverse,
21}
22
23/// Opaque identifier of a reverse-mode graph node.
24///
25/// # Examples
26///
27/// ```rust
28/// use ad_tensors_rs::NodeId;
29///
30/// let node = NodeId(7);
31/// assert_eq!(node.0, 7);
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub struct NodeId(pub u64);
35
36/// Opaque identifier of a tape instance.
37///
38/// # Examples
39///
40/// ```rust
41/// use ad_tensors_rs::TapeId;
42///
43/// let tape = TapeId(2);
44/// assert_eq!(tape.0, 2);
45/// ```
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct TapeId(pub u64);
48
49/// Generic AD value that can wrap any user-defined payload type `T`.
50///
51/// This is the primary extension point of the crate.
52///
53/// # Examples
54///
55/// ```rust
56/// use ad_tensors_rs::{AdMode, AdValue, NodeId, TapeId};
57///
58/// let primal = AdValue::primal(3.0_f64);
59/// assert_eq!(primal.mode(), AdMode::Primal);
60///
61/// let dual = AdValue::forward(3.0_f64, 1.0_f64);
62/// assert_eq!(dual.mode(), AdMode::Forward);
63///
64/// let tracked = AdValue::reverse(3.0_f64, NodeId(1), TapeId(9), None);
65/// assert_eq!(tracked.mode(), AdMode::Reverse);
66/// ```
67#[derive(Debug, Clone, PartialEq)]
68pub enum AdValue<T> {
69 /// Primal-only value.
70 Primal(T),
71 /// Forward-mode value and tangent.
72 Forward { primal: T, tangent: T },
73 /// Reverse-mode value with graph metadata.
74 Reverse {
75 primal: T,
76 node: NodeId,
77 tape: TapeId,
78 tangent: Option<T>,
79 },
80}
81
82impl<T> AdValue<T> {
83 /// Creates a primal-only value.
84 ///
85 /// # Examples
86 ///
87 /// ```rust
88 /// use ad_tensors_rs::AdValue;
89 ///
90 /// let x = AdValue::primal(2_i32);
91 /// assert!(matches!(x, AdValue::Primal(2)));
92 /// ```
93 pub fn primal(value: T) -> Self {
94 Self::Primal(value)
95 }
96
97 /// Creates a forward-mode value.
98 ///
99 /// # Examples
100 ///
101 /// ```rust
102 /// use ad_tensors_rs::AdValue;
103 ///
104 /// let x = AdValue::forward(2.0_f64, 1.0_f64);
105 /// assert!(matches!(x, AdValue::Forward { .. }));
106 /// ```
107 pub fn forward(primal: T, tangent: T) -> Self {
108 Self::Forward { primal, tangent }
109 }
110
111 /// Creates a reverse-mode value.
112 ///
113 /// # Examples
114 ///
115 /// ```rust
116 /// use ad_tensors_rs::{AdValue, NodeId, TapeId};
117 ///
118 /// let x = AdValue::reverse(2.0_f64, NodeId(3), TapeId(5), Some(0.1));
119 /// assert!(matches!(x, AdValue::Reverse { .. }));
120 /// ```
121 pub fn reverse(primal: T, node: NodeId, tape: TapeId, tangent: Option<T>) -> Self {
122 Self::Reverse {
123 primal,
124 node,
125 tape,
126 tangent,
127 }
128 }
129
130 /// Returns the AD mode.
131 ///
132 /// # Examples
133 ///
134 /// ```rust
135 /// use ad_tensors_rs::{AdMode, AdValue};
136 ///
137 /// let x = AdValue::forward(1.0_f64, 1.0_f64);
138 /// assert_eq!(x.mode(), AdMode::Forward);
139 /// ```
140 pub fn mode(&self) -> AdMode {
141 match self {
142 Self::Primal(_) => AdMode::Primal,
143 Self::Forward { .. } => AdMode::Forward,
144 Self::Reverse { .. } => AdMode::Reverse,
145 }
146 }
147
148 /// Returns a reference to the primal payload.
149 ///
150 /// # Examples
151 ///
152 /// ```rust
153 /// use ad_tensors_rs::AdValue;
154 ///
155 /// let x = AdValue::forward(10_i32, 1_i32);
156 /// assert_eq!(x.primal_ref(), &10);
157 /// ```
158 pub fn primal_ref(&self) -> &T {
159 match self {
160 Self::Primal(value) => value,
161 Self::Forward { primal, .. } => primal,
162 Self::Reverse { primal, .. } => primal,
163 }
164 }
165
166 /// Returns a mutable reference to the primal payload.
167 ///
168 /// # Examples
169 ///
170 /// ```rust
171 /// use ad_tensors_rs::AdValue;
172 ///
173 /// let mut x = AdValue::primal(1_i32);
174 /// *x.primal_mut() = 7;
175 /// assert_eq!(x.primal_ref(), &7);
176 /// ```
177 pub fn primal_mut(&mut self) -> &mut T {
178 match self {
179 Self::Primal(value) => value,
180 Self::Forward { primal, .. } => primal,
181 Self::Reverse { primal, .. } => primal,
182 }
183 }
184
185 /// Returns a reference to tangent payload when available.
186 ///
187 /// # Examples
188 ///
189 /// ```rust
190 /// use ad_tensors_rs::AdValue;
191 ///
192 /// let x = AdValue::forward(2.0_f64, 3.0_f64);
193 /// assert_eq!(x.tangent_ref(), Some(&3.0));
194 /// ```
195 pub fn tangent_ref(&self) -> Option<&T> {
196 match self {
197 Self::Primal(_) => None,
198 Self::Forward { tangent, .. } => Some(tangent),
199 Self::Reverse { tangent, .. } => tangent.as_ref(),
200 }
201 }
202
203 /// Returns reverse-mode node id when available.
204 ///
205 /// # Examples
206 ///
207 /// ```rust
208 /// use ad_tensors_rs::{AdValue, NodeId, TapeId};
209 ///
210 /// let x = AdValue::reverse(1.0_f64, NodeId(4), TapeId(6), None);
211 /// assert_eq!(x.node_id(), Some(NodeId(4)));
212 /// ```
213 pub fn node_id(&self) -> Option<NodeId> {
214 match self {
215 Self::Reverse { node, .. } => Some(*node),
216 _ => None,
217 }
218 }
219
220 /// Returns reverse-mode tape id when available.
221 ///
222 /// # Examples
223 ///
224 /// ```rust
225 /// use ad_tensors_rs::{AdValue, NodeId, TapeId};
226 ///
227 /// let x = AdValue::reverse(1.0_f64, NodeId(4), TapeId(6), None);
228 /// assert_eq!(x.tape_id(), Some(TapeId(6)));
229 /// ```
230 pub fn tape_id(&self) -> Option<TapeId> {
231 match self {
232 Self::Reverse { tape, .. } => Some(*tape),
233 _ => None,
234 }
235 }
236
237 /// Maps the payload type while preserving AD mode.
238 ///
239 /// # Examples
240 ///
241 /// ```rust
242 /// use ad_tensors_rs::AdValue;
243 ///
244 /// let x = AdValue::forward(2_i32, 3_i32);
245 /// let y = x.map(|v| v as f64);
246 /// assert_eq!(y.primal_ref(), &2.0_f64);
247 /// assert_eq!(y.tangent_ref(), Some(&3.0_f64));
248 /// ```
249 pub fn map<U>(self, mut f: impl FnMut(T) -> U) -> AdValue<U> {
250 match self {
251 Self::Primal(value) => AdValue::Primal(f(value)),
252 Self::Forward { primal, tangent } => AdValue::Forward {
253 primal: f(primal),
254 tangent: f(tangent),
255 },
256 Self::Reverse {
257 primal,
258 node,
259 tape,
260 tangent,
261 } => AdValue::Reverse {
262 primal: f(primal),
263 node,
264 tape,
265 tangent: tangent.map(f),
266 },
267 }
268 }
269}
270
271impl<T> From<T> for AdValue<T> {
272 fn from(value: T) -> Self {
273 Self::Primal(value)
274 }
275}
276
277/// Scalar newtype carrying AD mode information.
278///
279/// # Examples
280///
281/// ```rust
282/// use ad_tensors_rs::{AdMode, AdScalar};
283///
284/// let x: AdScalar<f64> = 2.0_f64.into();
285/// assert_eq!(x.mode(), AdMode::Primal);
286/// ```
287#[derive(Debug, Clone, PartialEq)]
288pub struct AdScalar<T>(pub AdValue<T>);
289
290impl<T> AdScalar<T> {
291 /// Creates a primal scalar.
292 ///
293 /// # Examples
294 ///
295 /// ```rust
296 /// use ad_tensors_rs::{AdMode, AdScalar};
297 ///
298 /// let x = AdScalar::new_primal(1.5_f64);
299 /// assert_eq!(x.mode(), AdMode::Primal);
300 /// ```
301 pub fn new_primal(value: T) -> Self {
302 Self(AdValue::primal(value))
303 }
304
305 /// Creates a forward-mode scalar.
306 ///
307 /// # Examples
308 ///
309 /// ```rust
310 /// use ad_tensors_rs::{AdMode, AdScalar};
311 ///
312 /// let x = AdScalar::new_forward(2.0_f64, 1.0_f64);
313 /// assert_eq!(x.mode(), AdMode::Forward);
314 /// ```
315 pub fn new_forward(primal: T, tangent: T) -> Self {
316 Self(AdValue::forward(primal, tangent))
317 }
318
319 /// Creates a reverse-mode scalar.
320 ///
321 /// # Examples
322 ///
323 /// ```rust
324 /// use ad_tensors_rs::{AdMode, AdScalar, NodeId, TapeId};
325 ///
326 /// let x = AdScalar::new_reverse(2.0_f64, NodeId(1), TapeId(2), Some(0.4));
327 /// assert_eq!(x.mode(), AdMode::Reverse);
328 /// ```
329 pub fn new_reverse(primal: T, node: NodeId, tape: TapeId, tangent: Option<T>) -> Self {
330 Self(AdValue::reverse(primal, node, tape, tangent))
331 }
332
333 /// Returns AD mode.
334 ///
335 /// # Examples
336 ///
337 /// ```rust
338 /// use ad_tensors_rs::{AdMode, AdScalar};
339 ///
340 /// let x = AdScalar::new_primal(2.0_f64);
341 /// assert_eq!(x.mode(), AdMode::Primal);
342 /// ```
343 pub fn mode(&self) -> AdMode {
344 self.0.mode()
345 }
346
347 /// Returns reference to underlying [`AdValue`].
348 ///
349 /// # Examples
350 ///
351 /// ```rust
352 /// use ad_tensors_rs::{AdScalar, AdValue};
353 ///
354 /// let x = AdScalar::new_primal(2.0_f64);
355 /// assert!(matches!(x.as_value(), AdValue::Primal(_)));
356 /// ```
357 pub fn as_value(&self) -> &AdValue<T> {
358 &self.0
359 }
360
361 /// Consumes wrapper and returns the underlying [`AdValue`].
362 ///
363 /// # Examples
364 ///
365 /// ```rust
366 /// use ad_tensors_rs::{AdScalar, AdValue};
367 ///
368 /// let x = AdScalar::new_primal(2.0_f64).into_value();
369 /// assert!(matches!(x, AdValue::Primal(_)));
370 /// ```
371 pub fn into_value(self) -> AdValue<T> {
372 self.0
373 }
374
375 /// Returns primal scalar reference.
376 ///
377 /// # Examples
378 ///
379 /// ```rust
380 /// use ad_tensors_rs::AdScalar;
381 ///
382 /// let x = AdScalar::new_forward(2.0_f64, 1.0_f64);
383 /// assert_eq!(x.primal(), &2.0);
384 /// ```
385 pub fn primal(&self) -> &T {
386 self.0.primal_ref()
387 }
388
389 /// Returns tangent scalar reference when available.
390 ///
391 /// # Examples
392 ///
393 /// ```rust
394 /// use ad_tensors_rs::AdScalar;
395 ///
396 /// let x = AdScalar::new_forward(2.0_f64, 1.0_f64);
397 /// assert_eq!(x.tangent(), Some(&1.0));
398 /// ```
399 pub fn tangent(&self) -> Option<&T> {
400 self.0.tangent_ref()
401 }
402}
403
404impl<T> From<T> for AdScalar<T> {
405 fn from(value: T) -> Self {
406 Self(AdValue::Primal(value))
407 }
408}
409
410impl<T> From<AdValue<T>> for AdScalar<T> {
411 fn from(value: AdValue<T>) -> Self {
412 Self(value)
413 }
414}
415
416impl<T> From<AdScalar<T>> for AdValue<T> {
417 fn from(value: AdScalar<T>) -> Self {
418 value.0
419 }
420}
421
422/// Tensor newtype carrying AD mode information.
423///
424/// # Examples
425///
426/// ```rust
427/// use ad_tensors_rs::{AdMode, AdTensor};
428/// use tenferro_tensor::{MemoryOrder, Tensor};
429///
430/// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
431/// let x: AdTensor<f64> = t.into();
432/// assert_eq!(x.mode(), AdMode::Primal);
433/// ```
434#[derive(Clone)]
435pub struct AdTensor<T: Scalar>(pub AdValue<Tensor<T>>);
436
437impl<T: Scalar> AdTensor<T> {
438 /// Creates a primal tensor.
439 ///
440 /// # Examples
441 ///
442 /// ```rust
443 /// use ad_tensors_rs::AdTensor;
444 /// use tenferro_tensor::{MemoryOrder, Tensor};
445 ///
446 /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
447 /// let x = AdTensor::new_primal(t);
448 /// assert_eq!(x.dims(), &[1]);
449 /// ```
450 pub fn new_primal(tensor: Tensor<T>) -> Self {
451 Self(AdValue::primal(tensor))
452 }
453
454 /// Creates a forward-mode tensor.
455 ///
456 /// # Examples
457 ///
458 /// ```rust
459 /// use ad_tensors_rs::{AdMode, AdTensor};
460 /// use tenferro_tensor::{MemoryOrder, Tensor};
461 ///
462 /// let primal = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
463 /// let tangent = Tensor::<f64>::from_slice(&[0.1], &[1], MemoryOrder::ColumnMajor).unwrap();
464 /// let x = AdTensor::new_forward(primal, tangent);
465 /// assert_eq!(x.mode(), AdMode::Forward);
466 /// ```
467 pub fn new_forward(primal: Tensor<T>, tangent: Tensor<T>) -> Self {
468 Self(AdValue::forward(primal, tangent))
469 }
470
471 /// Creates a reverse-mode tensor.
472 ///
473 /// # Examples
474 ///
475 /// ```rust
476 /// use ad_tensors_rs::{AdMode, AdTensor, NodeId, TapeId};
477 /// use tenferro_tensor::{MemoryOrder, Tensor};
478 ///
479 /// let primal = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
480 /// let x = AdTensor::new_reverse(primal, NodeId(8), TapeId(3), None);
481 /// assert_eq!(x.mode(), AdMode::Reverse);
482 /// ```
483 pub fn new_reverse(
484 primal: Tensor<T>,
485 node: NodeId,
486 tape: TapeId,
487 tangent: Option<Tensor<T>>,
488 ) -> Self {
489 Self(AdValue::reverse(primal, node, tape, tangent))
490 }
491
492 /// Returns AD mode.
493 ///
494 /// # Examples
495 ///
496 /// ```rust
497 /// use ad_tensors_rs::{AdMode, AdTensor};
498 /// use tenferro_tensor::{MemoryOrder, Tensor};
499 ///
500 /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
501 /// let x = AdTensor::new_primal(t);
502 /// assert_eq!(x.mode(), AdMode::Primal);
503 /// ```
504 pub fn mode(&self) -> AdMode {
505 self.0.mode()
506 }
507
508 /// Returns reference to underlying [`AdValue`].
509 ///
510 /// # Examples
511 ///
512 /// ```rust
513 /// use ad_tensors_rs::{AdTensor, AdValue};
514 /// use tenferro_tensor::{MemoryOrder, Tensor};
515 ///
516 /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
517 /// let x = AdTensor::new_primal(t);
518 /// assert!(matches!(x.as_value(), AdValue::Primal(_)));
519 /// ```
520 pub fn as_value(&self) -> &AdValue<Tensor<T>> {
521 &self.0
522 }
523
524 /// Consumes wrapper and returns the underlying [`AdValue`].
525 ///
526 /// # Examples
527 ///
528 /// ```rust
529 /// use ad_tensors_rs::{AdTensor, AdValue};
530 /// use tenferro_tensor::{MemoryOrder, Tensor};
531 ///
532 /// let t = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
533 /// let x = AdTensor::new_primal(t).into_value();
534 /// assert!(matches!(x, AdValue::Primal(_)));
535 /// ```
536 pub fn into_value(self) -> AdValue<Tensor<T>> {
537 self.0
538 }
539
540 /// Returns primal tensor reference.
541 ///
542 /// # Examples
543 ///
544 /// ```rust
545 /// use ad_tensors_rs::AdTensor;
546 /// use tenferro_tensor::{MemoryOrder, Tensor};
547 ///
548 /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
549 /// let x = AdTensor::new_primal(t);
550 /// assert_eq!(x.primal().dims(), &[2]);
551 /// ```
552 pub fn primal(&self) -> &Tensor<T> {
553 self.0.primal_ref()
554 }
555
556 /// Returns tangent tensor reference when available.
557 ///
558 /// # Examples
559 ///
560 /// ```rust
561 /// use ad_tensors_rs::AdTensor;
562 /// use tenferro_tensor::{MemoryOrder, Tensor};
563 ///
564 /// let primal = Tensor::<f64>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
565 /// let tangent = Tensor::<f64>::from_slice(&[0.5], &[1], MemoryOrder::ColumnMajor).unwrap();
566 /// let x = AdTensor::new_forward(primal, tangent);
567 /// assert_eq!(x.tangent().unwrap().dims(), &[1]);
568 /// ```
569 pub fn tangent(&self) -> Option<&Tensor<T>> {
570 self.0.tangent_ref()
571 }
572
573 /// Returns dimensions of the primal tensor.
574 ///
575 /// # Examples
576 ///
577 /// ```rust
578 /// use ad_tensors_rs::AdTensor;
579 /// use tenferro_tensor::{MemoryOrder, Tensor};
580 ///
581 /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
582 /// let x = AdTensor::new_primal(t);
583 /// assert_eq!(x.dims(), &[2]);
584 /// ```
585 pub fn dims(&self) -> &[usize] {
586 self.primal().dims()
587 }
588
589 /// Returns number of dimensions of the primal tensor.
590 ///
591 /// # Examples
592 ///
593 /// ```rust
594 /// use ad_tensors_rs::AdTensor;
595 /// use tenferro_tensor::{MemoryOrder, Tensor};
596 ///
597 /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
598 /// let x = AdTensor::new_primal(t);
599 /// assert_eq!(x.ndim(), 1);
600 /// ```
601 pub fn ndim(&self) -> usize {
602 self.dims().len()
603 }
604
605 /// Returns total number of elements in the primal tensor.
606 ///
607 /// # Examples
608 ///
609 /// ```rust
610 /// use ad_tensors_rs::AdTensor;
611 /// use tenferro_tensor::{MemoryOrder, Tensor};
612 ///
613 /// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
614 /// let x = AdTensor::new_primal(t);
615 /// assert_eq!(x.len(), 2);
616 /// ```
617 pub fn len(&self) -> usize {
618 self.dims().iter().product()
619 }
620
621 /// Returns true when primal tensor has zero elements.
622 ///
623 /// # Examples
624 ///
625 /// ```rust
626 /// use ad_tensors_rs::AdTensor;
627 /// use tenferro_tensor::{MemoryOrder, Tensor};
628 ///
629 /// let t = Tensor::<f64>::from_slice(&[], &[0], MemoryOrder::ColumnMajor).unwrap();
630 /// let x = AdTensor::new_primal(t);
631 /// assert!(x.is_empty());
632 /// ```
633 pub fn is_empty(&self) -> bool {
634 self.len() == 0
635 }
636}
637
638impl<T: Scalar> From<Tensor<T>> for AdTensor<T> {
639 fn from(value: Tensor<T>) -> Self {
640 Self(AdValue::Primal(value))
641 }
642}
643
644impl<T: Scalar> From<AdValue<Tensor<T>>> for AdTensor<T> {
645 fn from(value: AdValue<Tensor<T>>) -> Self {
646 Self(value)
647 }
648}
649
650impl<T: Scalar> From<AdTensor<T>> for AdValue<Tensor<T>> {
651 fn from(value: AdTensor<T>) -> Self {
652 value.0
653 }
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659 use tenferro_tensor::MemoryOrder;
660
661 #[test]
662 fn ad_value_map_preserves_mode() {
663 let x = AdValue::forward(2_i32, 3_i32);
664 let y = x.map(|v| v as f64);
665 assert_eq!(y.mode(), AdMode::Forward);
666 assert_eq!(y.primal_ref(), &2.0_f64);
667 assert_eq!(y.tangent_ref(), Some(&3.0_f64));
668 }
669
670 #[test]
671 fn ad_tensor_metadata() {
672 let tensor =
673 Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
674 let ad = AdTensor::new_primal(tensor);
675 assert_eq!(ad.mode(), AdMode::Primal);
676 assert_eq!(ad.dims(), &[2]);
677 assert_eq!(ad.ndim(), 1);
678 assert_eq!(ad.len(), 2);
679 }
680}