1use num_traits::{Float, One, Zero};
2
3use crate::ScalarAd;
4
5pub fn powf<S: ScalarAd>(x: S, exponent: S::Real) -> S {
11 x.powf(exponent)
12}
13
14pub fn powf_frule<S: ScalarAd>(x: S, exponent: S::Real, dx: S) -> (S, S) {
22 let y = x.powf(exponent);
23 let dy = if exponent == S::Real::zero() {
24 S::from_real(S::Real::zero())
25 } else {
26 dx * (S::from_real(exponent) * x.powf(exponent - S::Real::one()))
27 };
28 (y, dy)
29}
30
31pub fn powf_rrule<S: ScalarAd>(x: S, exponent: S::Real, cotangent: S) -> S {
42 if exponent == S::Real::zero() {
43 return S::from_real(S::Real::zero());
44 }
45 cotangent * (S::from_real(exponent) * x.powf(exponent - S::Real::one())).conj()
46}
47
48pub fn powi<S: ScalarAd>(x: S, exponent: i32) -> S {
54 x.powi(exponent)
55}
56
57pub fn powi_frule<S: ScalarAd>(x: S, exponent: i32, dx: S) -> (S, S) {
65 let y = x.powi(exponent);
66 let dy = if exponent == 0 {
67 S::from_i32(0)
68 } else {
69 dx * (S::from_i32(exponent) * x.powi(exponent - 1))
70 };
71 (y, dy)
72}
73
74pub fn powi_rrule<S: ScalarAd>(x: S, exponent: i32, cotangent: S) -> S {
85 if exponent == 0 {
86 return S::from_i32(0);
87 }
88 cotangent * (S::from_i32(exponent) * x.powi(exponent - 1)).conj()
89}
90
91#[doc = "Primal `pow(x, exponent)`.\n\n# Examples\n```rust\nuse chainrules::pow;\n\nassert_eq!(pow(2.0_f64, 3.0_f64), 8.0);\n```"]
92pub fn pow<S: ScalarAd>(x: S, exponent: S) -> S {
93 x.pow(exponent)
94}
95fn zero<S: ScalarAd>() -> S {
96 S::from_i32(0)
97}
98fn nan<S: ScalarAd>() -> S {
99 S::from_real(S::Real::nan())
100}
101fn pow_x_scale<S: ScalarAd>(x: S, exponent: S) -> S {
102 if exponent == zero::<S>() {
103 zero::<S>()
104 } else {
105 (exponent * x.pow(exponent - S::from_i32(1))).conj()
106 }
107}
108fn pow_exp_scale<S: ScalarAd>(x: S, exponent: S) -> S {
109 if x == zero::<S>() && exponent.imag() == S::Real::zero() {
110 if exponent.real() > S::Real::zero() {
111 zero::<S>()
112 } else {
113 nan::<S>()
114 }
115 } else {
116 (x.pow(exponent) * x.ln()).conj()
117 }
118}
119#[doc = "Forward rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-tangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_frule;\n\nlet (y, dy) = pow_frule(2.0_f64, 3.0_f64, 1.0, 0.0);\nassert_eq!(y, 8.0);\nassert!((dy - 12.0).abs() < 1e-12);\n```"]
120pub fn pow_frule<S: ScalarAd>(x: S, exponent: S, dx: S, dexponent: S) -> (S, S) {
121 let y = x.pow(exponent);
122 let dfdx = if dx == zero::<S>() {
123 zero::<S>()
124 } else {
125 dx * if exponent == zero::<S>() {
126 zero::<S>()
127 } else {
128 exponent * x.pow(exponent - S::from_i32(1))
129 }
130 };
131 let dfde = if dexponent == zero::<S>() {
132 zero::<S>()
133 } else {
134 dexponent
135 * if x == zero::<S>() && exponent.imag() == S::Real::zero() {
136 if exponent.real() > S::Real::zero() {
137 zero::<S>()
138 } else {
139 nan::<S>()
140 }
141 } else {
142 x.pow(exponent) * x.ln()
143 }
144 };
145 (y, dfdx + dfde)
146}
147#[doc = "Reverse rule for `pow(x, exponent)`.\n\nWhen `x` is zero and `exponent` is a non-positive real scalar, the exponent-cotangent path returns `NaN` to surface the singularity.\n\n# Examples\n```rust\nuse chainrules::pow_rrule;\n\nlet (dx, dexp) = pow_rrule(2.0_f64, 3.0_f64, 1.0);\nassert_eq!(dx, 12.0);\nassert!((dexp - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"]
148pub fn pow_rrule<S: ScalarAd>(x: S, exponent: S, cotangent: S) -> (S, S) {
149 let dfdx = if cotangent == zero::<S>() {
150 zero::<S>()
151 } else {
152 cotangent * pow_x_scale(x, exponent)
153 };
154 let dfde = if cotangent == zero::<S>() {
155 zero::<S>()
156 } else {
157 cotangent * pow_exp_scale(x, exponent)
158 };
159 (dfdx, dfde)
160}