chainrules/binary_special.rs
1use num_traits::Float;
2
3fn select_first_for_min<R: Float>(x: R, y: R) -> bool {
4 !x.is_nan() && (y.is_nan() || x < y)
5}
6
7fn select_first_for_max<R: Float>(x: R, y: R) -> bool {
8 !x.is_nan() && (y.is_nan() || x > y)
9}
10
11/// Primal `hypot`.
12///
13/// # Examples
14///
15/// ```rust
16/// use chainrules::hypot;
17///
18/// assert_eq!(hypot(3.0_f64, 4.0_f64), 5.0);
19/// ```
20pub fn hypot<R: Float>(x: R, y: R) -> R {
21 x.hypot(y)
22}
23
24/// Forward rule for `hypot`.
25///
26/// # Examples
27///
28/// ```rust
29/// use chainrules::hypot_frule;
30///
31/// let (r, dr) = hypot_frule(3.0_f64, 4.0_f64, 0.5_f64, 0.25_f64);
32/// assert_eq!(r, 5.0);
33/// assert!((dr - 0.5).abs() < 1e-12);
34/// ```
35pub fn hypot_frule<R: Float>(x: R, y: R, dx: R, dy: R) -> (R, R) {
36 let r = x.hypot(y);
37 let inv_r = R::one() / r;
38 (r, dx * (x * inv_r) + dy * (y * inv_r))
39}
40
41/// Reverse rule for `hypot`.
42///
43/// # Examples
44///
45/// ```rust
46/// use chainrules::hypot_rrule;
47///
48/// let (dx, dy) = hypot_rrule(3.0_f64, 4.0_f64, 1.0_f64);
49/// assert!((dx - 0.6).abs() < 1e-12);
50/// assert!((dy - 0.8).abs() < 1e-12);
51/// ```
52pub fn hypot_rrule<R: Float>(x: R, y: R, cotangent: R) -> (R, R) {
53 let r = x.hypot(y);
54 let inv_r = R::one() / r;
55 (cotangent * (x * inv_r), cotangent * (y * inv_r))
56}
57
58/// Primal `min`.
59///
60/// The primal follows `Float::min`. For differentiation, ties route the
61/// tangent/cotangent to the second argument. If exactly one input is `NaN`,
62/// the non-`NaN` input receives the gradient.
63///
64/// # Examples
65///
66/// ```rust
67/// use chainrules::min;
68///
69/// assert_eq!(min(1.5_f64, 2.5_f64), 1.5);
70/// assert_eq!(min(2.0_f64, 2.0_f64), 2.0);
71/// ```
72pub fn min<R: Float>(x: R, y: R) -> R {
73 x.min(y)
74}
75
76/// Forward rule for `min`.
77///
78/// When `x == y`, the tangent comes from `y`.
79///
80/// # Examples
81///
82/// ```rust
83/// use chainrules::min_frule;
84///
85/// let (z, dz) = min_frule(1.0_f64, 2.0_f64, 0.25, 0.5);
86/// assert_eq!(z, 1.0);
87/// assert_eq!(dz, 0.25);
88/// ```
89pub fn min_frule<R: Float>(x: R, y: R, dx: R, dy: R) -> (R, R) {
90 let z = x.min(y);
91 if select_first_for_min(x, y) {
92 (z, dx)
93 } else {
94 (z, dy)
95 }
96}
97
98/// Reverse rule for `min`.
99///
100/// When `x == y`, the cotangent goes to `y`. If exactly one input is `NaN`,
101/// the non-`NaN` input receives the cotangent.
102///
103/// # Examples
104///
105/// ```rust
106/// use chainrules::min_rrule;
107///
108/// let (dx, dy) = min_rrule(1.0_f64, 2.0_f64, 0.5);
109/// assert_eq!(dx, 0.5);
110/// assert_eq!(dy, 0.0);
111/// ```
112pub fn min_rrule<R: Float>(x: R, y: R, cotangent: R) -> (R, R) {
113 if select_first_for_min(x, y) {
114 (cotangent, R::zero())
115 } else {
116 (R::zero(), cotangent)
117 }
118}
119
120/// Primal `max`.
121///
122/// The primal follows `Float::max`. For differentiation, ties route the
123/// tangent/cotangent to the second argument. If exactly one input is `NaN`,
124/// the non-`NaN` input receives the gradient.
125///
126/// # Examples
127///
128/// ```rust
129/// use chainrules::max;
130///
131/// assert_eq!(max(1.5_f64, 2.5_f64), 2.5);
132/// assert_eq!(max(2.0_f64, 2.0_f64), 2.0);
133/// ```
134pub fn max<R: Float>(x: R, y: R) -> R {
135 x.max(y)
136}
137
138/// Forward rule for `max`.
139///
140/// When `x == y`, the tangent comes from `y`.
141///
142/// # Examples
143///
144/// ```rust
145/// use chainrules::max_frule;
146///
147/// let (z, dz) = max_frule(1.0_f64, 2.0_f64, 0.25, 0.5);
148/// assert_eq!(z, 2.0);
149/// assert_eq!(dz, 0.5);
150/// ```
151pub fn max_frule<R: Float>(x: R, y: R, dx: R, dy: R) -> (R, R) {
152 let z = x.max(y);
153 if select_first_for_max(x, y) {
154 (z, dx)
155 } else {
156 (z, dy)
157 }
158}
159
160/// Reverse rule for `max`.
161///
162/// When `x == y`, the cotangent goes to `y`. If exactly one input is `NaN`,
163/// the non-`NaN` input receives the cotangent.
164///
165/// # Examples
166///
167/// ```rust
168/// use chainrules::max_rrule;
169///
170/// let (dx, dy) = max_rrule(1.0_f64, 2.0_f64, 0.5);
171/// assert_eq!(dx, 0.0);
172/// assert_eq!(dy, 0.5);
173/// ```
174pub fn max_rrule<R: Float>(x: R, y: R, cotangent: R) -> (R, R) {
175 if select_first_for_max(x, y) {
176 (cotangent, R::zero())
177 } else {
178 (R::zero(), cotangent)
179 }
180}