chainrules/
power.rs

1use num_traits::{Float, One, Zero};
2
3use crate::ScalarAd;
4
5/// Primal `powf`.
6/// ```rust
7/// use chainrules::powf;
8/// assert_eq!(powf(2.0_f64, 3.0), 8.0);
9/// ```
10pub fn powf<S: ScalarAd>(x: S, exponent: S::Real) -> S {
11    x.powf(exponent)
12}
13
14/// Forward rule for `powf` with fixed exponent. Returns `(primal, tangent)`.
15/// ```rust
16/// use chainrules::powf_frule;
17/// let (y, dy) = powf_frule(2.0_f64, 3.0, 1.0);
18/// assert_eq!(y, 8.0);
19/// assert_eq!(dy, 12.0);
20/// ```
21pub fn powf_frule<S: ScalarAd>(x: S, exponent: S::Real, dx: S) -> (S, S) {
22    let y = x.powf(exponent);
23    let dy = if exponent == S::Real::zero() {
24        S::from_real(S::Real::zero())
25    } else {
26        dx * (S::from_real(exponent) * x.powf(exponent - S::Real::one()))
27    };
28    (y, dy)
29}
30
31/// Reverse rule for `powf` with fixed exponent.
32///
33/// # Examples
34///
35/// ```rust
36/// use chainrules::powf_rrule;
37///
38/// let dx = powf_rrule(2.0_f64, 3.0, 1.0);
39/// assert_eq!(dx, 12.0);
40/// ```
41pub fn powf_rrule<S: ScalarAd>(x: S, exponent: S::Real, cotangent: S) -> S {
42    if exponent == S::Real::zero() {
43        return S::from_real(S::Real::zero());
44    }
45    cotangent * (S::from_real(exponent) * x.powf(exponent - S::Real::one())).conj()
46}
47
48/// Primal `powi`.
49/// ```rust
50/// use chainrules::powi;
51/// assert_eq!(powi(2.0_f64, 4), 16.0);
52/// ```
53pub fn powi<S: ScalarAd>(x: S, exponent: i32) -> S {
54    x.powi(exponent)
55}
56
57/// Forward rule for `powi` with fixed integer exponent. Returns `(primal, tangent)`.
58/// ```rust
59/// use chainrules::powi_frule;
60/// let (y, dy) = powi_frule(2.0_f64, 4, 1.0);
61/// assert_eq!(y, 16.0);
62/// assert_eq!(dy, 32.0);
63/// ```
64pub fn powi_frule<S: ScalarAd>(x: S, exponent: i32, dx: S) -> (S, S) {
65    let y = x.powi(exponent);
66    let dy = if exponent == 0 {
67        S::from_i32(0)
68    } else {
69        dx * (S::from_i32(exponent) * x.powi(exponent - 1))
70    };
71    (y, dy)
72}
73
74/// Reverse rule for `powi` with fixed integer exponent.
75///
76/// # Examples
77///
78/// ```rust
79/// use chainrules::powi_rrule;
80///
81/// let dx = powi_rrule(2.0_f64, 4, 1.0);
82/// assert_eq!(dx, 32.0);
83/// ```
84pub fn powi_rrule<S: ScalarAd>(x: S, exponent: i32, cotangent: S) -> S {
85    if exponent == 0 {
86        return S::from_i32(0);
87    }
88    cotangent * (S::from_i32(exponent) * x.powi(exponent - 1)).conj()
89}
90
91#[doc = "Primal `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow;\n\nassert_eq!(pow(2.0_f64, 3.0_f64), 8.0);\n```"]
92pub fn pow<S: ScalarAd>(x: S, exponent: S) -> S {
93    x.pow(exponent)
94}
95fn zero<S: ScalarAd>() -> S {
96    S::from_i32(0)
97}
98fn nan<S: ScalarAd>() -> S {
99    S::from_real(S::Real::nan())
100}
101fn pow_x_scale<S: ScalarAd>(x: S, exponent: S) -> S {
102    if exponent == zero::<S>() {
103        zero::<S>()
104    } else {
105        (exponent * x.pow(exponent - S::from_i32(1))).conj()
106    }
107}
108fn pow_exp_scale<S: ScalarAd>(x: S, exponent: S) -> S {
109    if x == zero::<S>() && exponent.imag() == S::Real::zero() {
110        if exponent.real() > S::Real::zero() {
111            zero::<S>()
112        } else {
113            nan::<S>()
114        }
115    } else {
116        (x.pow(exponent) * x.ln()).conj()
117    }
118}
119#[doc = "Forward rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-tangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_frule;\n\nlet (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0);\nassert_eq!(y, 8.0);\nassert!((dy - 12.0).abs() < 1e-12);\n```"]
120pub fn pow_frule<S: ScalarAd>(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) {
121    let y = x.pow(exponent);
122    let dfdx = if dx == zero::<S>() {
123        zero::<S>()
124    } else {
125        dx * if exponent == zero::<S>() {
126            zero::<S>()
127        } else {
128            exponent * x.pow(exponent - S::from_i32(1))
129        }
130    };
131    let dfde = if dexponent == zero::<S>() {
132        zero::<S>()
133    } else {
134        dexponent
135            * if x == zero::<S>() && exponent.imag() == S::Real::zero() {
136                if exponent.real() > S::Real::zero() {
137                    zero::<S>()
138                } else {
139                    nan::<S>()
140                }
141            } else {
142                x.pow(exponent) * x.ln()
143            }
144    };
145    (y, dfdx + dfde)
146}
147#[doc = "Reverse rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-cotangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_rrule;\n\nlet (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0);\nassert_eq!(dx, 12.0);\nassert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"]
148pub fn pow_rrule<S: ScalarAd>(x: S, exponent: S, cotangent: S) -> (S, S) {
149    let dfdx = if cotangent == zero::<S>() {
150        zero::<S>()
151    } else {
152        cotangent * pow_x_scale(x, exponent)
153    };
154    let dfde = if cotangent == zero::<S>() {
155        zero::<S>()
156    } else {
157        cotangent * pow_exp_scale(x, exponent)
158    };
159    (dfdx, dfde)
160}