chainrules/binary.rs
1use crate::ScalarAd;
2
3/// Primal `add`.
4///
5/// # Examples
6///
7/// ```rust
8/// use chainrules::add;
9///
10/// assert_eq!(add(1.5_f64, 2.0_f64), 3.5_f64);
11/// ```
12pub fn add<S: ScalarAd>(x: S, y: S) -> S {
13 x + y
14}
15
16/// Forward rule for `add`.
17///
18/// Returns `(primal, tangent)`.
19///
20/// # Examples
21///
22/// ```rust
23/// use chainrules::add_frule;
24///
25/// let (y, dy) = add_frule(2.0_f64, 3.0_f64, 0.1_f64, 0.2_f64);
26/// assert_eq!(y, 5.0_f64);
27/// assert!((dy - 0.3_f64).abs() < 1e-12);
28/// ```
29pub fn add_frule<S: ScalarAd>(x: S, y: S, dx: S, dy: S) -> (S, S) {
30 (x + y, dx + dy)
31}
32
33/// Reverse rule for `add`.
34///
35/// Returns cotangents with respect to `(x, y)`.
36///
37/// # Examples
38///
39/// ```rust
40/// use chainrules::add_rrule;
41///
42/// let (dx, dy) = add_rrule(1.25_f64);
43/// assert_eq!(dx, 1.25_f64);
44/// assert_eq!(dy, 1.25_f64);
45/// ```
46pub fn add_rrule<S: ScalarAd>(cotangent: S) -> (S, S) {
47 (cotangent, cotangent)
48}
49
50/// Primal `sub`.
51///
52/// # Examples
53///
54/// ```rust
55/// use chainrules::sub;
56///
57/// assert_eq!(sub(5.0_f64, 2.0_f64), 3.0_f64);
58/// ```
59pub fn sub<S: ScalarAd>(x: S, y: S) -> S {
60 x - y
61}
62
63/// Forward rule for `sub`.
64///
65/// Returns `(primal, tangent)`.
66///
67/// # Examples
68///
69/// ```rust
70/// use chainrules::sub_frule;
71///
72/// let (y, dy) = sub_frule(5.0_f64, 2.0_f64, 0.3_f64, 0.1_f64);
73/// assert_eq!(y, 3.0_f64);
74/// assert!((dy - 0.2_f64).abs() < 1e-12);
75/// ```
76pub fn sub_frule<S: ScalarAd>(x: S, y: S, dx: S, dy: S) -> (S, S) {
77 (x - y, dx - dy)
78}
79
80/// Reverse rule for `sub`.
81///
82/// Returns cotangents with respect to `(x, y)`.
83///
84/// # Examples
85///
86/// ```rust
87/// use chainrules::sub_rrule;
88///
89/// let (dx, dy) = sub_rrule(2.0_f64);
90/// assert_eq!(dx, 2.0_f64);
91/// assert_eq!(dy, -2.0_f64);
92/// ```
93pub fn sub_rrule<S: ScalarAd>(cotangent: S) -> (S, S) {
94 (cotangent, -cotangent)
95}
96
97/// Primal `mul`.
98///
99/// # Examples
100///
101/// ```rust
102/// use chainrules::mul;
103///
104/// assert_eq!(mul(2.0_f64, 4.0_f64), 8.0_f64);
105/// ```
106pub fn mul<S: ScalarAd>(x: S, y: S) -> S {
107 x * y
108}
109
110/// Forward rule for `mul`.
111///
112/// Returns `(primal, tangent)`.
113///
114/// # Examples
115///
116/// ```rust
117/// use chainrules::mul_frule;
118///
119/// let (y, dy) = mul_frule(2.0_f64, 4.0_f64, 0.5_f64, 0.25_f64);
120/// assert_eq!(y, 8.0_f64);
121/// assert_eq!(dy, 2.5_f64);
122/// ```
123pub fn mul_frule<S: ScalarAd>(x: S, y: S, dx: S, dy: S) -> (S, S) {
124 let primal = x * y;
125 let tangent = dx * y + dy * x;
126 (primal, tangent)
127}
128
129/// Reverse rule for `mul`.
130///
131/// Returns cotangents with respect to `(x, y)`.
132///
133/// # Examples
134///
135/// ```rust
136/// use chainrules::mul_rrule;
137///
138/// let (dx, dy) = mul_rrule(2.0_f64, 4.0_f64, 1.0_f64);
139/// assert_eq!(dx, 4.0_f64);
140/// assert_eq!(dy, 2.0_f64);
141/// ```
142pub fn mul_rrule<S: ScalarAd>(x: S, y: S, cotangent: S) -> (S, S) {
143 (cotangent * y.conj(), cotangent * x.conj())
144}
145
146/// Primal `div`.
147///
148/// # Examples
149///
150/// ```rust
151/// use chainrules::div;
152///
153/// assert_eq!(div(8.0_f64, 2.0_f64), 4.0_f64);
154/// ```
155pub fn div<S: ScalarAd>(x: S, y: S) -> S {
156 x / y
157}
158
159/// Forward rule for `div`.
160///
161/// Returns `(primal, tangent)`.
162///
163/// When `y` is zero, the derivative produces NaN/Inf following IEEE 754
164/// semantics, consistent with standard AD behavior for division by zero.
165///
166/// # Examples
167///
168/// ```rust
169/// use chainrules::div_frule;
170///
171/// let (y, dy) = div_frule(8.0_f64, 2.0_f64, 0.5_f64, 0.25_f64);
172/// assert_eq!(y, 4.0_f64);
173/// assert!((dy + 0.25_f64).abs() < 1e-12);
174/// ```
175pub fn div_frule<S: ScalarAd>(x: S, y: S, dx: S, dy: S) -> (S, S) {
176 let primal = x / y;
177 let inv_y = S::from_i32(1) / y;
178 let dfdx = inv_y;
179 let dfdy = -(x * inv_y * inv_y);
180 let tangent = dx * dfdx + dy * dfdy;
181 (primal, tangent)
182}
183
184/// Reverse rule for `div`.
185///
186/// Returns cotangents with respect to `(x, y)`.
187///
188/// When `y` is zero, the derivatives produce NaN/Inf following IEEE 754
189/// semantics, consistent with standard AD behavior for division by zero.
190///
191/// # Examples
192///
193/// ```rust
194/// use chainrules::div_rrule;
195///
196/// let (dx, dy) = div_rrule(8.0_f64, 2.0_f64, 1.0_f64);
197/// assert_eq!(dx, 0.5_f64);
198/// assert_eq!(dy, -2.0_f64);
199/// ```
200pub fn div_rrule<S: ScalarAd>(x: S, y: S, cotangent: S) -> (S, S) {
201 let inv_y = S::from_i32(1) / y;
202 let dfdx = inv_y.conj();
203 let dfdy = (-(x * inv_y * inv_y)).conj();
204 (cotangent * dfdx, cotangent * dfdy)
205}