chainrules/unary/
exp_log.rs

1use crate::unary::one;
2use crate::ScalarAd;
3use num_traits::FloatConst;
4fn ln_2<S: ScalarAd>() -> S {
5    S::from_real(S::Real::LN_2())
6}
7fn ln_10<S: ScalarAd>() -> S {
8    S::from_real(S::Real::LN_10())
9}
10/// Primal `exp`.
11pub fn exp<S: ScalarAd>(x: S) -> S {
12    x.exp()
13}
14/// Forward rule for `exp`.
15pub fn exp_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
16    let y = x.exp();
17    (y, dx * y)
18}
19/// Reverse rule for `exp`.
20///
21/// Takes the forward **result** `exp(x)`, not the input `x`.
22pub fn exp_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
23    cotangent * result.conj()
24}
25/// Primal `exp(x) - 1`.
26pub fn expm1<S: ScalarAd>(x: S) -> S {
27    x.expm1()
28}
29/// Forward rule for `exp(x) - 1`.
30pub fn expm1_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
31    let y = x.expm1();
32    let scale = y + one::<S>();
33    (y, dx * scale)
34}
35/// Reverse rule for `exp(x) - 1`.
36///
37/// Takes the forward **result** `expm1(x)`, not the input `x`.
38pub fn expm1_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
39    cotangent * (result + one::<S>()).conj()
40}
41#[doc = "Primal `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2;\n\nassert!((exp2(3.0_f64) - 8.0).abs() < 1e-12);\n```"]
42pub fn exp2<S: ScalarAd>(x: S) -> S {
43    x.exp2()
44}
45#[doc = "Forward rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_frule;\n\nlet (y, dy) = exp2_frule(3.0_f64, 1.0);\nassert!((y - 8.0).abs() < 1e-12);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"]
46pub fn exp2_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
47    let y = x.exp2();
48    (y, dx * (y * ln_2::<S>()))
49}
50#[doc = "Reverse rule for `2^x`.\n\n# Examples\n```rust\nuse chainrules::exp2_rrule;\n\nlet dy = exp2_rrule(8.0_f64, 1.0);\nassert!((dy - 8.0_f64 * std::f64::consts::LN_2).abs() < 1e-12);\n```"]
51pub fn exp2_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
52    cotangent * (result * ln_2::<S>()).conj()
53}
54#[doc = "Primal `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10;\n\nassert!((exp10(2.0_f64) - 100.0).abs() < 1e-12);\n```"]
55pub fn exp10<S: ScalarAd>(x: S) -> S {
56    x.exp10()
57}
58#[doc = "Forward rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_frule;\n\nlet (y, dy) = exp10_frule(2.0_f64, 0.5);\nassert!((y - 100.0).abs() < 1e-12);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"]
59pub fn exp10_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
60    let y = x.exp10();
61    (y, dx * (y * ln_10::<S>()))
62}
63#[doc = "Reverse rule for `10^x`.\n\n# Examples\n```rust\nuse chainrules::exp10_rrule;\n\nlet dy = exp10_rrule(100.0_f64, 0.5);\nassert!((dy - 0.5_f64 * 100.0_f64 * std::f64::consts::LN_10).abs() < 1e-12);\n```"]
64pub fn exp10_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
65    cotangent * (result * ln_10::<S>()).conj()
66}
67/// Primal `log`.
68pub fn log<S: ScalarAd>(x: S) -> S {
69    x.ln()
70}
71/// Forward rule for `log`.
72pub fn log_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
73    let y = x.ln();
74    let dy = dx * (one::<S>() / x);
75    (y, dy)
76}
77/// Reverse rule for `log`.
78///
79/// Takes the original **input** `x`, not the result.
80pub fn log_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
81    cotangent * (one::<S>() / x).conj()
82}
83/// Primal `log(1 + x)`.
84pub fn log1p<S: ScalarAd>(x: S) -> S {
85    x.log1p()
86}
87/// Forward rule for `log(1 + x)`.
88pub fn log1p_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
89    let y = x.log1p();
90    let dy = dx * (one::<S>() / (one::<S>() + x));
91    (y, dy)
92}
93/// Reverse rule for `log(1 + x)`.
94///
95/// Takes the original **input** `x`, not the result.
96pub fn log1p_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
97    cotangent * (one::<S>() / (one::<S>() + x)).conj()
98}
99#[doc = "Primal `log2`.\n\n# Examples\n```rust\nuse chainrules::log2;\n\nassert_eq!(log2(8.0_f64), 3.0);\n```"]
100pub fn log2<S: ScalarAd>(x: S) -> S {
101    x.log2()
102}
103#[doc = "Forward rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_frule;\n\nlet (y, dy) = log2_frule(8.0_f64, 2.0);\nassert!((y - 3.0).abs() < 1e-12);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"]
104pub fn log2_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
105    let y = x.log2();
106    let scale = one::<S>() / (x * ln_2::<S>());
107    (y, dx * scale)
108}
109#[doc = "Reverse rule for `log2`.\n\n# Examples\n```rust\nuse chainrules::log2_rrule;\n\nlet dy = log2_rrule(8.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (8.0_f64 * std::f64::consts::LN_2))).abs() < 1e-12);\n```"]
110pub fn log2_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
111    cotangent * (one::<S>() / (x * ln_2::<S>())).conj()
112}
113#[doc = "Primal `log10`.\n\n# Examples\n```rust\nuse chainrules::log10;\n\nassert_eq!(log10(100.0_f64), 2.0);\n```"]
114pub fn log10<S: ScalarAd>(x: S) -> S {
115    x.log10()
116}
117#[doc = "Forward rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_frule;\n\nlet (y, dy) = log10_frule(100.0_f64, 2.0);\nassert!((y - 2.0).abs() < 1e-12);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"]
118pub fn log10_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
119    let y = x.log10();
120    let scale = one::<S>() / (x * ln_10::<S>());
121    (y, dx * scale)
122}
123#[doc = "Reverse rule for `log10`.\n\n# Examples\n```rust\nuse chainrules::log10_rrule;\n\nlet dy = log10_rrule(100.0_f64, 2.0);\nassert!((dy - (2.0_f64 / (100.0_f64 * std::f64::consts::LN_10))).abs() < 1e-12);\n```"]
124pub fn log10_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
125    cotangent * (one::<S>() / (x * ln_10::<S>())).conj()
126}