chainrules/unary/
trig.rs

1use crate::unary::one;
2use crate::ScalarAd;
3
4/// Primal `sin`.
5pub fn sin<S: ScalarAd>(x: S) -> S {
6    x.sin()
7}
8
9/// Forward rule for `sin`.
10pub fn sin_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
11    let y = x.sin();
12    (y, dx * x.cos())
13}
14
15/// Reverse rule for `sin`.
16///
17/// Takes the original **input** `x`, not the result.
18pub fn sin_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
19    cotangent * x.cos().conj()
20}
21
22/// Primal `cos`.
23pub fn cos<S: ScalarAd>(x: S) -> S {
24    x.cos()
25}
26
27/// Forward rule for `cos`.
28pub fn cos_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
29    let y = x.cos();
30    (y, dx * -x.sin())
31}
32
33/// Reverse rule for `cos`.
34///
35/// Takes the original **input** `x`, not the result.
36pub fn cos_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
37    cotangent * (-x.sin()).conj()
38}
39
40fn inverse_sqrt_one_minus_square<S: ScalarAd>(x: S) -> S {
41    one::<S>() / (one::<S>() - x * x).sqrt()
42}
43
44/// Primal `asin`.
45pub fn asin<S: ScalarAd>(x: S) -> S {
46    x.asin()
47}
48
49/// Forward rule for `asin`.
50pub fn asin_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
51    let y = x.asin();
52    let scale = inverse_sqrt_one_minus_square(x);
53    (y, dx * scale)
54}
55
56/// Reverse rule for `asin`.
57pub fn asin_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
58    cotangent * inverse_sqrt_one_minus_square(x).conj()
59}
60
61/// Primal `acos`.
62pub fn acos<S: ScalarAd>(x: S) -> S {
63    x.acos()
64}
65
66/// Forward rule for `acos`.
67pub fn acos_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
68    let y = x.acos();
69    let scale = -inverse_sqrt_one_minus_square(x);
70    (y, dx * scale)
71}
72
73/// Reverse rule for `acos`.
74pub fn acos_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
75    cotangent * (-inverse_sqrt_one_minus_square(x)).conj()
76}
77
78/// Primal `atan`.
79pub fn atan<S: ScalarAd>(x: S) -> S {
80    x.atan()
81}
82
83/// Forward rule for `atan`.
84pub fn atan_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
85    let y = x.atan();
86    let scale = one::<S>() / (one::<S>() + x * x);
87    (y, dx * scale)
88}
89
90/// Reverse rule for `atan`.
91pub fn atan_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
92    cotangent * (one::<S>() / (one::<S>() + x * x)).conj()
93}
94
95#[doc = "Primal `tan`.\n\n# Examples\n```rust\nuse chainrules::tan;\n\nassert!((tan(0.5_f64) - 0.5_f64.tan()).abs() < 1e-12);\n```"]
96pub fn tan<S: ScalarAd>(x: S) -> S {
97    x.tan()
98}
99
100#[doc = "Forward rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_frule;\n\nlet (y, dy) = tan_frule(0.25_f64, 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"]
101pub fn tan_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
102    let y = x.tan();
103    (y, dx * (one::<S>() + y * y))
104}
105
106#[doc = "Reverse rule for `tan`.\n\n# Examples\n```rust\nuse chainrules::tan_rrule;\n\nlet dy = tan_rrule(0.25_f64.tan(), 1.0);\nassert!((dy - (1.0 + 0.25_f64.tan().powi(2))).abs() < 1e-12);\n```"]
107pub fn tan_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
108    cotangent * (one::<S>() + result * result).conj()
109}
110
111#[doc = "Primal `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos;\n\nlet (s, c) = sincos(0.5_f64);\nassert!((s - 0.5_f64.sin()).abs() < 1e-12);\nassert!((c - 0.5_f64.cos()).abs() < 1e-12);\n```"]
112pub fn sincos<S: ScalarAd>(x: S) -> (S, S) {
113    (x.sin(), x.cos())
114}
115
116#[doc = "Forward rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_frule;\n\nlet ((s, c), (ds, dc)) = sincos_frule(0.25_f64, 1.0);\nassert!((ds - 0.25_f64.cos()).abs() < 1e-12);\nassert!((dc + 0.25_f64.sin()).abs() < 1e-12);\n```"]
117pub fn sincos_frule<S: ScalarAd>(x: S, dx: S) -> ((S, S), (S, S)) {
118    let sin_x = x.sin();
119    let cos_x = x.cos();
120    ((sin_x, cos_x), (dx * cos_x, dx * -sin_x))
121}
122
123#[doc = "Reverse rule for `sincos`.\n\n# Examples\n```rust\nuse chainrules::sincos_rrule;\n\nlet dx = sincos_rrule(0.25_f64, (1.0, 1.0));\nassert!((dx - (0.25_f64.cos() - 0.25_f64.sin())).abs() < 1e-12);\n```"]
124pub fn sincos_rrule<S: ScalarAd>(x: S, cotangents: (S, S)) -> S {
125    let (cotangent_sin, cotangent_cos) = cotangents;
126    let sin_x = x.sin();
127    let cos_x = x.cos();
128    cotangent_sin * cos_x.conj() + cotangent_cos * (-sin_x).conj()
129}