chainrules/
real_ops.rs

1use num_traits::Float;
2
3use crate::ScalarAd;
4
5/// Primal `atan2(y, x)` for ordered real scalars.
6///
7/// # Examples
8///
9/// ```rust
10/// use chainrules::atan2;
11///
12/// assert!((atan2(3.0_f64, 4.0_f64) - 3.0_f64.atan2(4.0_f64)).abs() < 1e-12);
13/// ```
14pub fn atan2<S>(y: S, x: S) -> S
15where
16    S: ScalarAd<Real = S> + Float,
17{
18    Float::atan2(y, x)
19}
20
21/// Forward rule for `atan2(y, x)`.
22///
23/// Returns `(primal, tangent)`.
24///
25/// # Examples
26///
27/// ```rust
28/// use chainrules::atan2_frule;
29///
30/// let (_y, dy) = atan2_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64);
31/// assert!((dy - 0.05_f64).abs() < 1e-12);
32/// ```
33pub fn atan2_frule<S>(y: S, x: S, dy: S, dx: S) -> (S, S)
34where
35    S: ScalarAd<Real = S> + Float,
36{
37    let primal = Float::atan2(y, x);
38    let denom = x * x + y * y;
39    let tangent = dy * (x / denom) + dx * ((-y) / denom);
40    (primal, tangent)
41}
42
43/// Reverse rule for `atan2(y, x)`.
44///
45/// Returns cotangents with respect to `(y, x)`.
46///
47/// # Examples
48///
49/// ```rust
50/// use chainrules::atan2_rrule;
51///
52/// let (dy, dx) = atan2_rrule(3.0_f64, 4.0_f64, 2.0_f64);
53/// assert!((dy - 0.32_f64).abs() < 1e-12);
54/// assert!((dx + 0.24_f64).abs() < 1e-12);
55/// ```
56pub fn atan2_rrule<S>(y: S, x: S, cotangent: S) -> (S, S)
57where
58    S: ScalarAd<Real = S> + Float,
59{
60    let denom = x * x + y * y;
61    (cotangent * (x / denom), cotangent * ((-y) / denom))
62}