chainrules/unary/complex_parts.rs
1use crate::ScalarAd;
2use num_complex::Complex;
3use num_traits::{Float, One, Zero};
4
5trait ComplexProjectionScalar: ScalarAd {
6 fn from_parts(re: Self::Real, im: Self::Real) -> Self;
7}
8
9impl ComplexProjectionScalar for num_complex::Complex32 {
10 fn from_parts(re: Self::Real, im: Self::Real) -> Self {
11 Complex::new(re, im)
12 }
13}
14
15impl ComplexProjectionScalar for num_complex::Complex64 {
16 fn from_parts(re: Self::Real, im: Self::Real) -> Self {
17 Complex::new(re, im)
18 }
19}
20
21/// Primal `abs`.
22///
23/// # Examples
24///
25/// ```rust
26/// use chainrules::abs;
27/// use num_complex::Complex64;
28///
29/// assert_eq!(abs(Complex64::new(3.0, 4.0)), 5.0);
30/// ```
31#[inline]
32pub fn abs<S: ScalarAd>(x: S) -> S::Real {
33 x.abs()
34}
35
36/// Primal `abs2`.
37///
38/// # Examples
39///
40/// ```rust
41/// use chainrules::abs2;
42/// use num_complex::Complex64;
43///
44/// assert_eq!(abs2(Complex64::new(3.0, 4.0)), 25.0);
45/// ```
46#[inline]
47pub fn abs2<S: ScalarAd>(x: S) -> S::Real {
48 x.abs2()
49}
50
51/// Primal `real`.
52///
53/// # Examples
54///
55/// ```rust
56/// use chainrules::real;
57/// use num_complex::Complex64;
58///
59/// assert_eq!(real(Complex64::new(3.0, 4.0)), 3.0);
60/// ```
61#[inline]
62pub fn real<S: ScalarAd>(x: S) -> S::Real {
63 x.real()
64}
65
66/// Primal `imag`.
67///
68/// # Examples
69///
70/// ```rust
71/// use chainrules::imag;
72/// use num_complex::Complex64;
73///
74/// assert_eq!(imag(Complex64::new(3.0, 4.0)), 4.0);
75/// ```
76#[inline]
77pub fn imag<S: ScalarAd>(x: S) -> S::Real {
78 x.imag()
79}
80
81/// Primal `angle`.
82///
83/// # Examples
84///
85/// ```rust
86/// use chainrules::angle;
87/// use num_complex::Complex64;
88///
89/// assert!((angle(Complex64::new(3.0, 4.0)) - 0.9272952180016122).abs() < 1e-12);
90/// ```
91#[inline]
92pub fn angle<S: ScalarAd>(x: S) -> S::Real {
93 x.angle()
94}
95
96/// Construct a complex number from real and imaginary parts.
97///
98/// # Examples
99///
100/// ```rust
101/// use chainrules::complex;
102/// use num_complex::Complex64;
103///
104/// assert_eq!(complex(3.0_f64, 4.0_f64), Complex64::new(3.0, 4.0));
105/// ```
106#[inline]
107pub fn complex<R: Float>(re: R, im: R) -> Complex<R> {
108 Complex::new(re, im)
109}
110
111/// Forward rule for `abs2`.
112///
113/// # Examples
114///
115/// ```rust
116/// use chainrules::abs2_frule;
117/// use num_complex::Complex64;
118///
119/// let z = Complex64::new(3.0, 4.0);
120/// let dz = Complex64::new(1.0, -2.0);
121/// let (y, dy) = abs2_frule(z, dz);
122/// assert_eq!(y, 25.0);
123/// assert_eq!(dy, -10.0);
124/// ```
125#[inline]
126pub fn abs2_frule<S: ScalarAd>(x: S, dx: S) -> (S::Real, S::Real) {
127 let y = x.abs2();
128 let two = S::Real::one() + S::Real::one();
129 let dy = two * (x.real() * dx.real() + x.imag() * dx.imag());
130 (y, dy)
131}
132
133/// Reverse rule for `abs2`.
134///
135/// # Examples
136///
137/// ```rust
138/// use chainrules::abs2_rrule;
139/// use num_complex::Complex64;
140///
141/// let z = Complex64::new(3.0, 4.0);
142/// assert_eq!(abs2_rrule(z, 1.25), Complex64::new(7.5, 10.0));
143/// ```
144#[inline]
145pub fn abs2_rrule<R: Float>(x: Complex<R>, cotangent: R) -> Complex<R> {
146 let two = R::one() + R::one();
147 Complex::new(two * cotangent * x.re, two * cotangent * x.im)
148}
149
150/// Reverse rule for `real`.
151///
152/// # Examples
153///
154/// ```rust
155/// use chainrules::real_rrule;
156/// use num_complex::Complex64;
157///
158/// let grad: Complex64 = real_rrule(2.0);
159/// assert_eq!(grad, Complex64::new(2.0, 0.0));
160/// ```
161#[inline]
162#[allow(private_bounds)]
163pub fn real_rrule<S: ComplexProjectionScalar>(cotangent: S::Real) -> S {
164 S::from_parts(cotangent, S::Real::zero())
165}
166
167/// Reverse rule for `imag`.
168///
169/// # Examples
170///
171/// ```rust
172/// use chainrules::imag_rrule;
173/// use num_complex::Complex64;
174///
175/// let grad: Complex64 = imag_rrule(2.0);
176/// assert_eq!(grad, Complex64::new(0.0, 2.0));
177/// ```
178#[inline]
179#[allow(private_bounds)]
180pub fn imag_rrule<S: ComplexProjectionScalar>(cotangent: S::Real) -> S {
181 S::from_parts(S::Real::zero(), cotangent)
182}
183
184/// Reverse rule for `angle`.
185///
186/// # Examples
187///
188/// ```rust
189/// use chainrules::angle_rrule;
190/// use num_complex::Complex64;
191///
192/// assert_eq!(angle_rrule(Complex64::new(3.0, 4.0), 1.0), Complex64::new(-0.16, 0.12));
193/// ```
194#[inline]
195pub fn angle_rrule<R: Float>(x: Complex<R>, cotangent: R) -> Complex<R> {
196 let denom = x.re * x.re + x.im * x.im;
197 Complex::new(-x.im * cotangent / denom, x.re * cotangent / denom)
198}