Skip to main content

tensor4all_quanticstransform/
affine.rs

1//! Affine transformation operator: y = A*x + b
2//!
3//! This implements general affine transformations with rational coefficients.
4//! The transformation computes y = A*x + b where A is an M×N rational matrix
5//! and b is an M-dimensional rational vector.
6//!
7//! Based on the algorithm from Quantics.jl/src/affine.jl
8
9use std::collections::HashMap;
10
11use anyhow::Result;
12use num_complex::Complex64;
13use num_integer::Integer;
14use num_rational::Rational64;
15use num_traits::One;
16use sprs::CsMat;
17use tensor4all_simplett::{types::tensor3_zeros, AbstractTensorTrain, Tensor3Ops, TensorTrain};
18
19use crate::common::{
20    tensortrain_to_linear_operator_asymmetric, BoundaryCondition, QuanticsOperator,
21};
22use tensor4all_simplett::tensor::Tensor3 as GenericTensor3;
23
24#[derive(Clone, Debug)]
25struct BoolTensor<const N: usize> {
26    data: Vec<u8>,
27    dims: [usize; N],
28}
29
30type BoolTensor2 = BoolTensor<2>;
31type BoolTensor3 = BoolTensor<3>;
32
33impl<const N: usize> BoolTensor<N> {
34    fn from_elem(dims: [usize; N], value: bool) -> Self {
35        let total: usize = dims.iter().product();
36        Self {
37            data: vec![u8::from(value); total],
38            dims,
39        }
40    }
41
42    fn dims(&self) -> &[usize; N] {
43        &self.dims
44    }
45
46    fn get(&self, idx: [usize; N]) -> bool {
47        self.data[self.offset(&idx)] != 0
48    }
49
50    fn set(&mut self, idx: [usize; N], value: bool) {
51        let offset = self.offset(&idx);
52        self.data[offset] = u8::from(value);
53    }
54
55    fn offset(&self, idx: &[usize; N]) -> usize {
56        let mut stride = 1usize;
57        let mut offset = 0usize;
58        for axis in (0..N).rev() {
59            offset += idx[axis] * stride;
60            stride *= self.dims[axis];
61        }
62        offset
63    }
64}
65
66/// Affine transformation parameters.
67///
68/// Represents the transformation y = A*x + b where:
69/// - A is an M x N matrix stored in column-major order
70/// - b is an M-dimensional vector
71/// - x is an N-dimensional input
72/// - y is an M-dimensional output
73///
74/// # Examples
75///
76/// ```
77/// use tensor4all_quanticstransform::AffineParams;
78/// use num_rational::Rational64;
79///
80/// // 1D shift: y = x + 3
81/// let params = AffineParams::from_integers(vec![1], vec![3], 1, 1).unwrap();
82/// assert_eq!(params.m, 1);
83/// assert_eq!(params.n, 1);
84///
85/// // 2D rotation: y = [[1,1],[1,-1]] * x
86/// // Column-major: [A[0,0], A[1,0], A[0,1], A[1,1]]
87/// let params = AffineParams::from_integers(
88///     vec![1, 1, 1, -1], vec![0, 0], 2, 2
89/// ).unwrap();
90/// assert_eq!(params.m, 2);
91/// assert_eq!(params.n, 2);
92///
93/// // With rational coefficients: y = (1/2)*x
94/// let params = AffineParams::new(
95///     vec![Rational64::new(1, 2)],
96///     vec![Rational64::from_integer(0)],
97///     1, 1,
98/// ).unwrap();
99/// ```
100#[derive(Clone, Debug)]
101pub struct AffineParams {
102    /// Transformation matrix A (M×N), stored in column-major order
103    pub a: Vec<Rational64>,
104    /// Translation vector b (M elements)
105    pub b: Vec<Rational64>,
106    /// Number of output dimensions (M)
107    pub m: usize,
108    /// Number of input dimensions (N)
109    pub n: usize,
110}
111
112impl AffineParams {
113    /// Create new affine parameters.
114    ///
115    /// # Arguments
116    /// * `a` - M x N matrix in column-major order (length must be m*n)
117    /// * `b` - M-dimensional translation vector (length must be m)
118    /// * `m` - Number of output dimensions
119    /// * `n` - Number of input dimensions
120    ///
121    /// # Examples
122    ///
123    /// ```
124    /// use tensor4all_quanticstransform::AffineParams;
125    /// use num_rational::Rational64;
126    ///
127    /// // 1D identity: y = x
128    /// let params = AffineParams::new(
129    ///     vec![Rational64::from_integer(1)],
130    ///     vec![Rational64::from_integer(0)],
131    ///     1, 1,
132    /// ).unwrap();
133    ///
134    /// // Dimension mismatch errors
135    /// assert!(AffineParams::new(
136    ///     vec![Rational64::from_integer(1)],
137    ///     vec![Rational64::from_integer(0)],
138    ///     2, 1, // expects 2 elements in A, got 1
139    /// ).is_err());
140    /// ```
141    pub fn new(a: Vec<Rational64>, b: Vec<Rational64>, m: usize, n: usize) -> Result<Self> {
142        if a.len() != m * n {
143            return Err(anyhow::anyhow!(
144                "Matrix A has {} elements but expected {}×{}={}",
145                a.len(),
146                m,
147                n,
148                m * n
149            ));
150        }
151        if b.len() != m {
152            return Err(anyhow::anyhow!(
153                "Vector b has {} elements but expected {}",
154                b.len(),
155                m
156            ));
157        }
158        Ok(Self { a, b, m, n })
159    }
160
161    /// Create affine parameters from integer matrix and vector.
162    ///
163    /// Convenience method that converts integer values to rationals.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// use tensor4all_quanticstransform::AffineParams;
169    ///
170    /// // 2D: y = [[1, 0], [0, 1]] * x + [1, 2] (shift by (1,2))
171    /// let params = AffineParams::from_integers(
172    ///     vec![1, 0, 0, 1], vec![1, 2], 2, 2,
173    /// ).unwrap();
174    /// assert_eq!(params.m, 2);
175    /// assert_eq!(params.n, 2);
176    /// ```
177    pub fn from_integers(a: Vec<i64>, b: Vec<i64>, m: usize, n: usize) -> Result<Self> {
178        let a_rat: Vec<Rational64> = a.into_iter().map(Rational64::from_integer).collect();
179        let b_rat: Vec<Rational64> = b.into_iter().map(Rational64::from_integer).collect();
180        Self::new(a_rat, b_rat, m, n)
181    }
182
183    /// Get element A[i, j] (0-indexed)
184    #[allow(dead_code)]
185    fn get_a(&self, i: usize, j: usize) -> Rational64 {
186        self.a[i + self.m * j]
187    }
188
189    /// Convert to integer representation by scaling with LCM of denominators.
190    /// Returns (A_int, b_int, scale) where A_int = scale * A and b_int = scale * b.
191    fn to_integer_scaled(&self) -> (Vec<i64>, Vec<i64>, i64) {
192        // Find LCM of all denominators
193        let mut denom_lcm = 1i64;
194        for r in &self.a {
195            denom_lcm = denom_lcm.lcm(r.denom());
196        }
197        for r in &self.b {
198            denom_lcm = denom_lcm.lcm(r.denom());
199        }
200
201        // Scale to integers
202        let a_int: Vec<i64> = self
203            .a
204            .iter()
205            .map(|r| (r * denom_lcm).to_integer())
206            .collect();
207        let b_int: Vec<i64> = self
208            .b
209            .iter()
210            .map(|r| (r * denom_lcm).to_integer())
211            .collect();
212
213        (a_int, b_int, denom_lcm)
214    }
215}
216
217/// Remap site indices of the affine MPO from internal encoding to the convention
218/// expected by `tensortrain_to_linear_operator_asymmetric`.
219///
220/// Internal encoding: `site_idx = y_bits | (x_bits << m)` (y-minor, x-major)
221/// Expected encoding: `s = s_out * in_dim + s_in = y_bits * 2^n + x_bits` (x-minor, y-major)
222fn remap_affine_site_indices(
223    mpo: &TensorTrain<Complex64>,
224    m: usize,
225    n: usize,
226    site_dim: usize,
227) -> Result<TensorTrain<Complex64>> {
228    let input_dim = 1 << n;
229
230    // Build permutation table: perm[old_idx] = remapped index
231    let perm: Vec<usize> = (0..site_dim)
232        .map(|old_idx| {
233            let y_bits = old_idx & ((1 << m) - 1);
234            let x_bits = old_idx >> m;
235            y_bits * input_dim + x_bits
236        })
237        .collect();
238
239    let r = mpo.len();
240    let mut new_tensors = Vec::with_capacity(r);
241
242    for i in 0..r {
243        let tensor = mpo.site_tensor(i);
244        let left_dim = tensor.left_dim();
245        let right_dim = tensor.right_dim();
246
247        let mut t = tensor3_zeros(left_dim, site_dim, right_dim);
248        for l in 0..left_dim {
249            for (old_s, &new_s) in perm.iter().enumerate() {
250                for rr in 0..right_dim {
251                    let val = *tensor.get3(l, old_s, rr);
252                    if val != Complex64::new(0.0, 0.0) {
253                        t.set3(l, new_s, rr, val);
254                    }
255                }
256            }
257        }
258        new_tensors.push(t);
259    }
260
261    TensorTrain::new(new_tensors)
262        .map_err(|e| anyhow::anyhow!("Failed to create remapped MPO: {}", e))
263}
264
265/// Create the operator that realizes the coordinate map `y = A * x + b`.
266///
267/// This is the **forward** affine operator. It maps a quantics tensor train
268/// representing an `N`-variable state `x` to the quantics tensor train of
269/// the `M`-variable state `y = A * x + b`.
270///
271/// To build the **pullback** (`f(y) = g(A * y + b)`), call `.transpose()`
272/// on the returned operator; the pullback is exactly the transpose of the
273/// forward operator.
274///
275/// # Arguments
276///
277/// * `r` — bits per variable (number of sites in the output MPO).
278/// * `params` — rational `M × N` matrix `A` and `M`-vector `b` describing
279///   the affine map.
280/// * `bc` — length `M` array of boundary conditions for each output variable.
281///   `Periodic` wraps output coordinates modulo `2^r`; `Open` zeroes the
282///   out-of-range contributions.
283///
284/// # Errors
285///
286/// Returns an error if `r == 0` or if `bc.len() != params.m`.
287///
288/// # Examples
289///
290/// ```
291/// use tensor4all_quanticstransform::{affine_operator, AffineParams, BoundaryCondition};
292/// use num_rational::Rational64;
293///
294/// // Transform g(x, y) -> g(x + y, x - y) (rotation by 45 degrees, scaled)
295/// let a = vec![
296///     Rational64::from_integer(1), Rational64::from_integer(1),  // row 0: x + y
297///     Rational64::from_integer(1), Rational64::from_integer(-1), // row 1: x - y
298/// ];
299/// let b = vec![Rational64::from_integer(0), Rational64::from_integer(0)];
300/// let params = AffineParams::new(a, b, 2, 2).unwrap();
301/// let bc = vec![BoundaryCondition::Periodic; 2];
302/// let op = affine_operator(4, &params, &bc).unwrap();
303/// assert_eq!(op.mpo.node_count(), 4);
304/// ```
305///
306/// Using integer convenience constructor:
307///
308/// ```
309/// use tensor4all_quanticstransform::{affine_operator, AffineParams, BoundaryCondition};
310///
311/// // Identity transform: y = x (1D)
312/// let params = AffineParams::from_integers(vec![1], vec![0], 1, 1).unwrap();
313/// let bc = vec![BoundaryCondition::Periodic];
314/// let op = affine_operator(4, &params, &bc).unwrap();
315/// assert_eq!(op.mpo.node_count(), 4);
316/// ```
317pub fn affine_operator(
318    r: usize,
319    params: &AffineParams,
320    bc: &[BoundaryCondition],
321) -> Result<QuanticsOperator> {
322    if r == 0 {
323        return Err(anyhow::anyhow!("Number of bits must be positive"));
324    }
325    if bc.len() != params.m {
326        return Err(anyhow::anyhow!(
327            "Boundary conditions length {} doesn't match output dimensions {}",
328            bc.len(),
329            params.m
330        ));
331    }
332
333    let mpo = affine_transform_mpo(r, params, bc)?;
334    // Site dimensions: M output variables, N input variables
335    // Input dimension per site: 2^N (N input bits)
336    // Output dimension per site: 2^M (M output bits)
337    let m = params.m;
338    let n = params.n;
339    let input_dim = 1 << n;
340    let output_dim = 1 << m;
341
342    // The internal affine MPO uses site encoding: site_idx = y_bits | (x_bits << m)
343    // (y-minor, x-major). But tensortrain_to_linear_operator_asymmetric expects
344    // s = s_out * in_dim + s_in = y_bits * 2^N + x_bits (x-minor, y-major).
345    // We need to remap the site indices.
346    let site_dim = input_dim * output_dim;
347    let remapped_mpo = remap_affine_site_indices(&mpo, m, n, site_dim)?;
348
349    let input_dims = vec![input_dim; r];
350    let output_dims = vec![output_dim; r];
351    tensortrain_to_linear_operator_asymmetric(&remapped_mpo, &input_dims, &output_dims)
352}
353
354/// Compute the full affine transformation matrix directly (for verification).
355///
356/// This computes the transformation matrix by directly evaluating y = A*x + b
357/// for all possible input values. The result is a sparse boolean matrix.
358///
359/// # Arguments
360/// * `r` - Number of bits per variable
361/// * `params` - Affine transformation parameters
362/// * `bc` - Boundary conditions for each output variable
363///
364/// # Returns
365/// Sparse matrix of size 2^(R*M) × 2^(R*N) where entry (y_flat, x_flat) = 1
366/// if the transformation maps x to y.
367///
368/// # Note
369/// This is only practical for small R due to exponential size.
370/// Use for testing/verification only.
371pub fn affine_transform_matrix(
372    r: usize,
373    params: &AffineParams,
374    bc: &[BoundaryCondition],
375) -> Result<CsMat<f64>> {
376    if r == 0 {
377        return Err(anyhow::anyhow!("Number of bits must be positive"));
378    }
379    if bc.len() != params.m {
380        return Err(anyhow::anyhow!(
381            "Boundary conditions length {} doesn't match output dimensions {}",
382            bc.len(),
383            params.m
384        ));
385    }
386
387    let (a_int, b_int, scale) = params.to_integer_scaled();
388    let m = params.m;
389    let n = params.n;
390
391    let bc_periodic: Vec<bool> = bc
392        .iter()
393        .map(|b| matches!(b, BoundaryCondition::Periodic))
394        .collect();
395
396    let input_size = 1usize << (r * n); // 2^(R*N)
397    let output_size = 1usize << (r * m); // 2^(R*M)
398    let modulus = 1i64 << r; // 2^R
399
400    let mut rows = Vec::new();
401    let mut cols = Vec::new();
402    let mut vals = Vec::new();
403
404    let mask = modulus - 1; // 2^R - 1
405
406    // Iterate over all (x, y) pairs, matching Julia's approach.
407    // For periodic BC with scale > 1, multiple y values can satisfy
408    // scale * y ≡ A*x + b (mod 2^R), so we must check all pairs.
409    for x_flat in 0..input_size {
410        // Decode x_flat to N-dimensional x vector
411        // x_flat = x[0] + x[1]*2^R + x[2]*2^(2R) + ...
412        let x: Vec<i64> = (0..n)
413            .map(|var| ((x_flat >> (var * r)) & ((1 << r) - 1)) as i64)
414            .collect();
415
416        // Compute v = A*x + b (unscaled)
417        let mut v: Vec<i64> = vec![0; m];
418        for i in 0..m {
419            v[i] = b_int[i];
420            for j in 0..n {
421                v[i] += a_int[i + m * j] * x[j];
422            }
423        }
424
425        for y_flat in 0..output_size {
426            // Decode y_flat to M-dimensional y vector
427            let y: Vec<i64> = (0..m)
428                .map(|var| ((y_flat >> (var * r)) & ((1 << r) - 1)) as i64)
429                .collect();
430
431            // Compute scale * y
432            let sy: Vec<i64> = y.iter().map(|&yi| scale * yi).collect();
433
434            // Check equiv(v, s*y, R, boundary) per component
435            let equiv = v.iter().zip(sy.iter()).enumerate().all(|(i, (&vi, &syi))| {
436                if bc_periodic[i] {
437                    // Periodic: v ≡ s*y (mod 2^R)
438                    (vi - syi) & mask == 0
439                } else {
440                    // Open: v == s*y (exact)
441                    vi == syi
442                }
443            });
444
445            if equiv {
446                rows.push(y_flat);
447                cols.push(x_flat);
448                vals.push(1.0);
449            }
450        }
451    }
452
453    // Build sparse matrix in CSR format
454    let triplet = sprs::TriMat::from_triplets((output_size, input_size), rows, cols, vals);
455    Ok(triplet.to_csr())
456}
457
458/// Create the affine transformation MPO as a TensorTrain.
459fn affine_transform_mpo(
460    r: usize,
461    params: &AffineParams,
462    bc: &[BoundaryCondition],
463) -> Result<TensorTrain<Complex64>> {
464    let (a_int, b_int, scale) = params.to_integer_scaled();
465    let m = params.m;
466    let n = params.n;
467
468    // Convert boundary conditions to weights
469    let bc_periodic: Vec<bool> = bc
470        .iter()
471        .map(|b| matches!(b, BoundaryCondition::Periodic))
472        .collect();
473
474    // Compute core tensors
475    let tensors = affine_transform_tensors(r, &a_int, &b_int, scale, m, n, &bc_periodic)?;
476
477    TensorTrain::new(tensors)
478        .map_err(|e| anyhow::anyhow!("Failed to create affine transform MPO: {}", e))
479}
480
481/// Create unfused affine transformation tensors.
482///
483/// Returns a vector of R tensors, where each tensor has shape:
484/// `[left_bond, 2, 2, ..., 2, right_bond]` with M+N physical indices of dimension 2.
485///
486/// The physical index order matches Quantics.jl:
487/// `(y[1], y[2], ..., y[M], x[1], x[2], ..., x[N])`
488/// where y are output variables and x are input variables.
489///
490/// # Arguments
491/// * `r` - Number of bits per variable (number of sites)
492/// * `params` - Affine transformation parameters
493/// * `bc` - Boundary conditions for each output variable
494///
495/// # Returns
496/// Vector of R tensors with unfused physical indices.
497///
498/// # Examples
499///
500/// ```
501/// use tensor4all_quanticstransform::{
502///     affine_transform_tensors_unfused, AffineParams, BoundaryCondition,
503/// };
504/// use tensor4all_simplett::Tensor3Ops;
505///
506/// let params = AffineParams::from_integers(vec![1, 1, 0, 1], vec![0, 0], 2, 2).unwrap();
507/// let bc = vec![BoundaryCondition::Periodic; 2];
508/// let tensors = affine_transform_tensors_unfused(4, &params, &bc).unwrap();
509///
510/// // One tensor per site
511/// assert_eq!(tensors.len(), 4);
512///
513/// // Each tensor has fused site_dim = 2^(M+N) = 16 for M=2, N=2
514/// assert_eq!(tensors[0].site_dim(), 16);
515/// ```
516pub fn affine_transform_tensors_unfused(
517    r: usize,
518    params: &AffineParams,
519    bc: &[BoundaryCondition],
520) -> Result<Vec<GenericTensor3<Complex64>>> {
521    if r == 0 {
522        return Err(anyhow::anyhow!("Number of bits must be positive"));
523    }
524    if bc.len() != params.m {
525        return Err(anyhow::anyhow!(
526            "Boundary conditions length {} doesn't match output dimensions {}",
527            bc.len(),
528            params.m
529        ));
530    }
531
532    let (a_int, b_int, scale) = params.to_integer_scaled();
533    let m = params.m;
534    let n = params.n;
535
536    // Convert boundary conditions to weights
537    let bc_periodic: Vec<bool> = bc
538        .iter()
539        .map(|b| matches!(b, BoundaryCondition::Periodic))
540        .collect();
541
542    // Compute fused tensors first
543    let fused_tensors = affine_transform_tensors(r, &a_int, &b_int, scale, m, n, &bc_periodic)?;
544
545    // Convert fused tensors to unfused format
546    // Fused: [left, fused_site, right] where fused_site = 2^(M+N)
547    // Unfused: [left, 2, 2, ..., 2, right] with M+N dimensions of size 2
548    //
549    // Fused index encoding: site_idx = y_bits | (x_bits << M)
550    // where y_bits = y[0] + 2*y[1] + ... + 2^(M-1)*y[M-1]
551    // and   x_bits = x[0] + 2*x[1] + ... + 2^(N-1)*x[N-1]
552    //
553    // Quantics.jl order: (y[0], y[1], ..., y[M-1], x[0], x[1], ..., x[N-1])
554    // We preserve that semantic index order:
555    // unfused[left, y0, y1, ..., yM-1, x0, x1, ..., xN-1, right]
556
557    let mut unfused_tensors = Vec::with_capacity(r);
558    let site_dim = 1 << (m + n);
559
560    for tensor in fused_tensors.iter() {
561        let left_dim = tensor.left_dim();
562        let right_dim = tensor.right_dim();
563
564        // Create unfused tensor
565        // Shape: [left_dim, 2^(M+N), right_dim] but we keep it as 3D for now
566        // The reshape to (M+N+2)-dimensional tensor will be done by the caller if needed
567        // For now, we provide a 3D tensor where the middle dimension is the fused site
568        // and document how to unfuse it.
569        //
570        // Actually, let's return it properly unfused using a flat storage with
571        // the correct index order for reshape.
572        //
573        // Total size: left_dim * 2^(M+N) * right_dim
574        // Shape for unfused: [left_dim, 2, 2, ..., 2, right_dim]
575        //
576        // Index mapping from fused to unfused:
577        // fused site_idx -> (y0, y1, ..., yM-1, x0, x1, ..., xN-1)
578        // site_idx = y0 + 2*y1 + ... + 2^(M-1)*yM-1 + 2^M * (x0 + 2*x1 + ...)
579
580        // Preserve the Quantics.jl physical index order
581        // (y0, y1, ..., yM-1, x0, x1, ..., xN-1).
582
583        let mut unfused_data = vec![Complex64::new(0.0, 0.0); left_dim * site_dim * right_dim];
584
585        for l in 0..left_dim {
586            for fused_idx in 0..site_dim {
587                for rr in 0..right_dim {
588                    let val = tensor.get3(l, fused_idx, rr);
589                    if val.norm() > 0.0 {
590                        // The fused index encodes: site_idx = y_bits | (x_bits << M)
591                        // This matches Quantics.jl's ordering (y variables first, then x)
592                        // so we can use fused_idx directly.
593                        //
594                        // The caller can reshape [left, site_dim, right] to
595                        // [left, 2, 2, ..., 2, right] with M+N dimensions of size 2,
596                        // where indices are in order (y[0], y[1], ..., y[M-1], x[0], ..., x[N-1])
597                        let flat_idx = l * site_dim * right_dim + fused_idx * right_dim + rr;
598                        unfused_data[flat_idx] = *val;
599                    }
600                }
601            }
602        }
603
604        // Create Tensor3 with shape [left_dim, site_dim, right_dim]
605        // The caller can reshape this to [left_dim, 2, 2, ..., 2, right_dim]
606        // with the understanding that the indices are ordered as (y0, y1, ..., x0, x1, ...)
607        let mut unfused_tensor =
608            GenericTensor3::from_elem([left_dim, site_dim, right_dim], Complex64::new(0.0, 0.0));
609        for l in 0..left_dim {
610            for s in 0..site_dim {
611                for r in 0..right_dim {
612                    unfused_tensor[[l, s, r]] =
613                        unfused_data[l * site_dim * right_dim + s * right_dim + r];
614                }
615            }
616        }
617
618        unfused_tensors.push(unfused_tensor);
619    }
620
621    Ok(unfused_tensors)
622}
623
624/// Information about the unfused tensor structure.
625///
626/// This helper provides metadata for reshaping the unfused tensors
627/// produced by [`affine_transform_tensors_unfused`].
628///
629/// # Examples
630///
631/// ```
632/// use tensor4all_quanticstransform::{AffineParams, UnfusedTensorInfo};
633///
634/// let params = AffineParams::from_integers(vec![1, 0, 0, 1], vec![0, 0], 2, 2).unwrap();
635/// let info = UnfusedTensorInfo::new(&params);
636///
637/// assert_eq!(info.m, 2);
638/// assert_eq!(info.n, 2);
639/// assert_eq!(info.num_physical_dims, 4);
640///
641/// // Get shape for a tensor with bond dims 3 and 5
642/// let shape = info.unfused_shape(3, 5);
643/// assert_eq!(shape, vec![3, 2, 2, 2, 2, 5]);
644///
645/// // Round-trip encode/decode
646/// let fused = info.encode_fused_index(&[1, 0], &[0, 1]);
647/// let (y_bits, x_bits) = info.decode_fused_index(fused);
648/// assert_eq!(y_bits, vec![1, 0]);
649/// assert_eq!(x_bits, vec![0, 1]);
650/// ```
651#[derive(Clone, Debug)]
652pub struct UnfusedTensorInfo {
653    /// Number of output variables (M)
654    pub m: usize,
655    /// Number of input variables (N)
656    pub n: usize,
657    /// Total physical dimensions per site (M + N)
658    pub num_physical_dims: usize,
659    /// Dimension of each physical index (always 2)
660    pub physical_dim: usize,
661}
662
663impl UnfusedTensorInfo {
664    /// Create info for the given affine parameters.
665    pub fn new(params: &AffineParams) -> Self {
666        Self {
667            m: params.m,
668            n: params.n,
669            num_physical_dims: params.m + params.n,
670            physical_dim: 2,
671        }
672    }
673
674    /// Get the shape for a fully unfused tensor at a given site.
675    ///
676    /// Returns `[left_bond, 2, 2, ..., 2, right_bond]` where there are M+N 2s.
677    pub fn unfused_shape(&self, left_bond: usize, right_bond: usize) -> Vec<usize> {
678        let mut shape = Vec::with_capacity(2 + self.num_physical_dims);
679        shape.push(left_bond);
680        shape.extend(std::iter::repeat_n(2, self.num_physical_dims));
681        shape.push(right_bond);
682        shape
683    }
684
685    /// Decode a fused site index to individual variable bits.
686    ///
687    /// Returns `(y_bits, x_bits)` where:
688    /// - `y_bits[i]` is the bit for output variable i
689    /// - `x_bits[j]` is the bit for input variable j
690    pub fn decode_fused_index(&self, fused_idx: usize) -> (Vec<usize>, Vec<usize>) {
691        let y_combined = fused_idx & ((1 << self.m) - 1);
692        let x_combined = fused_idx >> self.m;
693
694        let y_bits: Vec<usize> = (0..self.m).map(|i| (y_combined >> i) & 1).collect();
695        let x_bits: Vec<usize> = (0..self.n).map(|j| (x_combined >> j) & 1).collect();
696
697        (y_bits, x_bits)
698    }
699
700    /// Encode individual variable bits to a fused site index.
701    ///
702    /// # Arguments
703    /// * `y_bits` - Bits for output variables (length M)
704    /// * `x_bits` - Bits for input variables (length N)
705    pub fn encode_fused_index(&self, y_bits: &[usize], x_bits: &[usize]) -> usize {
706        let y_combined: usize = y_bits.iter().enumerate().map(|(i, &b)| b << i).sum();
707        let x_combined: usize = x_bits.iter().enumerate().map(|(j, &b)| b << j).sum();
708        y_combined | (x_combined << self.m)
709    }
710}
711
712/// Compute the core tensors for the affine transformation.
713///
714/// This implements the algorithm from Quantics.jl that handles:
715/// - Carry propagation for multi-bit arithmetic
716/// - Scaling factor s from rational to integer conversion
717///
718/// Uses big-endian convention: site 0 = MSB, site R-1 = LSB.
719///
720/// Carry propagation direction (matching shift.rs):
721/// - Arithmetic carry flows LSB → MSB (physical fact)
722/// - In big-endian: site R-1 → site 0 (right → left)
723/// - Tensor structure: t[left, site, right] where left=carry_out (going left), right=carry_in (from right)
724/// - Site 0 (MSB): BC applied on left, receives carry from right → shape (1, site_dim, num_carries)
725/// - Site R-1 (LSB): initial carry=0, sends carry to left → shape (num_carries, site_dim, 1)
726/// - Middle sites: shape (num_carries, site_dim, num_carries)
727fn affine_transform_tensors(
728    r: usize,
729    a_int: &[i64],
730    b_int: &[i64],
731    scale: i64,
732    m: usize,
733    n: usize,
734    bc_periodic: &[bool],
735) -> Result<Vec<tensor4all_simplett::Tensor3<Complex64>>> {
736    let site_dim = 1 << (m + n); // 2^(M+N) for fused representation
737
738    // Track sign separately and work with absolute value
739    // so that right-shifting always terminates (Julia PR #45 approach)
740    let bsign: Vec<i64> = b_int.iter().map(|&b| if b >= 0 { 1 } else { -1 }).collect();
741    let mut b_work: Vec<i64> = b_int.iter().map(|&b| b.abs()).collect();
742
743    // Process from LSB (site R-1) to MSB (site 0)
744    let mut carries: Vec<Vec<i64>> = vec![vec![0i64; m]];
745    let mut core_data_list: Vec<AffineCoreData> = Vec::with_capacity(r);
746
747    for _site in (0..r).rev() {
748        // Extract current bit: (b_work & 1) * bsign
749        let b_curr: Vec<i64> = b_work
750            .iter()
751            .zip(bsign.iter())
752            .map(|(&b, &s)| (b & 1) * s)
753            .collect();
754
755        let core_data = affine_transform_core(a_int, &b_curr, scale, m, n, &carries, true)?;
756        carries = core_data.carries_out.clone();
757        core_data_list.push(core_data);
758
759        // Shift right
760        b_work.iter_mut().for_each(|b| *b >>= 1);
761    }
762
763    // core_data_list is now in order: [site R-1, site R-2, ..., site 0]
764
765    // Extension loop: handle remaining bits of b for Open BC
766    // When abs(b) >= 2^R, high bits of b contribute to carries that affect validity.
767    // Extension tensors have site_dim=1 (activebit=false: only x=0, y=0).
768    // We fold them into the MSB tensor as a "cap matrix" (Julia approach).
769    let cap_matrix: Option<Vec<f64>> = if !bc_periodic.iter().all(|&p| p)
770        && b_work.iter().any(|&b| b > 0)
771    {
772        let mut ext_data_list: Vec<AffineCoreData> = Vec::new();
773        while b_work.iter().any(|&b| b > 0) {
774            let b_curr: Vec<i64> = b_work
775                .iter()
776                .zip(bsign.iter())
777                .map(|(&b, &s)| (b & 1) * s)
778                .collect();
779
780            let core_data = affine_transform_core(a_int, &b_curr, scale, m, n, &carries, false)?;
781            carries = core_data.carries_out.clone();
782            ext_data_list.push(core_data);
783
784            b_work.iter_mut().for_each(|b| *b >>= 1);
785        }
786
787        // Build cap matrix by contracting extension tensors with BC weights.
788        // Extension tensors have site_dim=1, so they are carry transition matrices:
789        //   ext_matrix[cout_idx, cin_idx] = core_data.tensor[[cout_idx, cin_idx, 0]]
790        //
791        // Process: outermost (last computed) gets BC weights applied,
792        // then multiply inward toward the main tensor chain.
793
794        // Start with BC weights on the final carries
795        let bc_weights: Vec<f64> = carries
796            .iter()
797            .map(|c| {
798                if c.iter().all(|&ci| ci == 0) {
799                    1.0
800                } else {
801                    0.0
802                }
803            })
804            .collect();
805
806        // Contract extension tensors from outermost to innermost
807        // ext_data_list is [innermost, ..., outermost] (order of computation)
808        // We process from outermost to innermost
809        let mut current_weights = bc_weights;
810        for ext_data in ext_data_list.iter().rev() {
811            let num_cin = ext_data.tensor.dims()[1];
812            let mut new_weights = vec![0.0; num_cin];
813            for (cin_idx, nw) in new_weights.iter_mut().enumerate() {
814                for (cout_idx, &w) in current_weights.iter().enumerate() {
815                    if w != 0.0 && ext_data.tensor.get([cout_idx, cin_idx, 0]) {
816                        *nw += w;
817                    }
818                }
819            }
820            current_weights = new_weights;
821        }
822
823        // current_weights now maps: MSB carry_out index -> effective BC weight
824        Some(current_weights)
825    } else {
826        None
827    };
828
829    // Build tensors in the same order, then reverse to get [site 0, site 1, ..., site R-1]
830    let mut tensors = Vec::with_capacity(r);
831
832    // Helper: compute BC weight for a carry-out index
833    let compute_bc_weight = |cout_idx: usize, core_data: &AffineCoreData| -> Complex64 {
834        if bc_periodic.iter().all(|&p| p) {
835            Complex64::one()
836        } else if let Some(ref cap) = cap_matrix {
837            // Extension loop was used: weight comes from cap matrix
838            Complex64::new(cap[cout_idx], 0.0)
839        } else {
840            // No extension: weight is 1 if carry is zero, 0 otherwise
841            let carry = &core_data.carries_out[cout_idx];
842            if carry.iter().all(|&c| c == 0) {
843                Complex64::one()
844            } else {
845                Complex64::new(0.0, 0.0)
846            }
847        }
848    };
849
850    for (idx, core_data) in core_data_list.iter().enumerate() {
851        // idx=0 corresponds to site R-1 (LSB), idx=R-1 corresponds to site 0 (MSB)
852        let actual_site = r - 1 - idx;
853        let num_carry_out = core_data.carries_out.len();
854        let num_carry_in = core_data.tensor.dims()[1];
855
856        // Tensor shape follows shift.rs pattern:
857        // t[left, site, right] where left=carry_out (going left), right=carry_in (from right)
858        //
859        // - Site 0 (MSB): left_dim=1 (BC applied), right_dim=num_carry (receives from right)
860        // - Site R-1 (LSB): left_dim=num_carry (sends to left), right_dim=1 (initial carry=0)
861        // - Middle: left_dim=num_carry, right_dim=num_carry
862        let is_msb = actual_site == 0;
863        let is_lsb = actual_site == r - 1;
864
865        let left_dim = if is_msb { 1 } else { num_carry_out };
866        let right_dim = if is_lsb { 1 } else { num_carry_in };
867
868        let mut t: tensor4all_simplett::Tensor3<Complex64> =
869            tensor3_zeros(left_dim, site_dim, right_dim);
870
871        if is_lsb && is_msb {
872            // R==1: single site case
873            for cout_idx in 0..num_carry_out {
874                let bc_weight = compute_bc_weight(cout_idx, core_data);
875
876                for site_idx in 0..site_dim {
877                    if core_data.tensor.get([cout_idx, 0, site_idx]) {
878                        let old = t.get3(0, site_idx, 0);
879                        t.set3(0, site_idx, 0, *old + bc_weight);
880                    }
881                }
882            }
883        } else if is_lsb {
884            // LSB (site R-1): initial carry_in=0, send carry_out to left
885            // Shape (num_carry_out, site_dim, 1)
886            // core_data.tensor[carry_out_idx, carry_in_idx, site_idx]
887            // Only carry_in_idx=0 matters (initial carry is the first entry: zero vector)
888            for cout_idx in 0..num_carry_out {
889                for site_idx in 0..site_dim {
890                    if core_data.tensor.get([cout_idx, 0, site_idx]) {
891                        t.set3(cout_idx, site_idx, 0, Complex64::one());
892                    }
893                }
894            }
895        } else if is_msb {
896            // MSB (site 0): apply BC on carry_out, receive carry from right
897            for cout_idx in 0..num_carry_out {
898                let bc_weight = compute_bc_weight(cout_idx, core_data);
899
900                for cin_idx in 0..num_carry_in {
901                    for site_idx in 0..site_dim {
902                        if core_data.tensor.get([cout_idx, cin_idx, site_idx]) {
903                            let old = t.get3(0, site_idx, cin_idx);
904                            t.set3(0, site_idx, cin_idx, *old + bc_weight);
905                        }
906                    }
907                }
908            }
909        } else {
910            // Middle tensors: receive carry from right, send carry to left
911            // Shape (num_carry_out, site_dim, num_carry_in)
912            for cout_idx in 0..num_carry_out {
913                for cin_idx in 0..num_carry_in {
914                    for site_idx in 0..site_dim {
915                        if core_data.tensor.get([cout_idx, cin_idx, site_idx]) {
916                            t.set3(cout_idx, site_idx, cin_idx, Complex64::one());
917                        }
918                    }
919                }
920            }
921        }
922
923        tensors.push(t);
924    }
925
926    // tensors is in order [site R-1, ..., site 0], reverse to get [site 0, ..., site R-1]
927    tensors.reverse();
928
929    Ok(tensors)
930}
931
932/// Core tensor data for affine transformation.
933///
934/// Shape: (num_carry_out, num_carry_in, site_dim)
935/// where site_dim = 2^(M+N)
936struct AffineCoreData {
937    /// Possible outgoing carry vectors
938    carries_out: Vec<Vec<i64>>,
939    /// Tensor data: tensor[carry_out_idx, carry_in_idx, site_idx]
940    tensor: BoolTensor3,
941}
942
943/// Compute a single core tensor for the affine transformation.
944///
945/// The core tensor encodes: 2 * carry_out = A * x + b_curr - scale * y + carry_in
946///
947/// Returns AffineCoreData containing:
948/// - carries_out: list of possible outgoing carry vectors
949/// - tensor: shape (num_carry_out, num_carry_in, site_dim)
950fn affine_transform_core(
951    a_int: &[i64],
952    b_curr: &[i64],
953    scale: i64,
954    m: usize,
955    n: usize,
956    carries_in: &[Vec<i64>],
957    activebit: bool,
958) -> Result<AffineCoreData> {
959    let mut carry_out_map: HashMap<Vec<i64>, BoolTensor2> = HashMap::new();
960    let x_range = if activebit { 1 << n } else { 1 };
961    let y_range = if activebit { 1 << m } else { 1 };
962    let site_dim = x_range * y_range;
963    let num_carry_in = carries_in.len();
964
965    // Iterate over all input carries
966    for (c_idx, carry_in) in carries_in.iter().enumerate() {
967        // Iterate over all possible x values (N bits)
968        for x_bits in 0..x_range {
969            let x: Vec<i64> = (0..n).map(|j| ((x_bits >> j) & 1) as i64).collect();
970
971            // Compute z = A*x + b + carry_in
972            let mut z: Vec<i64> = vec![0; m];
973            for i in 0..m {
974                z[i] = carry_in[i] + b_curr[i];
975                for j in 0..n {
976                    z[i] += a_int[i + m * j] * x[j];
977                }
978            }
979
980            if scale % 2 == 1 {
981                // Scale is odd: unique y that satisfies condition
982                let y: Vec<i64> = z.iter().map(|&zi| zi & 1).collect();
983
984                // When bits are inactive, y must be zero (Julia PR #45 fix)
985                if !activebit && y.iter().any(|&yi| yi != 0) {
986                    continue;
987                }
988
989                let y_bits: usize = y
990                    .iter()
991                    .enumerate()
992                    .map(|(i, &yi)| (yi as usize) << i)
993                    .sum();
994
995                // Compute carry_out = (z - scale * y) / 2
996                let carry_out: Vec<i64> = z
997                    .iter()
998                    .zip(y.iter())
999                    .map(|(&zi, &yi)| (zi - scale * yi) >> 1)
1000                    .collect();
1001
1002                // Site index: y bits in lower positions, x bits in upper positions
1003                let site_idx = y_bits | (x_bits << m);
1004
1005                let entry = carry_out_map
1006                    .entry(carry_out)
1007                    .or_insert_with(|| BoolTensor2::from_elem([num_carry_in, site_dim], false));
1008                entry.set([c_idx, site_idx], true);
1009            } else {
1010                // Scale is even: z must be even for valid y
1011                if z.iter().any(|&zi| zi % 2 != 0) {
1012                    continue;
1013                }
1014
1015                // y can be any value
1016                for y_bits in 0..y_range {
1017                    let y: Vec<i64> = (0..m).map(|i| ((y_bits >> i) & 1) as i64).collect();
1018
1019                    // Compute carry_out = (z - scale * y) / 2
1020                    let carry_out: Vec<i64> = z
1021                        .iter()
1022                        .zip(y.iter())
1023                        .map(|(&zi, &yi)| (zi - scale * yi) >> 1)
1024                        .collect();
1025
1026                    let site_idx = y_bits | (x_bits << m);
1027
1028                    let entry = carry_out_map
1029                        .entry(carry_out)
1030                        .or_insert_with(|| BoolTensor2::from_elem([num_carry_in, site_dim], false));
1031                    entry.set([c_idx, site_idx], true);
1032                }
1033            }
1034        }
1035    }
1036
1037    // Convert to sorted vectors for deterministic ordering
1038    let mut carries_out: Vec<Vec<i64>> = carry_out_map.keys().cloned().collect();
1039    carries_out.sort();
1040
1041    let num_carry_out = carries_out.len();
1042
1043    // Build 3D tensor: (num_carry_out, num_carry_in, site_dim)
1044    let mut tensor = BoolTensor3::from_elem([num_carry_out, num_carry_in, site_dim], false);
1045    for (cout_idx, carry) in carries_out.iter().enumerate() {
1046        let data_2d = &carry_out_map[carry];
1047        for cin_idx in 0..num_carry_in {
1048            for site_idx in 0..site_dim {
1049                tensor.set(
1050                    [cout_idx, cin_idx, site_idx],
1051                    data_2d.get([cin_idx, site_idx]),
1052                );
1053            }
1054        }
1055    }
1056
1057    Ok(AffineCoreData {
1058        carries_out,
1059        tensor,
1060    })
1061}
1062
1063#[cfg(test)]
1064mod tests;