chainrules/unary/
nonsmooth.rs

1use num_traits::Float;
2
3/// Primal `round`.
4///
5/// The corresponding forward and reverse rules use a zero-gradient policy at
6/// every point, including integer inputs.
7///
8/// # Examples
9///
10/// ```rust
11/// use chainrules::round;
12///
13/// assert_eq!(round(1.4_f64), 1.0);
14/// assert_eq!(round(1.5_f64), 2.0);
15/// ```
16pub fn round<R: Float>(x: R) -> R {
17    x.round()
18}
19
20/// Forward rule for `round`.
21///
22/// The tangent is always zero.
23///
24/// # Examples
25///
26/// ```rust
27/// use chainrules::round_frule;
28///
29/// let (y, dy) = round_frule(1.6_f64, 0.75);
30/// assert_eq!(y, 2.0);
31/// assert_eq!(dy, 0.0);
32/// ```
33pub fn round_frule<R: Float>(x: R, _dx: R) -> (R, R) {
34    (x.round(), R::zero())
35}
36
37/// Reverse rule for `round`.
38///
39/// The cotangent is always zero.
40///
41/// # Examples
42///
43/// ```rust
44/// use chainrules::round_rrule;
45///
46/// assert_eq!(round_rrule(1.0_f64, 0.5), 0.0);
47/// ```
48pub fn round_rrule<R: Float>(_x: R, _cotangent: R) -> R {
49    R::zero()
50}
51
52/// Primal `floor`.
53///
54/// The corresponding forward and reverse rules use a zero-gradient policy at
55/// every point.
56///
57/// # Examples
58///
59/// ```rust
60/// use chainrules::floor;
61///
62/// assert_eq!(floor(1.9_f64), 1.0);
63/// ```
64pub fn floor<R: Float>(x: R) -> R {
65    x.floor()
66}
67
68/// Forward rule for `floor`.
69///
70/// The tangent is always zero.
71///
72/// # Examples
73///
74/// ```rust
75/// use chainrules::floor_frule;
76///
77/// let (y, dy) = floor_frule(1.6_f64, 0.75);
78/// assert_eq!(y, 1.0);
79/// assert_eq!(dy, 0.0);
80/// ```
81pub fn floor_frule<R: Float>(x: R, _dx: R) -> (R, R) {
82    (x.floor(), R::zero())
83}
84
85/// Reverse rule for `floor`.
86///
87/// The cotangent is always zero.
88///
89/// # Examples
90///
91/// ```rust
92/// use chainrules::floor_rrule;
93///
94/// assert_eq!(floor_rrule(1.0_f64, 0.5), 0.0);
95/// ```
96pub fn floor_rrule<R: Float>(_x: R, _cotangent: R) -> R {
97    R::zero()
98}
99
100/// Primal `ceil`.
101///
102/// The corresponding forward and reverse rules use a zero-gradient policy at
103/// every point.
104///
105/// # Examples
106///
107/// ```rust
108/// use chainrules::ceil;
109///
110/// assert_eq!(ceil(1.1_f64), 2.0);
111/// ```
112pub fn ceil<R: Float>(x: R) -> R {
113    x.ceil()
114}
115
116/// Forward rule for `ceil`.
117///
118/// The tangent is always zero.
119///
120/// # Examples
121///
122/// ```rust
123/// use chainrules::ceil_frule;
124///
125/// let (y, dy) = ceil_frule(1.1_f64, 0.75);
126/// assert_eq!(y, 2.0);
127/// assert_eq!(dy, 0.0);
128/// ```
129pub fn ceil_frule<R: Float>(x: R, _dx: R) -> (R, R) {
130    (x.ceil(), R::zero())
131}
132
133/// Reverse rule for `ceil`.
134///
135/// The cotangent is always zero.
136///
137/// # Examples
138///
139/// ```rust
140/// use chainrules::ceil_rrule;
141///
142/// assert_eq!(ceil_rrule(1.0_f64, 0.5), 0.0);
143/// ```
144pub fn ceil_rrule<R: Float>(_x: R, _cotangent: R) -> R {
145    R::zero()
146}
147
148/// Primal `sign`.
149///
150/// The primal follows Julia-style `sign`: it returns signed zero for zero
151/// inputs, `+1`/`-1` for positive/negative infinities, and `x.signum()`
152/// otherwise.
153///
154/// The corresponding forward and reverse rules use a zero-gradient policy at
155/// every point.
156///
157/// # Examples
158///
159/// ```rust
160/// use chainrules::sign;
161///
162/// assert_eq!(sign(-3.0_f64), -1.0);
163/// assert_eq!(sign(0.0_f64), 0.0);
164/// assert_eq!(sign(-0.0_f64).is_sign_negative(), true);
165/// ```
166pub fn sign<R: Float>(x: R) -> R {
167    if x == R::zero() {
168        x
169    } else {
170        x.signum()
171    }
172}
173
174/// Forward rule for `sign`.
175///
176/// The tangent is always zero.
177///
178/// # Examples
179///
180/// ```rust
181/// use chainrules::sign_frule;
182///
183/// let (y, dy) = sign_frule(-2.0_f64, 0.75);
184/// assert_eq!(y, -1.0);
185/// assert_eq!(dy, 0.0);
186/// ```
187pub fn sign_frule<R: Float>(x: R, _dx: R) -> (R, R) {
188    (sign(x), R::zero())
189}
190
191/// Reverse rule for `sign`.
192///
193/// The cotangent is always zero.
194///
195/// # Examples
196///
197/// ```rust
198/// use chainrules::sign_rrule;
199///
200/// assert_eq!(sign_rrule(1.0_f64, 0.5), 0.0);
201/// ```
202pub fn sign_rrule<R: Float>(_x: R, _cotangent: R) -> R {
203    R::zero()
204}