chainrules/unary/
trig.rs

1use crate::unary::one;
2use crate::ScalarAd;
3
4/// Primal `sin`.
5pub fn sin<S: ScalarAd>(x: S) -> S {
6    x.sin()
7}
8
9/// Forward rule for `sin`.
10pub fn sin_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
11    let y = x.sin();
12    (y, dx * x.cos())
13}
14
15/// Reverse rule for `sin`.
16pub fn sin_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
17    cotangent * x.cos().conj()
18}
19
20/// Primal `cos`.
21pub fn cos<S: ScalarAd>(x: S) -> S {
22    x.cos()
23}
24
25/// Forward rule for `cos`.
26pub fn cos_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
27    let y = x.cos();
28    (y, dx * -x.sin())
29}
30
31/// Reverse rule for `cos`.
32pub fn cos_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
33    cotangent * (-x.sin()).conj()
34}
35
36fn inverse_sqrt_one_minus_square<S: ScalarAd>(x: S) -> S {
37    one::<S>() / (one::<S>() - x * x).sqrt()
38}
39
40/// Primal `asin`.
41pub fn asin<S: ScalarAd>(x: S) -> S {
42    x.asin()
43}
44
45/// Forward rule for `asin`.
46pub fn asin_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
47    let y = x.asin();
48    let scale = inverse_sqrt_one_minus_square(x);
49    (y, dx * scale)
50}
51
52/// Reverse rule for `asin`.
53pub fn asin_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
54    cotangent * inverse_sqrt_one_minus_square(x).conj()
55}
56
57/// Primal `acos`.
58pub fn acos<S: ScalarAd>(x: S) -> S {
59    x.acos()
60}
61
62/// Forward rule for `acos`.
63pub fn acos_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
64    let y = x.acos();
65    let scale = -inverse_sqrt_one_minus_square(x);
66    (y, dx * scale)
67}
68
69/// Reverse rule for `acos`.
70pub fn acos_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
71    cotangent * (-inverse_sqrt_one_minus_square(x)).conj()
72}
73
74/// Primal `atan`.
75pub fn atan<S: ScalarAd>(x: S) -> S {
76    x.atan()
77}
78
79/// Forward rule for `atan`.
80pub fn atan_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
81    let y = x.atan();
82    let scale = one::<S>() / (one::<S>() + x * x);
83    (y, dx * scale)
84}
85
86/// Reverse rule for `atan`.
87pub fn atan_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
88    cotangent * (one::<S>() / (one::<S>() + x * x)).conj()
89}
90
91#[doc = "Primal `tan`.\n\n# Examples\n```rust\nuse chainrules::tan;\n\nassert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12);\n```"]
92pub fn tan<S: ScalarAd>(x: S) -> S {
93    x.tan()
94}
95
96#[doc = "Forward rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_frule;\n\nlet (y, dy) = tan_frule(0.25_f64, 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"]
97pub fn tan_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
98    let y = x.tan();
99    (y, dx * (one::<S>() + y * y))
100}
101
102#[doc = "Reverse rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_rrule;\n\nlet dy = tan_rrule(0.25_f64.tan(), 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"]
103pub fn tan_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
104    cotangent * (one::<S>() + result * result).conj()
105}
106
107#[doc = "Primal `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos;\n\nlet (s, c) = sincos(0.5_f64);\nassert!((s - 0.5_f64.sin()).abs() < 1e-12);\nassert!((c - 0.5_f64.cos()).abs() < 1e-12);\n```"]
108pub fn sincos<S: ScalarAd>(x: S) -> (S, S) {
109    (x.sin(), x.cos())
110}
111
112#[doc = "Forward rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_frule;\n\nlet ((s, c), (ds, dc)) = sincos_frule(0.25_f64, 1.0);\nassert!((ds - 0.25_f64.cos()).abs() < 1e-12);\nassert!((dc + 0.25_f64.sin()).abs() < 1e-12);\n```"]
113pub fn sincos_frule<S: ScalarAd>(x: S, dx: S) -> ((S, S), (S, S)) {
114    let sin_x = x.sin();
115    let cos_x = x.cos();
116    ((sin_x, cos_x), (dx * cos_x, dx * -sin_x))
117}
118
119#[doc = "Reverse rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_rrule;\n\nlet dx = sincos_rrule(0.25_f64, (1.0, 1.0));\nassert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12);\n```"]
120pub fn sincos_rrule<S: ScalarAd>(x: S, cotangents: (S, S)) -> S {
121    let (cotangent_sin, cotangent_cos) = cotangents;
122    let sin_x = x.sin();
123    let cos_x = x.cos();
124    cotangent_sin * cos_x.conj() + cotangent_cos * (-sin_x).conj()
125}