chainrules/unary/
trig_extra.rs

1use crate::binary::{mul_frule, mul_rrule};
2use crate::unary::{
3    cos, cos_frule, cos_rrule, inv, inv_frule, inv_rrule, sin, sin_frule, sin_rrule, sincos, tan,
4    tan_frule, tan_rrule,
5};
6use crate::ScalarAd;
7use num_traits::{Float, FloatConst, Zero};
8
9fn pi<S: ScalarAd>() -> S {
10    S::from_real(S::Real::PI())
11}
12
13fn real<R: Float>(value: f64) -> R {
14    match R::from(value) {
15        Some(value) => value,
16        None => unreachable!("float constant conversion should succeed"),
17    }
18}
19
20fn deg2rad<S: ScalarAd>() -> S {
21    pi::<S>() / S::from_real(real::<S::Real>(180.0))
22}
23
24fn real_input<S: ScalarAd>(x: S) -> Option<S::Real> {
25    if x.imag().is_zero() {
26        Some(x.real())
27    } else {
28        None
29    }
30}
31
32fn sinpi_real<R: Float + FloatConst>(x: R) -> R {
33    let two = real::<R>(2.0);
34    let reduced = x - (x / two).floor() * two;
35    let zero = real::<R>(0.0);
36    let one = real::<R>(1.0);
37    let half = real::<R>(0.5);
38    let three_half = real::<R>(1.5);
39    if reduced == zero || reduced == one {
40        zero
41    } else if reduced == half {
42        one
43    } else if reduced == three_half {
44        -one
45    } else {
46        (R::PI() * reduced).sin()
47    }
48}
49
50fn cospi_real<R: Float + FloatConst>(x: R) -> R {
51    let two = real::<R>(2.0);
52    let reduced = x - (x / two).floor() * two;
53    let zero = real::<R>(0.0);
54    let one = real::<R>(1.0);
55    let half = real::<R>(0.5);
56    let three_half = real::<R>(1.5);
57    if reduced == zero {
58        one
59    } else if reduced == one {
60        -one
61    } else if reduced == half || reduced == three_half {
62        zero
63    } else {
64        (R::PI() * reduced).cos()
65    }
66}
67
68fn tand_real<R: Float + FloatConst>(x: R) -> R {
69    let one_eighty = real::<R>(180.0);
70    let reduced = x - (x / one_eighty).floor() * one_eighty;
71    let zero = real::<R>(0.0);
72    let forty_five = real::<R>(45.0);
73    let ninety = real::<R>(90.0);
74    let one_thirty_five = real::<R>(135.0);
75    if reduced == zero {
76        zero
77    } else if reduced == forty_five {
78        real::<R>(1.0)
79    } else if reduced == ninety {
80        R::infinity().copysign(sinpi_real(x / one_eighty))
81    } else if reduced == one_thirty_five {
82        real::<R>(-1.0)
83    } else {
84        (R::PI() * reduced / one_eighty).tan()
85    }
86}
87
88/// Primal `sec`.
89///
90/// # Examples
91/// ```rust
92/// use chainrules::sec;
93/// assert!((sec(0.5_f64) - 1.0 / 0.5_f64.cos()).abs() < 1e-12);
94/// ```
95pub fn sec<S: ScalarAd>(x: S) -> S {
96    inv(cos(x))
97}
98
99/// Forward rule for `sec`.
100///
101/// # Examples
102/// ```rust
103/// use chainrules::sec_frule;
104/// let (y, dy) = sec_frule(0.5_f64, 1.0);
105/// assert!((y - 1.0 / 0.5_f64.cos()).abs() < 1e-12);
106/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12);
107/// ```
108pub fn sec_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
109    let (y, dy) = cos_frule(x, dx);
110    inv_frule(y, dy)
111}
112
113/// Reverse rule for `sec`.
114///
115/// # Examples
116/// ```rust
117/// use chainrules::sec_rrule;
118/// let dy = sec_rrule(0.5_f64, 1.0);
119/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12);
120/// ```
121pub fn sec_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
122    let y = sec(x);
123    let d_y = inv_rrule(y, cotangent);
124    cos_rrule(x, d_y)
125}
126
127/// Primal `csc`.
128///
129/// # Examples
130/// ```rust
131/// use chainrules::csc;
132/// assert!((csc(0.5_f64) - 1.0 / 0.5_f64.sin()).abs() < 1e-12);
133/// ```
134pub fn csc<S: ScalarAd>(x: S) -> S {
135    inv(sin(x))
136}
137
138/// Forward rule for `csc`.
139///
140/// # Examples
141/// ```rust
142/// use chainrules::csc_frule;
143/// let (_, dy) = csc_frule(0.5_f64, 1.0);
144/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12);
145/// ```
146pub fn csc_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
147    let (y, dy) = sin_frule(x, dx);
148    inv_frule(y, dy)
149}
150
151/// Reverse rule for `csc`.
152///
153/// # Examples
154/// ```rust
155/// use chainrules::csc_rrule;
156/// let dy = csc_rrule(0.5_f64, 1.0);
157/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12);
158/// ```
159pub fn csc_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
160    let y = csc(x);
161    let d_y = inv_rrule(y, cotangent);
162    sin_rrule(x, d_y)
163}
164
165/// Primal `cot`.
166///
167/// # Examples
168/// ```rust
169/// use chainrules::cot;
170/// assert!((cot(0.5_f64) - 1.0 / 0.5_f64.tan()).abs() < 1e-12);
171/// ```
172pub fn cot<S: ScalarAd>(x: S) -> S {
173    inv(tan(x))
174}
175
176/// Forward rule for `cot`.
177///
178/// # Examples
179/// ```rust
180/// use chainrules::cot_frule;
181/// let (_, dy) = cot_frule(0.5_f64, 1.0);
182/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12);
183/// ```
184pub fn cot_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
185    let (y, dy) = tan_frule(x, dx);
186    inv_frule(y, dy)
187}
188
189/// Reverse rule for `cot`.
190///
191/// # Examples
192/// ```rust
193/// use chainrules::cot_rrule;
194/// let dy = cot_rrule(0.5_f64, 1.0);
195/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12);
196/// ```
197pub fn cot_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
198    let y = cot(x);
199    let d_y = inv_rrule(y, cotangent);
200    tan_rrule(tan(x), d_y)
201}
202
203/// Primal `sinpi`.
204///
205/// # Examples
206///
207/// ```rust
208/// use chainrules::sinpi;
209///
210/// assert!((sinpi(0.25_f64) - 0.25_f64.mul_add(std::f64::consts::PI, 0.0).sin()).abs() < 1e-12);
211/// ```
212pub fn sinpi<S: ScalarAd>(x: S) -> S {
213    if let Some(x_real) = real_input(x) {
214        return S::from_real(sinpi_real(x_real));
215    }
216    sincos(pi::<S>() * x).0
217}
218
219/// Forward rule for `sinpi`.
220///
221/// # Examples
222///
223/// ```rust
224/// use chainrules::sinpi_frule;
225///
226/// let (_, dy) = sinpi_frule(0.25_f64, 1.0);
227/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
228/// ```
229pub fn sinpi_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
230    let y = sinpi(x);
231    let scale = pi::<S>() * cospi(x);
232    (y, dx * scale)
233}
234
235/// Reverse rule for `sinpi`.
236///
237/// # Examples
238///
239/// ```rust
240/// use chainrules::sinpi_rrule;
241///
242/// let dy = sinpi_rrule(0.25_f64, 1.0);
243/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
244/// ```
245pub fn sinpi_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
246    cotangent * (pi::<S>() * cospi(x)).conj()
247}
248
249/// Primal `cospi`.
250///
251/// # Examples
252///
253/// ```rust
254/// use chainrules::cospi;
255///
256/// assert!((cospi(0.25_f64) - (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
257/// ```
258pub fn cospi<S: ScalarAd>(x: S) -> S {
259    if let Some(x_real) = real_input(x) {
260        return S::from_real(cospi_real(x_real));
261    }
262    sincos(pi::<S>() * x).1
263}
264
265/// Forward rule for `cospi`.
266///
267/// # Examples
268///
269/// ```rust
270/// use chainrules::cospi_frule;
271///
272/// let (_, dy) = cospi_frule(0.25_f64, 1.0);
273/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12);
274/// ```
275pub fn cospi_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
276    let y = cospi(x);
277    let scale = -(pi::<S>() * sinpi(x));
278    (y, dx * scale)
279}
280
281/// Reverse rule for `cospi`.
282///
283/// # Examples
284///
285/// ```rust
286/// use chainrules::cospi_rrule;
287///
288/// let dy = cospi_rrule(0.25_f64, 1.0);
289/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12);
290/// ```
291pub fn cospi_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
292    cotangent * (-(pi::<S>() * sinpi(x))).conj()
293}
294
295/// Primal `sincospi`.
296///
297/// # Examples
298///
299/// ```rust
300/// use chainrules::sincospi;
301///
302/// let (s, c) = sincospi(0.25_f64);
303/// assert!((s - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12);
304/// assert!((c - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12);
305/// ```
306pub fn sincospi<S: ScalarAd>(x: S) -> (S, S) {
307    (sinpi(x), cospi(x))
308}
309
310/// Forward rule for `sincospi`.
311///
312/// # Examples
313///
314/// ```rust
315/// use chainrules::sincospi_frule;
316///
317/// let ((_, _), (ds, dc)) = sincospi_frule(0.25_f64, 1.0);
318/// assert!((ds - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
319/// assert!((dc + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12);
320/// ```
321pub fn sincospi_frule<S: ScalarAd>(x: S, dx: S) -> ((S, S), (S, S)) {
322    let sin_x = sinpi(x);
323    let cos_x = cospi(x);
324    (
325        (sin_x, cos_x),
326        (dx * (pi::<S>() * cos_x), dx * (-(pi::<S>() * sin_x))),
327    )
328}
329
330/// Reverse rule for `sincospi`.
331///
332/// # Examples
333///
334/// ```rust
335/// use chainrules::sincospi_rrule;
336///
337/// let dx = sincospi_rrule(0.25_f64, (1.0, 1.0));
338/// assert!(
339///     (dx - (std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()
340///         - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()))
341///         .abs()
342///         < 1e-12
343/// );
344/// ```
345pub fn sincospi_rrule<S: ScalarAd>(x: S, cotangents: (S, S)) -> S {
346    let (cotangent_sin, cotangent_cos) = cotangents;
347    sinpi_rrule(x, cotangent_sin) + cospi_rrule(x, cotangent_cos)
348}
349
350/// Primal `sind`.
351///
352/// # Examples
353///
354/// ```rust
355/// use chainrules::sind;
356///
357/// assert!((sind(30.0_f64) - 0.5_f64).abs() < 1e-12);
358/// ```
359pub fn sind<S: ScalarAd>(x: S) -> S {
360    sinpi(x / S::from_real(real::<S::Real>(180.0)))
361}
362
363/// Forward rule for `sind`.
364///
365/// # Examples
366///
367/// ```rust
368/// use chainrules::sind_frule;
369///
370/// let (_, dy) = sind_frule(30.0_f64, 1.0);
371/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12);
372/// ```
373pub fn sind_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
374    let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
375    let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx);
376    sinpi_frule(scaled_x, dscaled_x)
377}
378
379/// Reverse rule for `sind`.
380///
381/// # Examples
382///
383/// ```rust
384/// use chainrules::sind_rrule;
385///
386/// let dy = sind_rrule(30.0_f64, 1.0);
387/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12);
388/// ```
389pub fn sind_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
390    let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
391    let scaled_x = scale * x;
392    let dscaled_x = sinpi_rrule(scaled_x, cotangent);
393    let (_, dx) = mul_rrule(scale, x, dscaled_x);
394    dx
395}
396
397/// Primal `cosd`.
398///
399/// # Examples
400///
401/// ```rust
402/// use chainrules::cosd;
403///
404/// assert!((cosd(60.0_f64) - 0.5_f64).abs() < 1e-12);
405/// ```
406pub fn cosd<S: ScalarAd>(x: S) -> S {
407    cospi(x / S::from_real(real::<S::Real>(180.0)))
408}
409
410/// Forward rule for `cosd`.
411///
412/// # Examples
413///
414/// ```rust
415/// use chainrules::cosd_frule;
416///
417/// let (_, dy) = cosd_frule(60.0_f64, 1.0);
418/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12);
419/// ```
420pub fn cosd_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
421    let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
422    let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx);
423    cospi_frule(scaled_x, dscaled_x)
424}
425
426/// Reverse rule for `cosd`.
427///
428/// # Examples
429///
430/// ```rust
431/// use chainrules::cosd_rrule;
432///
433/// let dy = cosd_rrule(60.0_f64, 1.0);
434/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12);
435/// ```
436pub fn cosd_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
437    let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
438    let scaled_x = scale * x;
439    let dscaled_x = cospi_rrule(scaled_x, cotangent);
440    let (_, dx) = mul_rrule(scale, x, dscaled_x);
441    dx
442}
443
444/// Primal `tand`.
445///
446/// # Examples
447///
448/// ```rust
449/// use chainrules::tand;
450///
451/// assert_eq!(tand(45.0_f64), 1.0);
452/// ```
453pub fn tand<S: ScalarAd>(x: S) -> S {
454    if let Some(x_real) = real_input(x) {
455        return S::from_real(tand_real(x_real));
456    }
457    tan(deg2rad::<S>() * x)
458}
459
460/// Forward rule for `tand`.
461///
462/// # Examples
463///
464/// ```rust
465/// use chainrules::tand_frule;
466///
467/// let (_, dy) = tand_frule(45.0_f64, 1.0);
468/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12);
469/// ```
470pub fn tand_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
471    let y = tand(x);
472    let scale = deg2rad::<S>() * (S::from_i32(1) + y * y);
473    (y, dx * scale)
474}
475
476/// Reverse rule for `tand`.
477///
478/// # Examples
479///
480/// ```rust
481/// use chainrules::tand_rrule;
482///
483/// let dy = tand_rrule(45.0_f64, 1.0);
484/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12);
485/// ```
486pub fn tand_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
487    let y = tand(x);
488    let scale = deg2rad::<S>() * (S::from_i32(1) + y * y);
489    cotangent * scale.conj()
490}