Skip to main content

tensor4all_quanticstransform/
common.rs

1//! Common types and helper functions for quantics transformations.
2
3use std::collections::HashMap;
4
5use anyhow::Result;
6use num_complex::Complex64;
7use num_traits::One;
8use tensor4all_core::index::{DynId, Index, TagSet};
9use tensor4all_core::TensorDynLen;
10use tensor4all_simplett::{types::tensor3_zeros, AbstractTensorTrain, Tensor3Ops, TensorTrain};
11use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
12
13/// Type alias for the default index type.
14pub type DynIndex = Index<DynId, TagSet>;
15
16/// Boundary condition for quantics transformations.
17///
18/// Controls how operators handle values that exceed the representable range
19/// `[0, 2^R)`.
20///
21/// # Variants
22///
23/// - **`Periodic`** (default): Results wrap around modulo 2^R.
24///   Use when functions are periodic or when wraparound is acceptable.
25/// - **`Open`**: Out-of-range results produce zeros.
26///   Use when the function has compact support or when boundary effects matter.
27///
28/// # Examples
29///
30/// ```
31/// use tensor4all_quanticstransform::BoundaryCondition;
32///
33/// // Default is Periodic
34/// let bc = BoundaryCondition::default();
35/// assert_eq!(bc, BoundaryCondition::Periodic);
36///
37/// // Periodic: shift(7, 2) in 3-bit (mod 8) wraps to 1
38/// // Open: shift(7, 2) in 3-bit goes to 9 >= 8, produces zero
39/// ```
40#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
41pub enum BoundaryCondition {
42    /// Periodic boundary: operations wrap around mod 2^R.
43    ///
44    /// Use for periodic functions or when wraparound is desired.
45    #[default]
46    Periodic,
47    /// Open boundary: operations beyond `[0, 2^R)` return zero.
48    ///
49    /// Use when the function has compact support or boundary effects matter.
50    Open,
51}
52
53/// Direction for carry propagation in binary arithmetic operations.
54///
55/// This is an internal detail of how binary arithmetic (addition, subtraction)
56/// is implemented in the MPO construction. Most users do not need to set this
57/// directly.
58///
59/// # Variants
60///
61/// - **`LeftToRight`** (default): Carry propagates from MSB to LSB.
62/// - **`RightToLeft`**: Carry propagates from LSB to MSB.
63#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
64pub enum CarryDirection {
65    /// Carry propagates from left (MSB) to right (LSB).
66    #[default]
67    LeftToRight,
68    /// Carry propagates from right (LSB) to left (MSB).
69    RightToLeft,
70}
71
72/// Type alias for the standard LinearOperator used in this crate.
73/// Uses TensorDynLen as the tensor type and usize as the node name type.
74pub type QuanticsOperator = LinearOperator<TensorDynLen, usize>;
75
76/// Convert a TensorTrain (MPO form) to a LinearOperator.
77///
78/// The TensorTrain is assumed to be an MPO with site dimension 4 (2x2 for input/output).
79/// Each site tensor has shape (left_bond, site_dim=4, right_bond) where site_dim
80/// encodes (s_out, s_in) = (2, 2).
81///
82/// # Arguments
83/// * `tt` - TensorTrain representing an MPO
84/// * `site_dims` - Site dimensions for input/output (typically all 2s)
85///
86/// # Returns
87/// LinearOperator wrapping the MPO as a TreeTN
88pub fn tensortrain_to_linear_operator(
89    tt: &TensorTrain<Complex64>,
90    site_dims: &[usize],
91) -> Result<QuanticsOperator> {
92    let n = tt.len();
93    if n == 0 {
94        return Err(anyhow::anyhow!("Empty tensor train"));
95    }
96
97    // Create site indices for input and output
98    let mut site_in_indices: Vec<DynIndex> = Vec::with_capacity(n);
99    let mut site_out_indices: Vec<DynIndex> = Vec::with_capacity(n);
100    let mut internal_in_indices: Vec<DynIndex> = Vec::with_capacity(n);
101    let mut internal_out_indices: Vec<DynIndex> = Vec::with_capacity(n);
102
103    for &dim in site_dims.iter() {
104        // True site indices (for state)
105        site_in_indices.push(Index::new_dyn(dim));
106        site_out_indices.push(Index::new_dyn(dim));
107        // Internal MPO indices
108        internal_in_indices.push(Index::new_dyn(dim));
109        internal_out_indices.push(Index::new_dyn(dim));
110    }
111
112    // Create bond indices
113    let mut bond_indices: Vec<DynIndex> = Vec::with_capacity(n + 1);
114
115    for i in 0..=n {
116        let dim = if i == 0 {
117            1
118        } else {
119            tt.site_tensor(i - 1).right_dim()
120        };
121        bond_indices.push(Index::new_dyn(dim));
122    }
123
124    // Build tensors for TreeTN
125    let mut tensors: Vec<TensorDynLen> = Vec::with_capacity(n);
126    let mut node_names: Vec<usize> = Vec::with_capacity(n);
127
128    for i in 0..n {
129        let tensor = tt.site_tensor(i);
130        let left_dim = tensor.left_dim();
131        let site_dim = tensor.site_dim();
132        let right_dim = tensor.right_dim();
133
134        // Expected site_dim is product of input and output dimensions
135        let expected_site_dim = site_dims[i] * site_dims[i];
136        if site_dim != expected_site_dim {
137            return Err(anyhow::anyhow!(
138                "Site {} has dimension {} but expected {} ({}x{})",
139                i,
140                site_dim,
141                expected_site_dim,
142                site_dims[i],
143                site_dims[i]
144            ));
145        }
146
147        // Create indices for this tensor: (left_bond, site_out, site_in, right_bond)
148        // For first tensor: (site_out, site_in, right_bond)
149        // For last tensor: (left_bond, site_out, site_in)
150        // For middle: (left_bond, site_out, site_in, right_bond)
151        let mut indices: Vec<DynIndex> = Vec::with_capacity(4);
152        let mut dims_vec: Vec<usize> = Vec::with_capacity(4);
153
154        if i > 0 {
155            indices.push(bond_indices[i].clone());
156            dims_vec.push(left_dim);
157        }
158        indices.push(internal_out_indices[i].clone());
159        dims_vec.push(site_dims[i]);
160        indices.push(internal_in_indices[i].clone());
161        dims_vec.push(site_dims[i]);
162        if i < n - 1 {
163            indices.push(bond_indices[i + 1].clone());
164            dims_vec.push(right_dim);
165        }
166
167        // Reshape tensor data: (left, site_out*site_in, right) -> (left, site_out, site_in, right)
168        // or appropriate variant for boundary tensors
169        let total_size: usize = dims_vec.iter().product();
170        let mut data: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); total_size];
171
172        // Map from TT format to TreeTN format
173        if i == 0 && n == 1 {
174            // Single tensor: (site_out, site_in)
175            for s_out in 0..site_dims[i] {
176                for s_in in 0..site_dims[i] {
177                    let s = s_out * site_dims[i] + s_in;
178                    let idx = s_out + site_dims[i] * s_in;
179                    data[idx] = *tensor.get3(0, s, 0);
180                }
181            }
182        } else if i == 0 {
183            // First tensor: (site_out, site_in, right_bond)
184            for s_out in 0..site_dims[i] {
185                for s_in in 0..site_dims[i] {
186                    for r in 0..right_dim {
187                        let s = s_out * site_dims[i] + s_in;
188                        let idx = s_out + site_dims[i] * (s_in + site_dims[i] * r);
189                        data[idx] = *tensor.get3(0, s, r);
190                    }
191                }
192            }
193        } else if i == n - 1 {
194            // Last tensor: (left_bond, site_out, site_in)
195            for l in 0..left_dim {
196                for s_out in 0..site_dims[i] {
197                    for s_in in 0..site_dims[i] {
198                        let s = s_out * site_dims[i] + s_in;
199                        let idx = l + left_dim * (s_out + site_dims[i] * s_in);
200                        data[idx] = *tensor.get3(l, s, 0);
201                    }
202                }
203            }
204        } else {
205            // Middle tensor: (left_bond, site_out, site_in, right_bond)
206            for l in 0..left_dim {
207                for s_out in 0..site_dims[i] {
208                    for s_in in 0..site_dims[i] {
209                        for r in 0..right_dim {
210                            let s = s_out * site_dims[i] + s_in;
211                            let idx =
212                                l + left_dim * (s_out + site_dims[i] * (s_in + site_dims[i] * r));
213                            data[idx] = *tensor.get3(l, s, r);
214                        }
215                    }
216                }
217            }
218        }
219
220        let tensor_dyn = TensorDynLen::from_dense(indices, data).unwrap();
221        tensors.push(tensor_dyn);
222        node_names.push(i);
223    }
224
225    // Build TreeTN from tensors
226    let treetn = TreeTN::from_tensors(tensors, node_names)?;
227
228    // Build index mappings
229    let mut input_mapping: HashMap<usize, IndexMapping<DynIndex>> = HashMap::new();
230    let mut output_mapping: HashMap<usize, IndexMapping<DynIndex>> = HashMap::new();
231
232    for i in 0..n {
233        input_mapping.insert(
234            i,
235            IndexMapping {
236                true_index: site_in_indices[i].clone(),
237                internal_index: internal_in_indices[i].clone(),
238            },
239        );
240        output_mapping.insert(
241            i,
242            IndexMapping {
243                true_index: site_out_indices[i].clone(),
244                internal_index: internal_out_indices[i].clone(),
245            },
246        );
247    }
248
249    Ok(LinearOperator::new(treetn, input_mapping, output_mapping))
250}
251
252/// Convert a TensorTrain (MPO form) to a LinearOperator with asymmetric dimensions.
253///
254/// This variant supports different input and output dimensions, useful for
255/// multi-variable transformations like affine transforms.
256///
257/// # Arguments
258/// * `tt` - TensorTrain representing an MPO
259/// * `input_dims` - Input dimensions per site
260/// * `output_dims` - Output dimensions per site
261///
262/// # Returns
263/// LinearOperator wrapping the MPO as a TreeTN
264pub fn tensortrain_to_linear_operator_asymmetric(
265    tt: &TensorTrain<Complex64>,
266    input_dims: &[usize],
267    output_dims: &[usize],
268) -> Result<QuanticsOperator> {
269    let n = tt.len();
270    if n == 0 {
271        return Err(anyhow::anyhow!("Empty tensor train"));
272    }
273    if input_dims.len() != n || output_dims.len() != n {
274        return Err(anyhow::anyhow!("Dimension arrays must have length {}", n));
275    }
276
277    // Create site indices for input and output
278    let mut site_in_indices: Vec<DynIndex> = Vec::with_capacity(n);
279    let mut site_out_indices: Vec<DynIndex> = Vec::with_capacity(n);
280    let mut internal_in_indices: Vec<DynIndex> = Vec::with_capacity(n);
281    let mut internal_out_indices: Vec<DynIndex> = Vec::with_capacity(n);
282
283    for i in 0..n {
284        // True site indices (for state)
285        site_in_indices.push(Index::new_dyn(input_dims[i]));
286        site_out_indices.push(Index::new_dyn(output_dims[i]));
287        // Internal MPO indices
288        internal_in_indices.push(Index::new_dyn(input_dims[i]));
289        internal_out_indices.push(Index::new_dyn(output_dims[i]));
290    }
291
292    // Create bond indices
293    let mut bond_indices: Vec<DynIndex> = Vec::with_capacity(n + 1);
294
295    for i in 0..=n {
296        let dim = if i == 0 {
297            1
298        } else {
299            tt.site_tensor(i - 1).right_dim()
300        };
301        bond_indices.push(Index::new_dyn(dim));
302    }
303
304    // Build tensors for TreeTN
305    let mut tensors: Vec<TensorDynLen> = Vec::with_capacity(n);
306    let mut node_names: Vec<usize> = Vec::with_capacity(n);
307
308    for i in 0..n {
309        let tensor = tt.site_tensor(i);
310        let left_dim = tensor.left_dim();
311        let site_dim = tensor.site_dim();
312        let right_dim = tensor.right_dim();
313
314        let in_dim = input_dims[i];
315        let out_dim = output_dims[i];
316
317        // Expected site_dim is product of input and output dimensions
318        let expected_site_dim = in_dim * out_dim;
319        if site_dim != expected_site_dim {
320            return Err(anyhow::anyhow!(
321                "Site {} has dimension {} but expected {} ({}x{})",
322                i,
323                site_dim,
324                expected_site_dim,
325                out_dim,
326                in_dim
327            ));
328        }
329
330        // Create indices for this tensor: (left_bond, site_out, site_in, right_bond)
331        let mut indices: Vec<DynIndex> = Vec::with_capacity(4);
332        let mut dims_vec: Vec<usize> = Vec::with_capacity(4);
333
334        if i > 0 {
335            indices.push(bond_indices[i].clone());
336            dims_vec.push(left_dim);
337        }
338        indices.push(internal_out_indices[i].clone());
339        dims_vec.push(out_dim);
340        indices.push(internal_in_indices[i].clone());
341        dims_vec.push(in_dim);
342        if i < n - 1 {
343            indices.push(bond_indices[i + 1].clone());
344            dims_vec.push(right_dim);
345        }
346
347        // Reshape tensor data: (left, site_out*site_in, right) -> (left, site_out, site_in, right)
348        let total_size: usize = dims_vec.iter().product();
349        let mut data: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); total_size];
350
351        // Map from TT format to TreeTN format
352        // TT format has site index = s_out * in_dim + s_in (output major, input minor)
353        if i == 0 && n == 1 {
354            // Single tensor: (site_out, site_in)
355            for s_out in 0..out_dim {
356                for s_in in 0..in_dim {
357                    let s = s_out * in_dim + s_in;
358                    let idx = s_out + out_dim * s_in;
359                    data[idx] = *tensor.get3(0, s, 0);
360                }
361            }
362        } else if i == 0 {
363            // First tensor: (site_out, site_in, right_bond)
364            for s_out in 0..out_dim {
365                for s_in in 0..in_dim {
366                    for r in 0..right_dim {
367                        let s = s_out * in_dim + s_in;
368                        let idx = s_out + out_dim * (s_in + in_dim * r);
369                        data[idx] = *tensor.get3(0, s, r);
370                    }
371                }
372            }
373        } else if i == n - 1 {
374            // Last tensor: (left_bond, site_out, site_in)
375            for l in 0..left_dim {
376                for s_out in 0..out_dim {
377                    for s_in in 0..in_dim {
378                        let s = s_out * in_dim + s_in;
379                        let idx = l + left_dim * (s_out + out_dim * s_in);
380                        data[idx] = *tensor.get3(l, s, 0);
381                    }
382                }
383            }
384        } else {
385            // Middle tensor: (left_bond, site_out, site_in, right_bond)
386            for l in 0..left_dim {
387                for s_out in 0..out_dim {
388                    for s_in in 0..in_dim {
389                        for r in 0..right_dim {
390                            let s = s_out * in_dim + s_in;
391                            let idx = l + left_dim * (s_out + out_dim * (s_in + in_dim * r));
392                            data[idx] = *tensor.get3(l, s, r);
393                        }
394                    }
395                }
396            }
397        }
398
399        let tensor_dyn = TensorDynLen::from_dense(indices, data).unwrap();
400        tensors.push(tensor_dyn);
401        node_names.push(i);
402    }
403
404    // Build TreeTN from tensors
405    let treetn = TreeTN::from_tensors(tensors, node_names)?;
406
407    // Build index mappings
408    let mut input_mapping: HashMap<usize, IndexMapping<DynIndex>> = HashMap::new();
409    let mut output_mapping: HashMap<usize, IndexMapping<DynIndex>> = HashMap::new();
410
411    for i in 0..n {
412        input_mapping.insert(
413            i,
414            IndexMapping {
415                true_index: site_in_indices[i].clone(),
416                internal_index: internal_in_indices[i].clone(),
417            },
418        );
419        output_mapping.insert(
420            i,
421            IndexMapping {
422                true_index: site_out_indices[i].clone(),
423                internal_index: internal_out_indices[i].clone(),
424            },
425        );
426    }
427
428    Ok(LinearOperator::new(treetn, input_mapping, output_mapping))
429}
430
431/// Embed a single-variable MPO into a multi-variable context.
432///
433/// The original MPO acts on one variable (site_dim = d*d for d=2, i.e., in/out dim 2).
434/// The embedded MPO acts on `nvariables` variables, applying the original
435/// operator to `target_var` and identity on all others.
436///
437/// Site index encoding in the embedded MPO:
438/// `s = s_out * (2^nvariables) + s_in` where
439/// `s_out = var0_out + 2*var1_out + ...` and similarly for `s_in`.
440///
441/// # Arguments
442/// * `mpo` - Single-variable MPO (R sites, site_dim = 4)
443/// * `nvariables` - Total number of variables (must be >= 2)
444/// * `target_var` - Which variable to apply the operator to (0-indexed)
445pub(crate) fn embed_single_var_mpo(
446    mpo: &TensorTrain<Complex64>,
447    nvariables: usize,
448    target_var: usize,
449) -> Result<TensorTrain<Complex64>> {
450    if target_var >= nvariables {
451        return Err(anyhow::anyhow!(
452            "target_var {} must be less than nvariables {}",
453            target_var,
454            nvariables
455        ));
456    }
457    if nvariables < 2 {
458        return Err(anyhow::anyhow!("nvariables must be at least 2"));
459    }
460
461    let r = mpo.len();
462    let dim_multi = 1usize << nvariables; // 2^nvars
463    let site_dim_new = dim_multi * dim_multi;
464
465    let mut new_tensors = Vec::with_capacity(r);
466
467    for i in 0..r {
468        let tensor = mpo.site_tensor(i);
469        let left_dim = tensor.left_dim();
470        let right_dim = tensor.right_dim();
471
472        assert_eq!(
473            tensor.site_dim(),
474            4,
475            "Input MPO must have site_dim=4 (single variable)"
476        );
477
478        let mut t = tensor3_zeros(left_dim, site_dim_new, right_dim);
479
480        for s_out_multi in 0..dim_multi {
481            for s_in_multi in 0..dim_multi {
482                // Check identity constraint on non-target variables
483                let mut identity_ok = true;
484                for v in 0..nvariables {
485                    if v != target_var {
486                        let out_bit = (s_out_multi >> v) & 1;
487                        let in_bit = (s_in_multi >> v) & 1;
488                        if out_bit != in_bit {
489                            identity_ok = false;
490                            break;
491                        }
492                    }
493                }
494                if !identity_ok {
495                    continue;
496                }
497
498                // Extract target variable bits
499                let target_out = (s_out_multi >> target_var) & 1;
500                let target_in = (s_in_multi >> target_var) & 1;
501                let s_orig = target_out * 2 + target_in;
502
503                // New fused site index
504                let s_new = s_out_multi * dim_multi + s_in_multi;
505
506                for l in 0..left_dim {
507                    for rr in 0..right_dim {
508                        let val = *tensor.get3(l, s_orig, rr);
509                        if val != Complex64::new(0.0, 0.0) {
510                            t.set3(l, s_new, rr, val);
511                        }
512                    }
513                }
514            }
515        }
516
517        new_tensors.push(t);
518    }
519
520    TensorTrain::new(new_tensors)
521        .map_err(|e| anyhow::anyhow!("Failed to create embedded MPO: {}", e))
522}
523
524/// Create an identity MPO for r sites with dimension 2.
525#[allow(dead_code)]
526pub fn identity_mpo(r: usize) -> Result<TensorTrain<Complex64>> {
527    if r == 0 {
528        return Err(anyhow::anyhow!("Number of sites must be positive"));
529    }
530
531    let mut tensors = Vec::with_capacity(r);
532
533    for _ in 0..r {
534        // Identity tensor: delta_{s_out, s_in}
535        // Shape: (1, 4, 1) where 4 = 2*2 for (s_out, s_in)
536        let mut t = tensor3_zeros(1, 4, 1);
537        // s = s_out * 2 + s_in
538        // Identity: s_out == s_in
539        t.set3(0, 0, 0, Complex64::one()); // (0, 0)
540        t.set3(0, 3, 0, Complex64::one()); // (1, 1)
541        tensors.push(t);
542    }
543
544    TensorTrain::new(tensors).map_err(|e| anyhow::anyhow!("Failed to create identity MPO: {}", e))
545}
546
547/// Create a scalar MPO (constant times identity).
548#[allow(dead_code)]
549pub fn scalar_mpo(r: usize, value: Complex64) -> Result<TensorTrain<Complex64>> {
550    let mut mpo = identity_mpo(r)?;
551    mpo.scale(value);
552    Ok(mpo)
553}
554
555#[cfg(test)]
556mod tests;