tensor4all_quanticstransform/affine.rs
1//! Affine transformation operator: y = A*x + b
2//!
3//! This implements general affine transformations with rational coefficients.
4//! The transformation computes y = A*x + b where A is an M×N rational matrix
5//! and b is an M-dimensional rational vector.
6//!
7//! Based on the algorithm from Quantics.jl/src/affine.jl
8
9use std::collections::HashMap;
10
11use anyhow::Result;
12use num_complex::Complex64;
13use num_integer::Integer;
14use num_rational::Rational64;
15use num_traits::One;
16use sprs::CsMat;
17use tensor4all_simplett::{types::tensor3_zeros, AbstractTensorTrain, Tensor3Ops, TensorTrain};
18
19use crate::common::{
20 tensortrain_to_linear_operator_asymmetric, BoundaryCondition, QuanticsOperator,
21};
22use tensor4all_simplett::tensor::Tensor3 as GenericTensor3;
23
24#[derive(Clone, Debug)]
25struct BoolTensor<const N: usize> {
26 data: Vec<u8>,
27 dims: [usize; N],
28}
29
30type BoolTensor2 = BoolTensor<2>;
31type BoolTensor3 = BoolTensor<3>;
32
33impl<const N: usize> BoolTensor<N> {
34 fn from_elem(dims: [usize; N], value: bool) -> Self {
35 let total: usize = dims.iter().product();
36 Self {
37 data: vec![u8::from(value); total],
38 dims,
39 }
40 }
41
42 fn dims(&self) -> &[usize; N] {
43 &self.dims
44 }
45
46 fn get(&self, idx: [usize; N]) -> bool {
47 self.data[self.offset(&idx)] != 0
48 }
49
50 fn set(&mut self, idx: [usize; N], value: bool) {
51 let offset = self.offset(&idx);
52 self.data[offset] = u8::from(value);
53 }
54
55 fn offset(&self, idx: &[usize; N]) -> usize {
56 let mut stride = 1usize;
57 let mut offset = 0usize;
58 for axis in (0..N).rev() {
59 offset += idx[axis] * stride;
60 stride *= self.dims[axis];
61 }
62 offset
63 }
64}
65
66/// Affine transformation parameters.
67///
68/// Represents the transformation y = A*x + b where:
69/// - A is an M x N matrix stored in column-major order
70/// - b is an M-dimensional vector
71/// - x is an N-dimensional input
72/// - y is an M-dimensional output
73///
74/// # Examples
75///
76/// ```
77/// use tensor4all_quanticstransform::AffineParams;
78/// use num_rational::Rational64;
79///
80/// // 1D shift: y = x + 3
81/// let params = AffineParams::from_integers(vec![1], vec![3], 1, 1).unwrap();
82/// assert_eq!(params.m, 1);
83/// assert_eq!(params.n, 1);
84///
85/// // 2D rotation: y = [[1,1],[1,-1]] * x
86/// // Column-major: [A[0,0], A[1,0], A[0,1], A[1,1]]
87/// let params = AffineParams::from_integers(
88/// vec![1, 1, 1, -1], vec![0, 0], 2, 2
89/// ).unwrap();
90/// assert_eq!(params.m, 2);
91/// assert_eq!(params.n, 2);
92///
93/// // With rational coefficients: y = (1/2)*x
94/// let params = AffineParams::new(
95/// vec![Rational64::new(1, 2)],
96/// vec![Rational64::from_integer(0)],
97/// 1, 1,
98/// ).unwrap();
99/// ```
100#[derive(Clone, Debug)]
101pub struct AffineParams {
102 /// Transformation matrix A (M×N), stored in column-major order
103 pub a: Vec<Rational64>,
104 /// Translation vector b (M elements)
105 pub b: Vec<Rational64>,
106 /// Number of output dimensions (M)
107 pub m: usize,
108 /// Number of input dimensions (N)
109 pub n: usize,
110}
111
112impl AffineParams {
113 /// Create new affine parameters.
114 ///
115 /// # Arguments
116 /// * `a` - M x N matrix in column-major order (length must be m*n)
117 /// * `b` - M-dimensional translation vector (length must be m)
118 /// * `m` - Number of output dimensions
119 /// * `n` - Number of input dimensions
120 ///
121 /// # Examples
122 ///
123 /// ```
124 /// use tensor4all_quanticstransform::AffineParams;
125 /// use num_rational::Rational64;
126 ///
127 /// // 1D identity: y = x
128 /// let params = AffineParams::new(
129 /// vec![Rational64::from_integer(1)],
130 /// vec![Rational64::from_integer(0)],
131 /// 1, 1,
132 /// ).unwrap();
133 ///
134 /// // Dimension mismatch errors
135 /// assert!(AffineParams::new(
136 /// vec![Rational64::from_integer(1)],
137 /// vec![Rational64::from_integer(0)],
138 /// 2, 1, // expects 2 elements in A, got 1
139 /// ).is_err());
140 /// ```
141 pub fn new(a: Vec<Rational64>, b: Vec<Rational64>, m: usize, n: usize) -> Result<Self> {
142 if a.len() != m * n {
143 return Err(anyhow::anyhow!(
144 "Matrix A has {} elements but expected {}×{}={}",
145 a.len(),
146 m,
147 n,
148 m * n
149 ));
150 }
151 if b.len() != m {
152 return Err(anyhow::anyhow!(
153 "Vector b has {} elements but expected {}",
154 b.len(),
155 m
156 ));
157 }
158 Ok(Self { a, b, m, n })
159 }
160
161 /// Create affine parameters from integer matrix and vector.
162 ///
163 /// Convenience method that converts integer values to rationals.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// use tensor4all_quanticstransform::AffineParams;
169 ///
170 /// // 2D: y = [[1, 0], [0, 1]] * x + [1, 2] (shift by (1,2))
171 /// let params = AffineParams::from_integers(
172 /// vec![1, 0, 0, 1], vec![1, 2], 2, 2,
173 /// ).unwrap();
174 /// assert_eq!(params.m, 2);
175 /// assert_eq!(params.n, 2);
176 /// ```
177 pub fn from_integers(a: Vec<i64>, b: Vec<i64>, m: usize, n: usize) -> Result<Self> {
178 let a_rat: Vec<Rational64> = a.into_iter().map(Rational64::from_integer).collect();
179 let b_rat: Vec<Rational64> = b.into_iter().map(Rational64::from_integer).collect();
180 Self::new(a_rat, b_rat, m, n)
181 }
182
183 /// Get element A[i, j] (0-indexed)
184 #[allow(dead_code)]
185 fn get_a(&self, i: usize, j: usize) -> Rational64 {
186 self.a[i + self.m * j]
187 }
188
189 /// Convert to integer representation by scaling with LCM of denominators.
190 /// Returns (A_int, b_int, scale) where A_int = scale * A and b_int = scale * b.
191 fn to_integer_scaled(&self) -> (Vec<i64>, Vec<i64>, i64) {
192 // Find LCM of all denominators
193 let mut denom_lcm = 1i64;
194 for r in &self.a {
195 denom_lcm = denom_lcm.lcm(r.denom());
196 }
197 for r in &self.b {
198 denom_lcm = denom_lcm.lcm(r.denom());
199 }
200
201 // Scale to integers
202 let a_int: Vec<i64> = self
203 .a
204 .iter()
205 .map(|r| (r * denom_lcm).to_integer())
206 .collect();
207 let b_int: Vec<i64> = self
208 .b
209 .iter()
210 .map(|r| (r * denom_lcm).to_integer())
211 .collect();
212
213 (a_int, b_int, denom_lcm)
214 }
215}
216
217/// Remap site indices of the affine MPO from internal encoding to the convention
218/// expected by `tensortrain_to_linear_operator_asymmetric`.
219///
220/// Internal encoding: `site_idx = y_bits | (x_bits << m)` (y-minor, x-major)
221/// Expected encoding: `s = s_out * in_dim + s_in = y_bits * 2^n + x_bits` (x-minor, y-major)
222fn remap_affine_site_indices(
223 mpo: &TensorTrain<Complex64>,
224 m: usize,
225 n: usize,
226 site_dim: usize,
227) -> Result<TensorTrain<Complex64>> {
228 let input_dim = 1 << n;
229
230 // Build permutation table: perm[old_idx] = remapped index
231 let perm: Vec<usize> = (0..site_dim)
232 .map(|old_idx| {
233 let y_bits = old_idx & ((1 << m) - 1);
234 let x_bits = old_idx >> m;
235 y_bits * input_dim + x_bits
236 })
237 .collect();
238
239 let r = mpo.len();
240 let mut new_tensors = Vec::with_capacity(r);
241
242 for i in 0..r {
243 let tensor = mpo.site_tensor(i);
244 let left_dim = tensor.left_dim();
245 let right_dim = tensor.right_dim();
246
247 let mut t = tensor3_zeros(left_dim, site_dim, right_dim);
248 for l in 0..left_dim {
249 for (old_s, &new_s) in perm.iter().enumerate() {
250 for rr in 0..right_dim {
251 let val = *tensor.get3(l, old_s, rr);
252 if val != Complex64::new(0.0, 0.0) {
253 t.set3(l, new_s, rr, val);
254 }
255 }
256 }
257 }
258 new_tensors.push(t);
259 }
260
261 TensorTrain::new(new_tensors)
262 .map_err(|e| anyhow::anyhow!("Failed to create remapped MPO: {}", e))
263}
264
265/// Create the operator that realizes the coordinate map `y = A * x + b`.
266///
267/// This is the **forward** affine operator. It maps a quantics tensor train
268/// representing an `N`-variable state `x` to the quantics tensor train of
269/// the `M`-variable state `y = A * x + b`.
270///
271/// To build the **pullback** (`f(y) = g(A * y + b)`), call `.transpose()`
272/// on the returned operator; the pullback is exactly the transpose of the
273/// forward operator.
274///
275/// # Arguments
276///
277/// * `r` — bits per variable (number of sites in the output MPO).
278/// * `params` — rational `M × N` matrix `A` and `M`-vector `b` describing
279/// the affine map.
280/// * `bc` — length `M` array of boundary conditions for each output variable.
281/// `Periodic` wraps output coordinates modulo `2^r`; `Open` zeroes the
282/// out-of-range contributions.
283///
284/// # Errors
285///
286/// Returns an error if `r == 0` or if `bc.len() != params.m`.
287///
288/// # Examples
289///
290/// ```
291/// use tensor4all_quanticstransform::{affine_operator, AffineParams, BoundaryCondition};
292/// use num_rational::Rational64;
293///
294/// // Transform g(x, y) -> g(x + y, x - y) (rotation by 45 degrees, scaled)
295/// let a = vec![
296/// Rational64::from_integer(1), Rational64::from_integer(1), // row 0: x + y
297/// Rational64::from_integer(1), Rational64::from_integer(-1), // row 1: x - y
298/// ];
299/// let b = vec![Rational64::from_integer(0), Rational64::from_integer(0)];
300/// let params = AffineParams::new(a, b, 2, 2).unwrap();
301/// let bc = vec![BoundaryCondition::Periodic; 2];
302/// let op = affine_operator(4, ¶ms, &bc).unwrap();
303/// assert_eq!(op.mpo.node_count(), 4);
304/// ```
305///
306/// Using integer convenience constructor:
307///
308/// ```
309/// use tensor4all_quanticstransform::{affine_operator, AffineParams, BoundaryCondition};
310///
311/// // Identity transform: y = x (1D)
312/// let params = AffineParams::from_integers(vec![1], vec![0], 1, 1).unwrap();
313/// let bc = vec![BoundaryCondition::Periodic];
314/// let op = affine_operator(4, ¶ms, &bc).unwrap();
315/// assert_eq!(op.mpo.node_count(), 4);
316/// ```
317pub fn affine_operator(
318 r: usize,
319 params: &AffineParams,
320 bc: &[BoundaryCondition],
321) -> Result<QuanticsOperator> {
322 if r == 0 {
323 return Err(anyhow::anyhow!("Number of bits must be positive"));
324 }
325 if bc.len() != params.m {
326 return Err(anyhow::anyhow!(
327 "Boundary conditions length {} doesn't match output dimensions {}",
328 bc.len(),
329 params.m
330 ));
331 }
332
333 let mpo = affine_transform_mpo(r, params, bc)?;
334 // Site dimensions: M output variables, N input variables
335 // Input dimension per site: 2^N (N input bits)
336 // Output dimension per site: 2^M (M output bits)
337 let m = params.m;
338 let n = params.n;
339 let input_dim = 1 << n;
340 let output_dim = 1 << m;
341
342 // The internal affine MPO uses site encoding: site_idx = y_bits | (x_bits << m)
343 // (y-minor, x-major). But tensortrain_to_linear_operator_asymmetric expects
344 // s = s_out * in_dim + s_in = y_bits * 2^N + x_bits (x-minor, y-major).
345 // We need to remap the site indices.
346 let site_dim = input_dim * output_dim;
347 let remapped_mpo = remap_affine_site_indices(&mpo, m, n, site_dim)?;
348
349 let input_dims = vec![input_dim; r];
350 let output_dims = vec![output_dim; r];
351 tensortrain_to_linear_operator_asymmetric(&remapped_mpo, &input_dims, &output_dims)
352}
353
354/// Compute the full affine transformation matrix directly (for verification).
355///
356/// This computes the transformation matrix by directly evaluating y = A*x + b
357/// for all possible input values. The result is a sparse boolean matrix.
358///
359/// # Arguments
360/// * `r` - Number of bits per variable
361/// * `params` - Affine transformation parameters
362/// * `bc` - Boundary conditions for each output variable
363///
364/// # Returns
365/// Sparse matrix of size 2^(R*M) × 2^(R*N) where entry (y_flat, x_flat) = 1
366/// if the transformation maps x to y.
367///
368/// # Note
369/// This is only practical for small R due to exponential size.
370/// Use for testing/verification only.
371pub fn affine_transform_matrix(
372 r: usize,
373 params: &AffineParams,
374 bc: &[BoundaryCondition],
375) -> Result<CsMat<f64>> {
376 if r == 0 {
377 return Err(anyhow::anyhow!("Number of bits must be positive"));
378 }
379 if bc.len() != params.m {
380 return Err(anyhow::anyhow!(
381 "Boundary conditions length {} doesn't match output dimensions {}",
382 bc.len(),
383 params.m
384 ));
385 }
386
387 let (a_int, b_int, scale) = params.to_integer_scaled();
388 let m = params.m;
389 let n = params.n;
390
391 let bc_periodic: Vec<bool> = bc
392 .iter()
393 .map(|b| matches!(b, BoundaryCondition::Periodic))
394 .collect();
395
396 let input_size = 1usize << (r * n); // 2^(R*N)
397 let output_size = 1usize << (r * m); // 2^(R*M)
398 let modulus = 1i64 << r; // 2^R
399
400 let mut rows = Vec::new();
401 let mut cols = Vec::new();
402 let mut vals = Vec::new();
403
404 let mask = modulus - 1; // 2^R - 1
405
406 // Iterate over all (x, y) pairs, matching Julia's approach.
407 // For periodic BC with scale > 1, multiple y values can satisfy
408 // scale * y ≡ A*x + b (mod 2^R), so we must check all pairs.
409 for x_flat in 0..input_size {
410 // Decode x_flat to N-dimensional x vector
411 // x_flat = x[0] + x[1]*2^R + x[2]*2^(2R) + ...
412 let x: Vec<i64> = (0..n)
413 .map(|var| ((x_flat >> (var * r)) & ((1 << r) - 1)) as i64)
414 .collect();
415
416 // Compute v = A*x + b (unscaled)
417 let mut v: Vec<i64> = vec![0; m];
418 for i in 0..m {
419 v[i] = b_int[i];
420 for j in 0..n {
421 v[i] += a_int[i + m * j] * x[j];
422 }
423 }
424
425 for y_flat in 0..output_size {
426 // Decode y_flat to M-dimensional y vector
427 let y: Vec<i64> = (0..m)
428 .map(|var| ((y_flat >> (var * r)) & ((1 << r) - 1)) as i64)
429 .collect();
430
431 // Compute scale * y
432 let sy: Vec<i64> = y.iter().map(|&yi| scale * yi).collect();
433
434 // Check equiv(v, s*y, R, boundary) per component
435 let equiv = v.iter().zip(sy.iter()).enumerate().all(|(i, (&vi, &syi))| {
436 if bc_periodic[i] {
437 // Periodic: v ≡ s*y (mod 2^R)
438 (vi - syi) & mask == 0
439 } else {
440 // Open: v == s*y (exact)
441 vi == syi
442 }
443 });
444
445 if equiv {
446 rows.push(y_flat);
447 cols.push(x_flat);
448 vals.push(1.0);
449 }
450 }
451 }
452
453 // Build sparse matrix in CSR format
454 let triplet = sprs::TriMat::from_triplets((output_size, input_size), rows, cols, vals);
455 Ok(triplet.to_csr())
456}
457
458/// Create the affine transformation MPO as a TensorTrain.
459fn affine_transform_mpo(
460 r: usize,
461 params: &AffineParams,
462 bc: &[BoundaryCondition],
463) -> Result<TensorTrain<Complex64>> {
464 let (a_int, b_int, scale) = params.to_integer_scaled();
465 let m = params.m;
466 let n = params.n;
467
468 // Convert boundary conditions to weights
469 let bc_periodic: Vec<bool> = bc
470 .iter()
471 .map(|b| matches!(b, BoundaryCondition::Periodic))
472 .collect();
473
474 // Compute core tensors
475 let tensors = affine_transform_tensors(r, &a_int, &b_int, scale, m, n, &bc_periodic)?;
476
477 TensorTrain::new(tensors)
478 .map_err(|e| anyhow::anyhow!("Failed to create affine transform MPO: {}", e))
479}
480
481/// Create unfused affine transformation tensors.
482///
483/// Returns a vector of R tensors, where each tensor has shape:
484/// `[left_bond, 2, 2, ..., 2, right_bond]` with M+N physical indices of dimension 2.
485///
486/// The physical index order matches Quantics.jl:
487/// `(y[1], y[2], ..., y[M], x[1], x[2], ..., x[N])`
488/// where y are output variables and x are input variables.
489///
490/// # Arguments
491/// * `r` - Number of bits per variable (number of sites)
492/// * `params` - Affine transformation parameters
493/// * `bc` - Boundary conditions for each output variable
494///
495/// # Returns
496/// Vector of R tensors with unfused physical indices.
497///
498/// # Examples
499///
500/// ```
501/// use tensor4all_quanticstransform::{
502/// affine_transform_tensors_unfused, AffineParams, BoundaryCondition,
503/// };
504/// use tensor4all_simplett::Tensor3Ops;
505///
506/// let params = AffineParams::from_integers(vec![1, 1, 0, 1], vec![0, 0], 2, 2).unwrap();
507/// let bc = vec![BoundaryCondition::Periodic; 2];
508/// let tensors = affine_transform_tensors_unfused(4, ¶ms, &bc).unwrap();
509///
510/// // One tensor per site
511/// assert_eq!(tensors.len(), 4);
512///
513/// // Each tensor has fused site_dim = 2^(M+N) = 16 for M=2, N=2
514/// assert_eq!(tensors[0].site_dim(), 16);
515/// ```
516pub fn affine_transform_tensors_unfused(
517 r: usize,
518 params: &AffineParams,
519 bc: &[BoundaryCondition],
520) -> Result<Vec<GenericTensor3<Complex64>>> {
521 if r == 0 {
522 return Err(anyhow::anyhow!("Number of bits must be positive"));
523 }
524 if bc.len() != params.m {
525 return Err(anyhow::anyhow!(
526 "Boundary conditions length {} doesn't match output dimensions {}",
527 bc.len(),
528 params.m
529 ));
530 }
531
532 let (a_int, b_int, scale) = params.to_integer_scaled();
533 let m = params.m;
534 let n = params.n;
535
536 // Convert boundary conditions to weights
537 let bc_periodic: Vec<bool> = bc
538 .iter()
539 .map(|b| matches!(b, BoundaryCondition::Periodic))
540 .collect();
541
542 // Compute fused tensors first
543 let fused_tensors = affine_transform_tensors(r, &a_int, &b_int, scale, m, n, &bc_periodic)?;
544
545 // Convert fused tensors to unfused format
546 // Fused: [left, fused_site, right] where fused_site = 2^(M+N)
547 // Unfused: [left, 2, 2, ..., 2, right] with M+N dimensions of size 2
548 //
549 // Fused index encoding: site_idx = y_bits | (x_bits << M)
550 // where y_bits = y[0] + 2*y[1] + ... + 2^(M-1)*y[M-1]
551 // and x_bits = x[0] + 2*x[1] + ... + 2^(N-1)*x[N-1]
552 //
553 // Quantics.jl order: (y[0], y[1], ..., y[M-1], x[0], x[1], ..., x[N-1])
554 // We preserve that semantic index order:
555 // unfused[left, y0, y1, ..., yM-1, x0, x1, ..., xN-1, right]
556
557 let mut unfused_tensors = Vec::with_capacity(r);
558 let site_dim = 1 << (m + n);
559
560 for tensor in fused_tensors.iter() {
561 let left_dim = tensor.left_dim();
562 let right_dim = tensor.right_dim();
563
564 // Create unfused tensor
565 // Shape: [left_dim, 2^(M+N), right_dim] but we keep it as 3D for now
566 // The reshape to (M+N+2)-dimensional tensor will be done by the caller if needed
567 // For now, we provide a 3D tensor where the middle dimension is the fused site
568 // and document how to unfuse it.
569 //
570 // Actually, let's return it properly unfused using a flat storage with
571 // the correct index order for reshape.
572 //
573 // Total size: left_dim * 2^(M+N) * right_dim
574 // Shape for unfused: [left_dim, 2, 2, ..., 2, right_dim]
575 //
576 // Index mapping from fused to unfused:
577 // fused site_idx -> (y0, y1, ..., yM-1, x0, x1, ..., xN-1)
578 // site_idx = y0 + 2*y1 + ... + 2^(M-1)*yM-1 + 2^M * (x0 + 2*x1 + ...)
579
580 // Preserve the Quantics.jl physical index order
581 // (y0, y1, ..., yM-1, x0, x1, ..., xN-1).
582
583 let mut unfused_data = vec![Complex64::new(0.0, 0.0); left_dim * site_dim * right_dim];
584
585 for l in 0..left_dim {
586 for fused_idx in 0..site_dim {
587 for rr in 0..right_dim {
588 let val = tensor.get3(l, fused_idx, rr);
589 if val.norm() > 0.0 {
590 // The fused index encodes: site_idx = y_bits | (x_bits << M)
591 // This matches Quantics.jl's ordering (y variables first, then x)
592 // so we can use fused_idx directly.
593 //
594 // The caller can reshape [left, site_dim, right] to
595 // [left, 2, 2, ..., 2, right] with M+N dimensions of size 2,
596 // where indices are in order (y[0], y[1], ..., y[M-1], x[0], ..., x[N-1])
597 let flat_idx = l * site_dim * right_dim + fused_idx * right_dim + rr;
598 unfused_data[flat_idx] = *val;
599 }
600 }
601 }
602 }
603
604 // Create Tensor3 with shape [left_dim, site_dim, right_dim]
605 // The caller can reshape this to [left_dim, 2, 2, ..., 2, right_dim]
606 // with the understanding that the indices are ordered as (y0, y1, ..., x0, x1, ...)
607 let mut unfused_tensor =
608 GenericTensor3::from_elem([left_dim, site_dim, right_dim], Complex64::new(0.0, 0.0));
609 for l in 0..left_dim {
610 for s in 0..site_dim {
611 for r in 0..right_dim {
612 unfused_tensor[[l, s, r]] =
613 unfused_data[l * site_dim * right_dim + s * right_dim + r];
614 }
615 }
616 }
617
618 unfused_tensors.push(unfused_tensor);
619 }
620
621 Ok(unfused_tensors)
622}
623
624/// Information about the unfused tensor structure.
625///
626/// This helper provides metadata for reshaping the unfused tensors
627/// produced by [`affine_transform_tensors_unfused`].
628///
629/// # Examples
630///
631/// ```
632/// use tensor4all_quanticstransform::{AffineParams, UnfusedTensorInfo};
633///
634/// let params = AffineParams::from_integers(vec![1, 0, 0, 1], vec![0, 0], 2, 2).unwrap();
635/// let info = UnfusedTensorInfo::new(¶ms);
636///
637/// assert_eq!(info.m, 2);
638/// assert_eq!(info.n, 2);
639/// assert_eq!(info.num_physical_dims, 4);
640///
641/// // Get shape for a tensor with bond dims 3 and 5
642/// let shape = info.unfused_shape(3, 5);
643/// assert_eq!(shape, vec![3, 2, 2, 2, 2, 5]);
644///
645/// // Round-trip encode/decode
646/// let fused = info.encode_fused_index(&[1, 0], &[0, 1]);
647/// let (y_bits, x_bits) = info.decode_fused_index(fused);
648/// assert_eq!(y_bits, vec![1, 0]);
649/// assert_eq!(x_bits, vec![0, 1]);
650/// ```
651#[derive(Clone, Debug)]
652pub struct UnfusedTensorInfo {
653 /// Number of output variables (M)
654 pub m: usize,
655 /// Number of input variables (N)
656 pub n: usize,
657 /// Total physical dimensions per site (M + N)
658 pub num_physical_dims: usize,
659 /// Dimension of each physical index (always 2)
660 pub physical_dim: usize,
661}
662
663impl UnfusedTensorInfo {
664 /// Create info for the given affine parameters.
665 pub fn new(params: &AffineParams) -> Self {
666 Self {
667 m: params.m,
668 n: params.n,
669 num_physical_dims: params.m + params.n,
670 physical_dim: 2,
671 }
672 }
673
674 /// Get the shape for a fully unfused tensor at a given site.
675 ///
676 /// Returns `[left_bond, 2, 2, ..., 2, right_bond]` where there are M+N 2s.
677 pub fn unfused_shape(&self, left_bond: usize, right_bond: usize) -> Vec<usize> {
678 let mut shape = Vec::with_capacity(2 + self.num_physical_dims);
679 shape.push(left_bond);
680 shape.extend(std::iter::repeat_n(2, self.num_physical_dims));
681 shape.push(right_bond);
682 shape
683 }
684
685 /// Decode a fused site index to individual variable bits.
686 ///
687 /// Returns `(y_bits, x_bits)` where:
688 /// - `y_bits[i]` is the bit for output variable i
689 /// - `x_bits[j]` is the bit for input variable j
690 pub fn decode_fused_index(&self, fused_idx: usize) -> (Vec<usize>, Vec<usize>) {
691 let y_combined = fused_idx & ((1 << self.m) - 1);
692 let x_combined = fused_idx >> self.m;
693
694 let y_bits: Vec<usize> = (0..self.m).map(|i| (y_combined >> i) & 1).collect();
695 let x_bits: Vec<usize> = (0..self.n).map(|j| (x_combined >> j) & 1).collect();
696
697 (y_bits, x_bits)
698 }
699
700 /// Encode individual variable bits to a fused site index.
701 ///
702 /// # Arguments
703 /// * `y_bits` - Bits for output variables (length M)
704 /// * `x_bits` - Bits for input variables (length N)
705 pub fn encode_fused_index(&self, y_bits: &[usize], x_bits: &[usize]) -> usize {
706 let y_combined: usize = y_bits.iter().enumerate().map(|(i, &b)| b << i).sum();
707 let x_combined: usize = x_bits.iter().enumerate().map(|(j, &b)| b << j).sum();
708 y_combined | (x_combined << self.m)
709 }
710}
711
712/// Compute the core tensors for the affine transformation.
713///
714/// This implements the algorithm from Quantics.jl that handles:
715/// - Carry propagation for multi-bit arithmetic
716/// - Scaling factor s from rational to integer conversion
717///
718/// Uses big-endian convention: site 0 = MSB, site R-1 = LSB.
719///
720/// Carry propagation direction (matching shift.rs):
721/// - Arithmetic carry flows LSB → MSB (physical fact)
722/// - In big-endian: site R-1 → site 0 (right → left)
723/// - Tensor structure: t[left, site, right] where left=carry_out (going left), right=carry_in (from right)
724/// - Site 0 (MSB): BC applied on left, receives carry from right → shape (1, site_dim, num_carries)
725/// - Site R-1 (LSB): initial carry=0, sends carry to left → shape (num_carries, site_dim, 1)
726/// - Middle sites: shape (num_carries, site_dim, num_carries)
727fn affine_transform_tensors(
728 r: usize,
729 a_int: &[i64],
730 b_int: &[i64],
731 scale: i64,
732 m: usize,
733 n: usize,
734 bc_periodic: &[bool],
735) -> Result<Vec<tensor4all_simplett::Tensor3<Complex64>>> {
736 let site_dim = 1 << (m + n); // 2^(M+N) for fused representation
737
738 // Track sign separately and work with absolute value
739 // so that right-shifting always terminates (Julia PR #45 approach)
740 let bsign: Vec<i64> = b_int.iter().map(|&b| if b >= 0 { 1 } else { -1 }).collect();
741 let mut b_work: Vec<i64> = b_int.iter().map(|&b| b.abs()).collect();
742
743 // Process from LSB (site R-1) to MSB (site 0)
744 let mut carries: Vec<Vec<i64>> = vec![vec![0i64; m]];
745 let mut core_data_list: Vec<AffineCoreData> = Vec::with_capacity(r);
746
747 for _site in (0..r).rev() {
748 // Extract current bit: (b_work & 1) * bsign
749 let b_curr: Vec<i64> = b_work
750 .iter()
751 .zip(bsign.iter())
752 .map(|(&b, &s)| (b & 1) * s)
753 .collect();
754
755 let core_data = affine_transform_core(a_int, &b_curr, scale, m, n, &carries, true)?;
756 carries = core_data.carries_out.clone();
757 core_data_list.push(core_data);
758
759 // Shift right
760 b_work.iter_mut().for_each(|b| *b >>= 1);
761 }
762
763 // core_data_list is now in order: [site R-1, site R-2, ..., site 0]
764
765 // Extension loop: handle remaining bits of b for Open BC
766 // When abs(b) >= 2^R, high bits of b contribute to carries that affect validity.
767 // Extension tensors have site_dim=1 (activebit=false: only x=0, y=0).
768 // We fold them into the MSB tensor as a "cap matrix" (Julia approach).
769 let cap_matrix: Option<Vec<f64>> = if !bc_periodic.iter().all(|&p| p)
770 && b_work.iter().any(|&b| b > 0)
771 {
772 let mut ext_data_list: Vec<AffineCoreData> = Vec::new();
773 while b_work.iter().any(|&b| b > 0) {
774 let b_curr: Vec<i64> = b_work
775 .iter()
776 .zip(bsign.iter())
777 .map(|(&b, &s)| (b & 1) * s)
778 .collect();
779
780 let core_data = affine_transform_core(a_int, &b_curr, scale, m, n, &carries, false)?;
781 carries = core_data.carries_out.clone();
782 ext_data_list.push(core_data);
783
784 b_work.iter_mut().for_each(|b| *b >>= 1);
785 }
786
787 // Build cap matrix by contracting extension tensors with BC weights.
788 // Extension tensors have site_dim=1, so they are carry transition matrices:
789 // ext_matrix[cout_idx, cin_idx] = core_data.tensor[[cout_idx, cin_idx, 0]]
790 //
791 // Process: outermost (last computed) gets BC weights applied,
792 // then multiply inward toward the main tensor chain.
793
794 // Start with BC weights on the final carries
795 let bc_weights: Vec<f64> = carries
796 .iter()
797 .map(|c| {
798 if c.iter().all(|&ci| ci == 0) {
799 1.0
800 } else {
801 0.0
802 }
803 })
804 .collect();
805
806 // Contract extension tensors from outermost to innermost
807 // ext_data_list is [innermost, ..., outermost] (order of computation)
808 // We process from outermost to innermost
809 let mut current_weights = bc_weights;
810 for ext_data in ext_data_list.iter().rev() {
811 let num_cin = ext_data.tensor.dims()[1];
812 let mut new_weights = vec![0.0; num_cin];
813 for (cin_idx, nw) in new_weights.iter_mut().enumerate() {
814 for (cout_idx, &w) in current_weights.iter().enumerate() {
815 if w != 0.0 && ext_data.tensor.get([cout_idx, cin_idx, 0]) {
816 *nw += w;
817 }
818 }
819 }
820 current_weights = new_weights;
821 }
822
823 // current_weights now maps: MSB carry_out index -> effective BC weight
824 Some(current_weights)
825 } else {
826 None
827 };
828
829 // Build tensors in the same order, then reverse to get [site 0, site 1, ..., site R-1]
830 let mut tensors = Vec::with_capacity(r);
831
832 // Helper: compute BC weight for a carry-out index
833 let compute_bc_weight = |cout_idx: usize, core_data: &AffineCoreData| -> Complex64 {
834 if bc_periodic.iter().all(|&p| p) {
835 Complex64::one()
836 } else if let Some(ref cap) = cap_matrix {
837 // Extension loop was used: weight comes from cap matrix
838 Complex64::new(cap[cout_idx], 0.0)
839 } else {
840 // No extension: weight is 1 if carry is zero, 0 otherwise
841 let carry = &core_data.carries_out[cout_idx];
842 if carry.iter().all(|&c| c == 0) {
843 Complex64::one()
844 } else {
845 Complex64::new(0.0, 0.0)
846 }
847 }
848 };
849
850 for (idx, core_data) in core_data_list.iter().enumerate() {
851 // idx=0 corresponds to site R-1 (LSB), idx=R-1 corresponds to site 0 (MSB)
852 let actual_site = r - 1 - idx;
853 let num_carry_out = core_data.carries_out.len();
854 let num_carry_in = core_data.tensor.dims()[1];
855
856 // Tensor shape follows shift.rs pattern:
857 // t[left, site, right] where left=carry_out (going left), right=carry_in (from right)
858 //
859 // - Site 0 (MSB): left_dim=1 (BC applied), right_dim=num_carry (receives from right)
860 // - Site R-1 (LSB): left_dim=num_carry (sends to left), right_dim=1 (initial carry=0)
861 // - Middle: left_dim=num_carry, right_dim=num_carry
862 let is_msb = actual_site == 0;
863 let is_lsb = actual_site == r - 1;
864
865 let left_dim = if is_msb { 1 } else { num_carry_out };
866 let right_dim = if is_lsb { 1 } else { num_carry_in };
867
868 let mut t: tensor4all_simplett::Tensor3<Complex64> =
869 tensor3_zeros(left_dim, site_dim, right_dim);
870
871 if is_lsb && is_msb {
872 // R==1: single site case
873 for cout_idx in 0..num_carry_out {
874 let bc_weight = compute_bc_weight(cout_idx, core_data);
875
876 for site_idx in 0..site_dim {
877 if core_data.tensor.get([cout_idx, 0, site_idx]) {
878 let old = t.get3(0, site_idx, 0);
879 t.set3(0, site_idx, 0, *old + bc_weight);
880 }
881 }
882 }
883 } else if is_lsb {
884 // LSB (site R-1): initial carry_in=0, send carry_out to left
885 // Shape (num_carry_out, site_dim, 1)
886 // core_data.tensor[carry_out_idx, carry_in_idx, site_idx]
887 // Only carry_in_idx=0 matters (initial carry is the first entry: zero vector)
888 for cout_idx in 0..num_carry_out {
889 for site_idx in 0..site_dim {
890 if core_data.tensor.get([cout_idx, 0, site_idx]) {
891 t.set3(cout_idx, site_idx, 0, Complex64::one());
892 }
893 }
894 }
895 } else if is_msb {
896 // MSB (site 0): apply BC on carry_out, receive carry from right
897 for cout_idx in 0..num_carry_out {
898 let bc_weight = compute_bc_weight(cout_idx, core_data);
899
900 for cin_idx in 0..num_carry_in {
901 for site_idx in 0..site_dim {
902 if core_data.tensor.get([cout_idx, cin_idx, site_idx]) {
903 let old = t.get3(0, site_idx, cin_idx);
904 t.set3(0, site_idx, cin_idx, *old + bc_weight);
905 }
906 }
907 }
908 }
909 } else {
910 // Middle tensors: receive carry from right, send carry to left
911 // Shape (num_carry_out, site_dim, num_carry_in)
912 for cout_idx in 0..num_carry_out {
913 for cin_idx in 0..num_carry_in {
914 for site_idx in 0..site_dim {
915 if core_data.tensor.get([cout_idx, cin_idx, site_idx]) {
916 t.set3(cout_idx, site_idx, cin_idx, Complex64::one());
917 }
918 }
919 }
920 }
921 }
922
923 tensors.push(t);
924 }
925
926 // tensors is in order [site R-1, ..., site 0], reverse to get [site 0, ..., site R-1]
927 tensors.reverse();
928
929 Ok(tensors)
930}
931
932/// Core tensor data for affine transformation.
933///
934/// Shape: (num_carry_out, num_carry_in, site_dim)
935/// where site_dim = 2^(M+N)
936struct AffineCoreData {
937 /// Possible outgoing carry vectors
938 carries_out: Vec<Vec<i64>>,
939 /// Tensor data: tensor[carry_out_idx, carry_in_idx, site_idx]
940 tensor: BoolTensor3,
941}
942
943/// Compute a single core tensor for the affine transformation.
944///
945/// The core tensor encodes: 2 * carry_out = A * x + b_curr - scale * y + carry_in
946///
947/// Returns AffineCoreData containing:
948/// - carries_out: list of possible outgoing carry vectors
949/// - tensor: shape (num_carry_out, num_carry_in, site_dim)
950fn affine_transform_core(
951 a_int: &[i64],
952 b_curr: &[i64],
953 scale: i64,
954 m: usize,
955 n: usize,
956 carries_in: &[Vec<i64>],
957 activebit: bool,
958) -> Result<AffineCoreData> {
959 let mut carry_out_map: HashMap<Vec<i64>, BoolTensor2> = HashMap::new();
960 let x_range = if activebit { 1 << n } else { 1 };
961 let y_range = if activebit { 1 << m } else { 1 };
962 let site_dim = x_range * y_range;
963 let num_carry_in = carries_in.len();
964
965 // Iterate over all input carries
966 for (c_idx, carry_in) in carries_in.iter().enumerate() {
967 // Iterate over all possible x values (N bits)
968 for x_bits in 0..x_range {
969 let x: Vec<i64> = (0..n).map(|j| ((x_bits >> j) & 1) as i64).collect();
970
971 // Compute z = A*x + b + carry_in
972 let mut z: Vec<i64> = vec![0; m];
973 for i in 0..m {
974 z[i] = carry_in[i] + b_curr[i];
975 for j in 0..n {
976 z[i] += a_int[i + m * j] * x[j];
977 }
978 }
979
980 if scale % 2 == 1 {
981 // Scale is odd: unique y that satisfies condition
982 let y: Vec<i64> = z.iter().map(|&zi| zi & 1).collect();
983
984 // When bits are inactive, y must be zero (Julia PR #45 fix)
985 if !activebit && y.iter().any(|&yi| yi != 0) {
986 continue;
987 }
988
989 let y_bits: usize = y
990 .iter()
991 .enumerate()
992 .map(|(i, &yi)| (yi as usize) << i)
993 .sum();
994
995 // Compute carry_out = (z - scale * y) / 2
996 let carry_out: Vec<i64> = z
997 .iter()
998 .zip(y.iter())
999 .map(|(&zi, &yi)| (zi - scale * yi) >> 1)
1000 .collect();
1001
1002 // Site index: y bits in lower positions, x bits in upper positions
1003 let site_idx = y_bits | (x_bits << m);
1004
1005 let entry = carry_out_map
1006 .entry(carry_out)
1007 .or_insert_with(|| BoolTensor2::from_elem([num_carry_in, site_dim], false));
1008 entry.set([c_idx, site_idx], true);
1009 } else {
1010 // Scale is even: z must be even for valid y
1011 if z.iter().any(|&zi| zi % 2 != 0) {
1012 continue;
1013 }
1014
1015 // y can be any value
1016 for y_bits in 0..y_range {
1017 let y: Vec<i64> = (0..m).map(|i| ((y_bits >> i) & 1) as i64).collect();
1018
1019 // Compute carry_out = (z - scale * y) / 2
1020 let carry_out: Vec<i64> = z
1021 .iter()
1022 .zip(y.iter())
1023 .map(|(&zi, &yi)| (zi - scale * yi) >> 1)
1024 .collect();
1025
1026 let site_idx = y_bits | (x_bits << m);
1027
1028 let entry = carry_out_map
1029 .entry(carry_out)
1030 .or_insert_with(|| BoolTensor2::from_elem([num_carry_in, site_dim], false));
1031 entry.set([c_idx, site_idx], true);
1032 }
1033 }
1034 }
1035 }
1036
1037 // Convert to sorted vectors for deterministic ordering
1038 let mut carries_out: Vec<Vec<i64>> = carry_out_map.keys().cloned().collect();
1039 carries_out.sort();
1040
1041 let num_carry_out = carries_out.len();
1042
1043 // Build 3D tensor: (num_carry_out, num_carry_in, site_dim)
1044 let mut tensor = BoolTensor3::from_elem([num_carry_out, num_carry_in, site_dim], false);
1045 for (cout_idx, carry) in carries_out.iter().enumerate() {
1046 let data_2d = &carry_out_map[carry];
1047 for cin_idx in 0..num_carry_in {
1048 for site_idx in 0..site_dim {
1049 tensor.set(
1050 [cout_idx, cin_idx, site_idx],
1051 data_2d.get([cin_idx, site_idx]),
1052 );
1053 }
1054 }
1055 }
1056
1057 Ok(AffineCoreData {
1058 carries_out,
1059 tensor,
1060 })
1061}
1062
1063#[cfg(test)]
1064mod tests;