tensor4all_itensorlike/
linsolve.rs1use std::collections::HashMap;
10
11use tensor4all_core::{DynIndex, IndexLike};
12use tensor4all_treetn::{square_linsolve, IndexMapping, TruncationOptions};
13
14use crate::error::{Result, TensorTrainError};
15use crate::options::{validate_svd_truncation_options, CanonicalForm, LinsolveOptions};
16use crate::tensortrain::TensorTrain;
17
18pub fn linsolve(
35 operator: &TensorTrain,
36 rhs: &TensorTrain,
37 init: TensorTrain,
38 options: &LinsolveOptions,
39) -> Result<TensorTrain> {
40 if operator.is_empty() || rhs.is_empty() || init.is_empty() {
41 return Err(TensorTrainError::InvalidStructure {
42 message: "Cannot linsolve with empty tensor trains".to_string(),
43 });
44 }
45
46 if !options.nhalfsweeps().is_multiple_of(2) {
47 return Err(TensorTrainError::OperationError {
48 message: format!(
49 "nhalfsweeps must be a multiple of 2, got {}",
50 options.nhalfsweeps()
51 ),
52 });
53 }
54
55 validate_svd_truncation_options(options.max_rank(), options.svd_policy())?;
56
57 if !options.krylov_tol().is_finite() || options.krylov_tol() <= 0.0 {
58 return Err(TensorTrainError::OperationError {
59 message: format!(
60 "krylov_tol must be finite and > 0, got {}",
61 options.krylov_tol()
62 ),
63 });
64 }
65
66 if options.krylov_maxiter() == 0 {
67 return Err(TensorTrainError::OperationError {
68 message: "krylov_maxiter must be >= 1".to_string(),
69 });
70 }
71
72 if options.krylov_dim() == 0 {
73 return Err(TensorTrainError::OperationError {
74 message: "krylov_dim must be >= 1".to_string(),
75 });
76 }
77
78 if let Some(tol) = options.convergence_tol() {
79 if !tol.is_finite() || tol < 0.0 {
80 return Err(TensorTrainError::OperationError {
81 message: format!("convergence_tol must be finite and >= 0, got {}", tol),
82 });
83 }
84 }
85
86 let nfullsweeps = options.nhalfsweeps() / 2;
88
89 let treetn_options = tensor4all_treetn::LinsolveOptions::new(nfullsweeps)
90 .with_truncation(TruncationOptions::new())
91 .with_krylov_tol(options.krylov_tol())
92 .with_krylov_maxiter(options.krylov_maxiter())
93 .with_krylov_dim(options.krylov_dim())
94 .with_coefficients(options.coefficients().0, options.coefficients().1);
95
96 let treetn_options = if let Some(policy) = options.svd_policy() {
97 treetn_options.with_svd_policy(policy)
98 } else {
99 treetn_options
100 };
101
102 let treetn_options = if let Some(max_rank) = options.max_rank() {
103 treetn_options.with_max_rank(max_rank)
104 } else {
105 treetn_options
106 };
107
108 let treetn_options = if let Some(tol) = options.convergence_tol() {
109 treetn_options.with_convergence_tol(tol)
110 } else {
111 treetn_options
112 };
113
114 let center = init.len() - 1;
116
117 let (input_mapping, output_mapping) =
122 infer_index_mappings(operator, &init).map_err(|e| TensorTrainError::OperationError {
123 message: format!("Failed to infer index mappings: {}", e),
124 })?;
125
126 let result = square_linsolve(
127 operator.as_treetn(),
128 rhs.as_treetn(),
129 init.treetn,
130 ¢er,
131 treetn_options,
132 input_mapping,
133 output_mapping,
134 )
135 .map_err(|e| TensorTrainError::OperationError {
136 message: format!("Linsolve failed: {}", e),
137 })?;
138
139 TensorTrain::from_inner(result.solution, Some(CanonicalForm::Unitary))
140}
141
142type SiteMappings = (
143 Option<HashMap<usize, IndexMapping<DynIndex>>>,
144 Option<HashMap<usize, IndexMapping<DynIndex>>>,
145);
146
147fn infer_index_mappings(
154 operator: &TensorTrain,
155 init: &TensorTrain,
156) -> std::result::Result<SiteMappings, String> {
157 let op_treetn = operator.as_treetn();
158 let init_treetn = init.as_treetn();
159 let nsites = init.len();
160
161 let mut needs_mapping = false;
162
163 for site in 0..nsites {
165 let op_site = op_treetn.site_space(&site);
166 let init_site = init_treetn.site_space(&site);
167
168 if let (Some(op_indices), Some(init_indices)) = (op_site, init_site) {
169 if op_indices.len() == 2 && init_indices.len() == 1 {
170 let init_idx = init_indices.iter().next().unwrap();
172 let has_shared = op_indices.iter().any(|idx| idx.same_id(init_idx));
173 if has_shared {
174 let input_idx = op_indices.iter().find(|idx| idx.same_id(init_idx));
176 if let Some(input_idx) = input_idx {
177 if input_idx != init_idx {
178 needs_mapping = true;
179 }
180 }
181 let output_idx = op_indices.iter().find(|idx| !idx.same_id(init_idx));
183 if output_idx.is_some() {
184 needs_mapping = true;
185 }
186 } else {
187 return Err(format!(
188 "Site {}: operator has 2 site indices but none share an ID with init's \
189 site index {:?}. Cannot auto-infer index mappings. \
190 Use the treetn-level API with explicit IndexMapping.",
191 site,
192 init_idx.id()
193 ));
194 }
195 }
196 }
197 }
198
199 if !needs_mapping {
200 return Ok((None, None));
201 }
202
203 let mut input_mapping = HashMap::new();
205 let mut output_mapping = HashMap::new();
206
207 for site in 0..nsites {
208 let op_site = op_treetn.site_space(&site);
209 let init_site = init_treetn.site_space(&site);
210
211 if let (Some(op_indices), Some(init_indices)) = (op_site, init_site) {
212 if op_indices.len() == 2 && init_indices.len() == 1 {
213 let init_idx = init_indices.iter().next().unwrap();
214
215 let op_input = op_indices.iter().find(|idx| idx.same_id(init_idx)).unwrap();
216 let op_output = op_indices
217 .iter()
218 .find(|idx| !idx.same_id(init_idx))
219 .unwrap();
220
221 input_mapping.insert(
222 site,
223 IndexMapping {
224 true_index: init_idx.clone(),
225 internal_index: op_input.clone(),
226 },
227 );
228 output_mapping.insert(
229 site,
230 IndexMapping {
231 true_index: init_idx.clone(),
232 internal_index: op_output.clone(),
233 },
234 );
235 }
236 }
237 }
238
239 Ok((Some(input_mapping), Some(output_mapping)))
240}
241
242impl TensorTrain {
243 pub fn linsolve(&self, rhs: &Self, init: Self, options: &LinsolveOptions) -> Result<Self> {
249 linsolve(self, rhs, init, options)
250 }
251}