Skip to main content

tensor4all_quanticstransform/
shift.rs

1//! Shift operator: f(x) = g(x + offset) mod 2^R
2//!
3//! This transformation shifts the argument by a constant offset.
4
5use anyhow::Result;
6use num_complex::Complex64;
7use num_traits::{One, Zero};
8use tensor4all_simplett::{types::tensor3_zeros, AbstractTensorTrain, Tensor3Ops, TensorTrain};
9
10use crate::common::{
11    embed_single_var_mpo, tensortrain_to_linear_operator,
12    tensortrain_to_linear_operator_asymmetric, BoundaryCondition, QuanticsOperator,
13};
14
15/// Create a shift operator: f(x) = g(x + offset) mod 2^R
16///
17/// This MPO transforms a function g(x) to f(x) = g(x + offset) for x = 0, 1, ..., 2^R - 1.
18///
19/// # Arguments
20/// * `r` - Number of bits (sites)
21/// * `offset` - Shift amount (can be negative)
22/// * `bc` - Boundary condition
23///
24/// # Returns
25/// LinearOperator representing the shift transformation
26///
27/// # Examples
28///
29/// ```
30/// use tensor4all_quanticstransform::{shift_operator, BoundaryCondition};
31///
32/// // Create a shift operator for 4-bit (2^4 = 16 points) quantics representation
33/// let op = shift_operator(4, 3, BoundaryCondition::Periodic).unwrap();
34///
35/// // The operator has one MPO tensor per bit
36/// assert_eq!(op.mpo.node_count(), 4);
37/// ```
38pub fn shift_operator(r: usize, offset: i64, bc: BoundaryCondition) -> Result<QuanticsOperator> {
39    if r == 0 {
40        return Err(anyhow::anyhow!("Number of sites must be positive"));
41    }
42
43    let mpo = shift_mpo(r, offset, bc)?;
44    let site_dims = vec![2; r];
45    tensortrain_to_linear_operator(&mpo, &site_dims)
46}
47
48/// Create a shift operator for one variable in a multi-variable system.
49///
50/// Acts as shift on `target_var` and identity on all other variables.
51/// The resulting operator works on interleaved quantics encoding where each
52/// site has local dimension `2^nvariables`.
53///
54/// # Arguments
55/// * `r` - Number of bits (sites)
56/// * `offset` - Shift amount (can be negative)
57/// * `bc` - Boundary condition
58/// * `nvariables` - Total number of variables (must be at least 2)
59/// * `target_var` - Which variable to shift (0-indexed, must be < nvariables)
60///
61/// # Examples
62///
63/// ```
64/// use tensor4all_quanticstransform::{shift_operator_multivar, BoundaryCondition};
65///
66/// // Shift only the x-variable of a 2-variable function f(x, y) by 3
67/// let op = shift_operator_multivar(4, 3, BoundaryCondition::Periodic, 2, 0).unwrap();
68/// assert_eq!(op.mpo.node_count(), 4);
69/// ```
70pub fn shift_operator_multivar(
71    r: usize,
72    offset: i64,
73    bc: BoundaryCondition,
74    nvariables: usize,
75    target_var: usize,
76) -> Result<QuanticsOperator> {
77    if r == 0 {
78        return Err(anyhow::anyhow!("Number of sites must be positive"));
79    }
80
81    let mpo = shift_mpo(r, offset, bc)?;
82    let embedded = embed_single_var_mpo(&mpo, nvariables, target_var)?;
83    let dim_multi = 1 << nvariables;
84    let dims = vec![dim_multi; r];
85    tensortrain_to_linear_operator_asymmetric(&embedded, &dims, &dims)
86}
87
88/// Create the shift MPO as a TensorTrain.
89///
90/// The shift operation computes x + offset using binary addition with carry propagation.
91/// Uses big-endian convention: site n contains bit 2^(R-1-n) (MSB at site 0).
92/// This matches Julia Quantics.jl's convention.
93///
94/// At each site n, we compute: out_n = x_n + offset_n + carry_in (mod 2)
95/// with carry_out = (x_n + offset_n + carry_in) / 2
96///
97/// Carry propagates from LSB to MSB, so in big-endian:
98/// - Site 0 (MSB): applies BC on left, receives carry from right
99/// - Site R-1 (LSB): initial carry = 0, sends carry to left
100pub(crate) fn shift_mpo(
101    r: usize,
102    offset: i64,
103    bc: BoundaryCondition,
104) -> Result<TensorTrain<Complex64>> {
105    if r == 0 {
106        return Err(anyhow::anyhow!("Number of sites must be positive"));
107    }
108    if r > 63 {
109        anyhow::bail!("Number of sites must be at most 63 to avoid integer overflow");
110    }
111
112    if bc == BoundaryCondition::Open && offset < 0 {
113        let positive = offset
114            .checked_neg()
115            .ok_or_else(|| anyhow::anyhow!("open-boundary shift offset overflow"))?;
116        let mpo = shift_mpo(r, positive, bc)?;
117        return transpose_binary_operator_mpo(&mpo);
118    }
119
120    let n_max = 1i64 << r;
121
122    // Normalize offset to [0, 2^R)
123    let (nbc, offset_mod) = {
124        let offset_mod = offset.rem_euclid(n_max);
125        let nbc = (offset - offset_mod) / n_max;
126        (nbc, offset_mod as usize)
127    };
128
129    // Convert offset to binary (big-endian: MSB first)
130    // Site n contains bit 2^(R-1-n)
131    // offset_bits[n] = bit at position (R-1-n)
132    let offset_bits: Vec<usize> = (0..r).map(|n| (offset_mod >> (r - 1 - n)) & 1).collect();
133
134    let mut tensors = Vec::with_capacity(r);
135
136    // Carry states: index 0 = carry 0, index 1 = carry 1
137    // For addition, carry can be 0 or 1.
138    //
139    // In big-endian with TensorTrain (left-to-right contraction):
140    // - Carry flows right-to-left (LSB at R-1 to MSB at 0)
141    // - t[left, s, right] where left = carry_out (going left), right = carry_in (from right)
142
143    #[allow(clippy::needless_range_loop)]
144    for n in 0..r {
145        let y_bit = offset_bits[n]; // The constant bit at position (R-1-n)
146
147        if r == 1 {
148            // Single site case: no carry propagation needed
149            let mut t = tensor3_zeros(1, 4, 1);
150            for x_bit in 0..2 {
151                let sum = x_bit + y_bit;
152                let out_bit = sum % 2;
153                let bc_factor = match bc {
154                    BoundaryCondition::Periodic => Complex64::one(),
155                    BoundaryCondition::Open => {
156                        if sum >= 2 {
157                            Complex64::zero()
158                        } else {
159                            Complex64::one()
160                        }
161                    }
162                };
163                let s = out_bit * 2 + x_bit;
164                t.set3(0, s, 0, bc_factor);
165            }
166            tensors.push(t);
167        } else if n == 0 {
168            // First tensor (MSB): apply BC on left, receive carry from right
169            // Shape (1, 4, 2): left=1 (BC applied), site=4, right=2 (carry_in from site 1)
170            let bc_val = match bc {
171                BoundaryCondition::Periodic => Complex64::one(),
172                BoundaryCondition::Open => Complex64::zero(),
173            };
174
175            let mut t = tensor3_zeros(1, 4, 2);
176            for carry_in in 0..2 {
177                for x_bit in 0..2 {
178                    let sum = x_bit + y_bit + carry_in;
179                    let out_bit = sum % 2;
180                    let carry_out = sum / 2;
181
182                    // Weight by boundary condition based on carry_out
183                    let weight = if carry_out == 0 {
184                        Complex64::one()
185                    } else {
186                        bc_val
187                    };
188
189                    let s = out_bit * 2 + x_bit;
190                    t.set3(0, s, carry_in, weight);
191                }
192            }
193            tensors.push(t);
194        } else if n == r - 1 {
195            // Last tensor (LSB): initial carry = 0, send carry_out to left
196            // Shape (2, 4, 1): left=2 (carry_out to site R-2), site=4, right=1 (no input)
197            let mut t = tensor3_zeros(2, 4, 1);
198            for x_bit in 0..2 {
199                let sum = x_bit + y_bit; // carry_in = 0 at start
200                let out_bit = sum % 2;
201                let carry_out = sum / 2;
202                let s = out_bit * 2 + x_bit;
203                t.set3(carry_out, s, 0, Complex64::one());
204            }
205            tensors.push(t);
206        } else {
207            // Middle tensors: receive carry from right, send carry to left
208            // Shape (2, 4, 2): left=2 (carry_out), site=4, right=2 (carry_in)
209            let mut t = tensor3_zeros(2, 4, 2);
210            for carry_in in 0..2 {
211                for x_bit in 0..2 {
212                    let sum = x_bit + y_bit + carry_in;
213                    let out_bit = sum % 2;
214                    let carry_out = sum / 2;
215                    let s = out_bit * 2 + x_bit;
216                    t.set3(carry_out, s, carry_in, Complex64::one());
217                }
218            }
219            tensors.push(t);
220        }
221    }
222
223    let mut mpo = TensorTrain::new(tensors)
224        .map_err(|e| anyhow::anyhow!("Failed to create shift MPO: {}", e))?;
225
226    // Apply overall boundary condition factor for number of full cycles
227    if nbc != 0 {
228        let bc_factor = match bc {
229            BoundaryCondition::Periodic => Complex64::one(),
230            BoundaryCondition::Open => {
231                // `nbc` is an Euclidean quotient, so negative offsets in (-n_max, 0)
232                // still produce `nbc = -1`. Only true full-cycle offsets should zero.
233                if offset >= n_max || offset <= -n_max {
234                    Complex64::zero()
235                } else {
236                    Complex64::one()
237                }
238            }
239        };
240        mpo.scale(bc_factor);
241    }
242
243    Ok(mpo)
244}
245
246fn transpose_binary_operator_mpo(mpo: &TensorTrain<Complex64>) -> Result<TensorTrain<Complex64>> {
247    let mut transposed = Vec::with_capacity(mpo.len());
248    for site in 0..mpo.len() {
249        let tensor = mpo.site_tensor(site);
250        let mut new_tensor =
251            tensor3_zeros(tensor.left_dim(), tensor.site_dim(), tensor.right_dim());
252        for left in 0..tensor.left_dim() {
253            for right in 0..tensor.right_dim() {
254                for out_bit in 0..2 {
255                    for in_bit in 0..2 {
256                        let old_site = out_bit * 2 + in_bit;
257                        let new_site = in_bit * 2 + out_bit;
258                        new_tensor.set3(left, new_site, right, *tensor.get3(left, old_site, right));
259                    }
260                }
261            }
262        }
263        transposed.push(new_tensor);
264    }
265
266    TensorTrain::new(transposed)
267        .map_err(|e| anyhow::anyhow!("Failed to transpose binary shift MPO: {}", e))
268}
269
270#[cfg(test)]
271mod tests;