chainrules/unary/roots.rs
1use crate::ScalarAd;
2
3/// Primal `cbrt`.
4///
5/// # Examples
6///
7/// ```rust
8/// use chainrules::cbrt;
9///
10/// assert_eq!(cbrt(8.0_f64), 2.0);
11/// ```
12pub fn cbrt<S: ScalarAd>(x: S) -> S {
13 x.cbrt()
14}
15
16/// Forward rule for `cbrt`.
17///
18/// # Examples
19///
20/// ```rust
21/// use chainrules::cbrt_frule;
22///
23/// let (y, dy) = cbrt_frule(8.0_f64, 1.0);
24/// assert_eq!(y, 2.0);
25/// assert!((dy - (1.0 / 12.0)).abs() < 1e-12);
26/// ```
27pub fn cbrt_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
28 let y = x.cbrt();
29 let scale = S::from_i32(1) / (S::from_i32(3) * y * y);
30 (y, dx * scale)
31}
32
33/// Reverse rule for `cbrt`.
34///
35/// # Examples
36///
37/// ```rust
38/// use chainrules::cbrt_rrule;
39///
40/// let dx = cbrt_rrule(2.0_f64, 1.0);
41/// assert!((dx - (1.0 / 12.0)).abs() < 1e-12);
42/// ```
43pub fn cbrt_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
44 cotangent * (S::from_i32(1) / (S::from_i32(3) * result * result)).conj()
45}
46
47/// Primal `inv`.
48///
49/// # Examples
50///
51/// ```rust
52/// use chainrules::inv;
53///
54/// assert_eq!(inv(4.0_f64), 0.25);
55/// ```
56pub fn inv<S: ScalarAd>(x: S) -> S {
57 x.recip()
58}
59
60/// Forward rule for `inv`.
61///
62/// # Examples
63///
64/// ```rust
65/// use chainrules::inv_frule;
66///
67/// let (y, dy) = inv_frule(4.0_f64, 2.0);
68/// assert_eq!(y, 0.25);
69/// assert!((dy + 0.125).abs() < 1e-12);
70/// ```
71pub fn inv_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
72 let y = x.recip();
73 (y, dx * (-(y * y)))
74}
75
76/// Reverse rule for `inv`.
77///
78/// # Examples
79///
80/// ```rust
81/// use chainrules::inv_rrule;
82///
83/// let dx = inv_rrule(0.25_f64, 2.0);
84/// assert!((dx + 0.125).abs() < 1e-12);
85/// ```
86pub fn inv_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
87 cotangent * (-(result * result)).conj()
88}