tensor4all_quanticstransform/shift.rs
1//! Shift operator: f(x) = g(x + offset) mod 2^R
2//!
3//! This transformation shifts the argument by a constant offset.
4
5use anyhow::Result;
6use num_complex::Complex64;
7use num_traits::{One, Zero};
8use tensor4all_simplett::{types::tensor3_zeros, AbstractTensorTrain, Tensor3Ops, TensorTrain};
9
10use crate::common::{
11 embed_single_var_mpo, tensortrain_to_linear_operator,
12 tensortrain_to_linear_operator_asymmetric, BoundaryCondition, QuanticsOperator,
13};
14
15/// Create a shift operator: f(x) = g(x + offset) mod 2^R
16///
17/// This MPO transforms a function g(x) to f(x) = g(x + offset) for x = 0, 1, ..., 2^R - 1.
18///
19/// # Arguments
20/// * `r` - Number of bits (sites)
21/// * `offset` - Shift amount (can be negative)
22/// * `bc` - Boundary condition
23///
24/// # Returns
25/// LinearOperator representing the shift transformation
26///
27/// # Examples
28///
29/// ```
30/// use tensor4all_quanticstransform::{shift_operator, BoundaryCondition};
31///
32/// // Create a shift operator for 4-bit (2^4 = 16 points) quantics representation
33/// let op = shift_operator(4, 3, BoundaryCondition::Periodic).unwrap();
34///
35/// // The operator has one MPO tensor per bit
36/// assert_eq!(op.mpo.node_count(), 4);
37/// ```
38pub fn shift_operator(r: usize, offset: i64, bc: BoundaryCondition) -> Result<QuanticsOperator> {
39 if r == 0 {
40 return Err(anyhow::anyhow!("Number of sites must be positive"));
41 }
42
43 let mpo = shift_mpo(r, offset, bc)?;
44 let site_dims = vec![2; r];
45 tensortrain_to_linear_operator(&mpo, &site_dims)
46}
47
48/// Create a shift operator for one variable in a multi-variable system.
49///
50/// Acts as shift on `target_var` and identity on all other variables.
51/// The resulting operator works on interleaved quantics encoding where each
52/// site has local dimension `2^nvariables`.
53///
54/// # Arguments
55/// * `r` - Number of bits (sites)
56/// * `offset` - Shift amount (can be negative)
57/// * `bc` - Boundary condition
58/// * `nvariables` - Total number of variables (must be at least 2)
59/// * `target_var` - Which variable to shift (0-indexed, must be < nvariables)
60///
61/// # Examples
62///
63/// ```
64/// use tensor4all_quanticstransform::{shift_operator_multivar, BoundaryCondition};
65///
66/// // Shift only the x-variable of a 2-variable function f(x, y) by 3
67/// let op = shift_operator_multivar(4, 3, BoundaryCondition::Periodic, 2, 0).unwrap();
68/// assert_eq!(op.mpo.node_count(), 4);
69/// ```
70pub fn shift_operator_multivar(
71 r: usize,
72 offset: i64,
73 bc: BoundaryCondition,
74 nvariables: usize,
75 target_var: usize,
76) -> Result<QuanticsOperator> {
77 if r == 0 {
78 return Err(anyhow::anyhow!("Number of sites must be positive"));
79 }
80
81 let mpo = shift_mpo(r, offset, bc)?;
82 let embedded = embed_single_var_mpo(&mpo, nvariables, target_var)?;
83 let dim_multi = 1 << nvariables;
84 let dims = vec![dim_multi; r];
85 tensortrain_to_linear_operator_asymmetric(&embedded, &dims, &dims)
86}
87
88/// Create the shift MPO as a TensorTrain.
89///
90/// The shift operation computes x + offset using binary addition with carry propagation.
91/// Uses big-endian convention: site n contains bit 2^(R-1-n) (MSB at site 0).
92/// This matches Julia Quantics.jl's convention.
93///
94/// At each site n, we compute: out_n = x_n + offset_n + carry_in (mod 2)
95/// with carry_out = (x_n + offset_n + carry_in) / 2
96///
97/// Carry propagates from LSB to MSB, so in big-endian:
98/// - Site 0 (MSB): applies BC on left, receives carry from right
99/// - Site R-1 (LSB): initial carry = 0, sends carry to left
100pub(crate) fn shift_mpo(
101 r: usize,
102 offset: i64,
103 bc: BoundaryCondition,
104) -> Result<TensorTrain<Complex64>> {
105 if r == 0 {
106 return Err(anyhow::anyhow!("Number of sites must be positive"));
107 }
108 if r > 63 {
109 anyhow::bail!("Number of sites must be at most 63 to avoid integer overflow");
110 }
111
112 if bc == BoundaryCondition::Open && offset < 0 {
113 let positive = offset
114 .checked_neg()
115 .ok_or_else(|| anyhow::anyhow!("open-boundary shift offset overflow"))?;
116 let mpo = shift_mpo(r, positive, bc)?;
117 return transpose_binary_operator_mpo(&mpo);
118 }
119
120 let n_max = 1i64 << r;
121
122 // Normalize offset to [0, 2^R)
123 let (nbc, offset_mod) = {
124 let offset_mod = offset.rem_euclid(n_max);
125 let nbc = (offset - offset_mod) / n_max;
126 (nbc, offset_mod as usize)
127 };
128
129 // Convert offset to binary (big-endian: MSB first)
130 // Site n contains bit 2^(R-1-n)
131 // offset_bits[n] = bit at position (R-1-n)
132 let offset_bits: Vec<usize> = (0..r).map(|n| (offset_mod >> (r - 1 - n)) & 1).collect();
133
134 let mut tensors = Vec::with_capacity(r);
135
136 // Carry states: index 0 = carry 0, index 1 = carry 1
137 // For addition, carry can be 0 or 1.
138 //
139 // In big-endian with TensorTrain (left-to-right contraction):
140 // - Carry flows right-to-left (LSB at R-1 to MSB at 0)
141 // - t[left, s, right] where left = carry_out (going left), right = carry_in (from right)
142
143 #[allow(clippy::needless_range_loop)]
144 for n in 0..r {
145 let y_bit = offset_bits[n]; // The constant bit at position (R-1-n)
146
147 if r == 1 {
148 // Single site case: no carry propagation needed
149 let mut t = tensor3_zeros(1, 4, 1);
150 for x_bit in 0..2 {
151 let sum = x_bit + y_bit;
152 let out_bit = sum % 2;
153 let bc_factor = match bc {
154 BoundaryCondition::Periodic => Complex64::one(),
155 BoundaryCondition::Open => {
156 if sum >= 2 {
157 Complex64::zero()
158 } else {
159 Complex64::one()
160 }
161 }
162 };
163 let s = out_bit * 2 + x_bit;
164 t.set3(0, s, 0, bc_factor);
165 }
166 tensors.push(t);
167 } else if n == 0 {
168 // First tensor (MSB): apply BC on left, receive carry from right
169 // Shape (1, 4, 2): left=1 (BC applied), site=4, right=2 (carry_in from site 1)
170 let bc_val = match bc {
171 BoundaryCondition::Periodic => Complex64::one(),
172 BoundaryCondition::Open => Complex64::zero(),
173 };
174
175 let mut t = tensor3_zeros(1, 4, 2);
176 for carry_in in 0..2 {
177 for x_bit in 0..2 {
178 let sum = x_bit + y_bit + carry_in;
179 let out_bit = sum % 2;
180 let carry_out = sum / 2;
181
182 // Weight by boundary condition based on carry_out
183 let weight = if carry_out == 0 {
184 Complex64::one()
185 } else {
186 bc_val
187 };
188
189 let s = out_bit * 2 + x_bit;
190 t.set3(0, s, carry_in, weight);
191 }
192 }
193 tensors.push(t);
194 } else if n == r - 1 {
195 // Last tensor (LSB): initial carry = 0, send carry_out to left
196 // Shape (2, 4, 1): left=2 (carry_out to site R-2), site=4, right=1 (no input)
197 let mut t = tensor3_zeros(2, 4, 1);
198 for x_bit in 0..2 {
199 let sum = x_bit + y_bit; // carry_in = 0 at start
200 let out_bit = sum % 2;
201 let carry_out = sum / 2;
202 let s = out_bit * 2 + x_bit;
203 t.set3(carry_out, s, 0, Complex64::one());
204 }
205 tensors.push(t);
206 } else {
207 // Middle tensors: receive carry from right, send carry to left
208 // Shape (2, 4, 2): left=2 (carry_out), site=4, right=2 (carry_in)
209 let mut t = tensor3_zeros(2, 4, 2);
210 for carry_in in 0..2 {
211 for x_bit in 0..2 {
212 let sum = x_bit + y_bit + carry_in;
213 let out_bit = sum % 2;
214 let carry_out = sum / 2;
215 let s = out_bit * 2 + x_bit;
216 t.set3(carry_out, s, carry_in, Complex64::one());
217 }
218 }
219 tensors.push(t);
220 }
221 }
222
223 let mut mpo = TensorTrain::new(tensors)
224 .map_err(|e| anyhow::anyhow!("Failed to create shift MPO: {}", e))?;
225
226 // Apply overall boundary condition factor for number of full cycles
227 if nbc != 0 {
228 let bc_factor = match bc {
229 BoundaryCondition::Periodic => Complex64::one(),
230 BoundaryCondition::Open => {
231 // `nbc` is an Euclidean quotient, so negative offsets in (-n_max, 0)
232 // still produce `nbc = -1`. Only true full-cycle offsets should zero.
233 if offset >= n_max || offset <= -n_max {
234 Complex64::zero()
235 } else {
236 Complex64::one()
237 }
238 }
239 };
240 mpo.scale(bc_factor);
241 }
242
243 Ok(mpo)
244}
245
246fn transpose_binary_operator_mpo(mpo: &TensorTrain<Complex64>) -> Result<TensorTrain<Complex64>> {
247 let mut transposed = Vec::with_capacity(mpo.len());
248 for site in 0..mpo.len() {
249 let tensor = mpo.site_tensor(site);
250 let mut new_tensor =
251 tensor3_zeros(tensor.left_dim(), tensor.site_dim(), tensor.right_dim());
252 for left in 0..tensor.left_dim() {
253 for right in 0..tensor.right_dim() {
254 for out_bit in 0..2 {
255 for in_bit in 0..2 {
256 let old_site = out_bit * 2 + in_bit;
257 let new_site = in_bit * 2 + out_bit;
258 new_tensor.set3(left, new_site, right, *tensor.get3(left, old_site, right));
259 }
260 }
261 }
262 }
263 transposed.push(new_tensor);
264 }
265
266 TensorTrain::new(transposed)
267 .map_err(|e| anyhow::anyhow!("Failed to transpose binary shift MPO: {}", e))
268}
269
270#[cfg(test)]
271mod tests;