chainrules/unary/
basic.rs

1use crate::ScalarAd;
2
3/// Primal `conj`.
4pub fn conj<S: ScalarAd>(x: S) -> S {
5    x.conj()
6}
7
8/// Forward rule for `conj`.
9pub fn conj_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
10    (x.conj(), dx.conj())
11}
12
13/// Reverse rule for `conj`.
14pub fn conj_rrule<S: ScalarAd>(cotangent: S) -> S {
15    cotangent.conj()
16}
17
18/// Primal `sqrt`.
19pub fn sqrt<S: ScalarAd>(x: S) -> S {
20    x.sqrt()
21}
22
23/// Forward rule for `sqrt`.
24pub fn sqrt_frule<S: ScalarAd>(x: S, dx: S) -> (S, S) {
25    let y = x.sqrt();
26    let dy = dx / (S::from_i32(2) * y);
27    (y, dy)
28}
29
30/// Reverse rule for `sqrt`.
31pub fn sqrt_rrule<S: ScalarAd>(result: S, cotangent: S) -> S {
32    cotangent / (S::from_i32(2) * result.conj())
33}