chainrules/unary/trig_extra.rs
1use crate::binary::{mul_frule, mul_rrule};
2use crate::unary::{
3 cos, cos_frule, cos_rrule, inv, inv_frule, inv_rrule, sin, sin_frule, sin_rrule, sincos, tan,
4 tan_frule, tan_rrule,
5};
6use crate::ScalarAd;
7use num_traits::{Float, FloatConst, Zero};
8
9fn pi<S: ScalarAd>() -> S {
10 S::from_real(S::Real::PI())
11}
12
13fn real<R: Float>(value: f64) -> R {
14 match R::from(value) {
15 Some(value) => value,
16 None => unreachable!("float constant conversion should succeed"),
17 }
18}
19
20fn deg2rad<S: ScalarAd>() -> S {
21 pi::<S>() / S::from_real(real::<S::Real>(180.0))
22}
23
24fn real_input<S: ScalarAd>(x: S) -> Option<S::Real> {
25 if x.imag().is_zero() {
26 Some(x.real())
27 } else {
28 None
29 }
30}
31
32fn sinpi_real<R: Float + FloatConst>(x: R) -> R {
33 let two = real::<R>(2.0);
34 let reduced = x - (x / two).floor() * two;
35 let zero = real::<R>(0.0);
36 let one = real::<R>(1.0);
37 let half = real::<R>(0.5);
38 let three_half = real::<R>(1.5);
39 if reduced == zero || reduced == one {
40 zero
41 } else if reduced == half {
42 one
43 } else if reduced == three_half {
44 -one
45 } else {
46 (R::PI() * reduced).sin()
47 }
48}
49
50fn cospi_real<R: Float + FloatConst>(x: R) -> R {
51 let two = real::<R>(2.0);
52 let reduced = x - (x / two).floor() * two;
53 let zero = real::<R>(0.0);
54 let one = real::<R>(1.0);
55 let half = real::<R>(0.5);
56 let three_half = real::<R>(1.5);
57 if reduced == zero {
58 one
59 } else if reduced == one {
60 -one
61 } else if reduced == half || reduced == three_half {
62 zero
63 } else {
64 (R::PI() * reduced).cos()
65 }
66}
67
68fn tand_real<R: Float + FloatConst>(x: R) -> R {
69 let one_eighty = real::<R>(180.0);
70 let reduced = x - (x / one_eighty).floor() * one_eighty;
71 let zero = real::<R>(0.0);
72 let forty_five = real::<R>(45.0);
73 let ninety = real::<R>(90.0);
74 let one_thirty_five = real::<R>(135.0);
75 if reduced == zero {
76 zero
77 } else if reduced == forty_five {
78 real::<R>(1.0)
79 } else if reduced == ninety {
80 R::infinity().copysign(sinpi_real(x / one_eighty))
81 } else if reduced == one_thirty_five {
82 real::<R>(-1.0)
83 } else {
84 (R::PI() * reduced / one_eighty).tan()
85 }
86}
87
88/// Primal `sec`.
89///
90/// # Examples
91/// ```rust
92/// use chainrules::sec;
93/// assert!((sec(0.5_f64) - 1.0 / 0.5_f64.cos()).abs() < 1e-12);
94/// ```
95pub fn sec<S: ScalarAd>(x: S) -> S {
96 inv(cos(x))
97}
98
99/// Forward rule for `sec`.
100///
101/// # Examples
102/// ```rust
103/// use chainrules::sec_frule;
104/// let (y, dy) = sec_frule(0.5_f64, 1.0);
105/// assert!((y - 1.0 / 0.5_f64.cos()).abs() < 1e-12);
106/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12);
107/// ```
108pub fn sec_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
109 let (y, dy) = cos_frule(x, dx);
110 inv_frule(y, dy)
111}
112
113/// Reverse rule for `sec`.
114///
115/// # Examples
116/// ```rust
117/// use chainrules::sec_rrule;
118/// let dy = sec_rrule(0.5_f64, 1.0);
119/// assert!((dy - (0.5_f64.sin() / 0.5_f64.cos().powi(2))).abs() < 1e-12);
120/// ```
121pub fn sec_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
122 let y = sec(x);
123 let d_y = inv_rrule(y, cotangent);
124 cos_rrule(x, d_y)
125}
126
127/// Primal `csc`.
128///
129/// # Examples
130/// ```rust
131/// use chainrules::csc;
132/// assert!((csc(0.5_f64) - 1.0 / 0.5_f64.sin()).abs() < 1e-12);
133/// ```
134pub fn csc<S: ScalarAd>(x: S) -> S {
135 inv(sin(x))
136}
137
138/// Forward rule for `csc`.
139///
140/// # Examples
141/// ```rust
142/// use chainrules::csc_frule;
143/// let (_, dy) = csc_frule(0.5_f64, 1.0);
144/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12);
145/// ```
146pub fn csc_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
147 let (y, dy) = sin_frule(x, dx);
148 inv_frule(y, dy)
149}
150
151/// Reverse rule for `csc`.
152///
153/// # Examples
154/// ```rust
155/// use chainrules::csc_rrule;
156/// let dy = csc_rrule(0.5_f64, 1.0);
157/// assert!((dy + 0.5_f64.cos() / 0.5_f64.sin().powi(2)).abs() < 1e-12);
158/// ```
159pub fn csc_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
160 let y = csc(x);
161 let d_y = inv_rrule(y, cotangent);
162 sin_rrule(x, d_y)
163}
164
165/// Primal `cot`.
166///
167/// # Examples
168/// ```rust
169/// use chainrules::cot;
170/// assert!((cot(0.5_f64) - 1.0 / 0.5_f64.tan()).abs() < 1e-12);
171/// ```
172pub fn cot<S: ScalarAd>(x: S) -> S {
173 inv(tan(x))
174}
175
176/// Forward rule for `cot`.
177///
178/// # Examples
179/// ```rust
180/// use chainrules::cot_frule;
181/// let (_, dy) = cot_frule(0.5_f64, 1.0);
182/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12);
183/// ```
184pub fn cot_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
185 let (y, dy) = tan_frule(x, dx);
186 inv_frule(y, dy)
187}
188
189/// Reverse rule for `cot`.
190///
191/// # Examples
192/// ```rust
193/// use chainrules::cot_rrule;
194/// let dy = cot_rrule(0.5_f64, 1.0);
195/// assert!((dy + 1.0 / 0.5_f64.sin().powi(2)).abs() < 1e-12);
196/// ```
197pub fn cot_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
198 let y = cot(x);
199 let d_y = inv_rrule(y, cotangent);
200 tan_rrule(tan(x), d_y)
201}
202
203/// Primal `sinpi`.
204///
205/// # Examples
206///
207/// ```rust
208/// use chainrules::sinpi;
209///
210/// assert!((sinpi(0.25_f64) - 0.25_f64.mul_add(std::f64::consts::PI, 0.0).sin()).abs() < 1e-12);
211/// ```
212pub fn sinpi<S: ScalarAd>(x: S) -> S {
213 if let Some(x_real) = real_input(x) {
214 return S::from_real(sinpi_real(x_real));
215 }
216 sincos(pi::<S>() * x).0
217}
218
219/// Forward rule for `sinpi`.
220///
221/// # Examples
222///
223/// ```rust
224/// use chainrules::sinpi_frule;
225///
226/// let (_, dy) = sinpi_frule(0.25_f64, 1.0);
227/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
228/// ```
229pub fn sinpi_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
230 let y = sinpi(x);
231 let scale = pi::<S>() * cospi(x);
232 (y, dx * scale)
233}
234
235/// Reverse rule for `sinpi`.
236///
237/// # Examples
238///
239/// ```rust
240/// use chainrules::sinpi_rrule;
241///
242/// let dy = sinpi_rrule(0.25_f64, 1.0);
243/// assert!((dy - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
244/// ```
245pub fn sinpi_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
246 cotangent * (pi::<S>() * cospi(x)).conj()
247}
248
249/// Primal `cospi`.
250///
251/// # Examples
252///
253/// ```rust
254/// use chainrules::cospi;
255///
256/// assert!((cospi(0.25_f64) - (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
257/// ```
258pub fn cospi<S: ScalarAd>(x: S) -> S {
259 if let Some(x_real) = real_input(x) {
260 return S::from_real(cospi_real(x_real));
261 }
262 sincos(pi::<S>() * x).1
263}
264
265/// Forward rule for `cospi`.
266///
267/// # Examples
268///
269/// ```rust
270/// use chainrules::cospi_frule;
271///
272/// let (_, dy) = cospi_frule(0.25_f64, 1.0);
273/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12);
274/// ```
275pub fn cospi_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
276 let y = cospi(x);
277 let scale = -(pi::<S>() * sinpi(x));
278 (y, dx * scale)
279}
280
281/// Reverse rule for `cospi`.
282///
283/// # Examples
284///
285/// ```rust
286/// use chainrules::cospi_rrule;
287///
288/// let dy = cospi_rrule(0.25_f64, 1.0);
289/// assert!((dy + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12);
290/// ```
291pub fn cospi_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
292 cotangent * (-(pi::<S>() * sinpi(x))).conj()
293}
294
295/// Primal `sincospi`.
296///
297/// # Examples
298///
299/// ```rust
300/// use chainrules::sincospi;
301///
302/// let (s, c) = sincospi(0.25_f64);
303/// assert!((s - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12);
304/// assert!((c - (std::f64::consts::FRAC_1_SQRT_2)).abs() < 1e-12);
305/// ```
306pub fn sincospi<S: ScalarAd>(x: S) -> (S, S) {
307 (sinpi(x), cospi(x))
308}
309
310/// Forward rule for `sincospi`.
311///
312/// # Examples
313///
314/// ```rust
315/// use chainrules::sincospi_frule;
316///
317/// let ((_, _), (ds, dc)) = sincospi_frule(0.25_f64, 1.0);
318/// assert!((ds - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()).abs() < 1e-12);
319/// assert!((dc + std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()).abs() < 1e-12);
320/// ```
321pub fn sincospi_frule<S: ScalarAd>(x: S, dx: S) -> ((S, S), (S, S)) {
322 let sin_x = sinpi(x);
323 let cos_x = cospi(x);
324 (
325 (sin_x, cos_x),
326 (dx * (pi::<S>() * cos_x), dx * (-(pi::<S>() * sin_x))),
327 )
328}
329
330/// Reverse rule for `sincospi`.
331///
332/// # Examples
333///
334/// ```rust
335/// use chainrules::sincospi_rrule;
336///
337/// let dx = sincospi_rrule(0.25_f64, (1.0, 1.0));
338/// assert!(
339/// (dx - (std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).cos()
340/// - std::f64::consts::PI * (std::f64::consts::PI * 0.25_f64).sin()))
341/// .abs()
342/// < 1e-12
343/// );
344/// ```
345pub fn sincospi_rrule<S: ScalarAd>(x: S, cotangents: (S, S)) -> S {
346 let (cotangent_sin, cotangent_cos) = cotangents;
347 sinpi_rrule(x, cotangent_sin) + cospi_rrule(x, cotangent_cos)
348}
349
350/// Primal `sind`.
351///
352/// # Examples
353///
354/// ```rust
355/// use chainrules::sind;
356///
357/// assert!((sind(30.0_f64) - 0.5_f64).abs() < 1e-12);
358/// ```
359pub fn sind<S: ScalarAd>(x: S) -> S {
360 sinpi(x / S::from_real(real::<S::Real>(180.0)))
361}
362
363/// Forward rule for `sind`.
364///
365/// # Examples
366///
367/// ```rust
368/// use chainrules::sind_frule;
369///
370/// let (_, dy) = sind_frule(30.0_f64, 1.0);
371/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12);
372/// ```
373pub fn sind_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
374 let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
375 let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx);
376 sinpi_frule(scaled_x, dscaled_x)
377}
378
379/// Reverse rule for `sind`.
380///
381/// # Examples
382///
383/// ```rust
384/// use chainrules::sind_rrule;
385///
386/// let dy = sind_rrule(30.0_f64, 1.0);
387/// assert!((dy - std::f64::consts::PI / 180.0 * (30.0_f64.to_radians()).cos()).abs() < 1e-12);
388/// ```
389pub fn sind_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
390 let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
391 let scaled_x = scale * x;
392 let dscaled_x = sinpi_rrule(scaled_x, cotangent);
393 let (_, dx) = mul_rrule(scale, x, dscaled_x);
394 dx
395}
396
397/// Primal `cosd`.
398///
399/// # Examples
400///
401/// ```rust
402/// use chainrules::cosd;
403///
404/// assert!((cosd(60.0_f64) - 0.5_f64).abs() < 1e-12);
405/// ```
406pub fn cosd<S: ScalarAd>(x: S) -> S {
407 cospi(x / S::from_real(real::<S::Real>(180.0)))
408}
409
410/// Forward rule for `cosd`.
411///
412/// # Examples
413///
414/// ```rust
415/// use chainrules::cosd_frule;
416///
417/// let (_, dy) = cosd_frule(60.0_f64, 1.0);
418/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12);
419/// ```
420pub fn cosd_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
421 let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
422 let (scaled_x, dscaled_x) = mul_frule(scale, x, S::from_i32(0), dx);
423 cospi_frule(scaled_x, dscaled_x)
424}
425
426/// Reverse rule for `cosd`.
427///
428/// # Examples
429///
430/// ```rust
431/// use chainrules::cosd_rrule;
432///
433/// let dy = cosd_rrule(60.0_f64, 1.0);
434/// assert!((dy + std::f64::consts::PI / 180.0 * (60.0_f64.to_radians()).sin()).abs() < 1e-12);
435/// ```
436pub fn cosd_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
437 let scale = S::from_real(real::<S::Real>(1.0 / 180.0));
438 let scaled_x = scale * x;
439 let dscaled_x = cospi_rrule(scaled_x, cotangent);
440 let (_, dx) = mul_rrule(scale, x, dscaled_x);
441 dx
442}
443
444/// Primal `tand`.
445///
446/// # Examples
447///
448/// ```rust
449/// use chainrules::tand;
450///
451/// assert_eq!(tand(45.0_f64), 1.0);
452/// ```
453pub fn tand<S: ScalarAd>(x: S) -> S {
454 if let Some(x_real) = real_input(x) {
455 return S::from_real(tand_real(x_real));
456 }
457 tan(deg2rad::<S>() * x)
458}
459
460/// Forward rule for `tand`.
461///
462/// # Examples
463///
464/// ```rust
465/// use chainrules::tand_frule;
466///
467/// let (_, dy) = tand_frule(45.0_f64, 1.0);
468/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12);
469/// ```
470pub fn tand_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
471 let y = tand(x);
472 let scale = deg2rad::<S>() * (S::from_i32(1) + y * y);
473 (y, dx * scale)
474}
475
476/// Reverse rule for `tand`.
477///
478/// # Examples
479///
480/// ```rust
481/// use chainrules::tand_rrule;
482///
483/// let dy = tand_rrule(45.0_f64, 1.0);
484/// assert!((dy - 2.0 * std::f64::consts::PI / 180.0).abs() < 1e-12);
485/// ```
486pub fn tand_rrule<S: ScalarAd>(x: S, cotangent: S) -> S {
487 let y = tand(x);
488 let scale = deg2rad::<S>() * (S::from_i32(1) + y * y);
489 cotangent * scale.conj()
490}