1use std::collections::HashSet;
25
26use crate::any_scalar::AnyScalar;
27use crate::index_like::IndexLike;
28use crate::tensor_index::TensorIndex;
29use crate::tensor_like::{
30 AllowedPairs, DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult, TensorLike,
31};
32use anyhow::Result;
33
34#[derive(Debug, Clone)]
44pub struct BlockTensor<T: TensorLike> {
45 blocks: Vec<T>,
47 shape: (usize, usize),
49}
50
51impl<T: TensorLike> BlockTensor<T> {
52 pub fn try_new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
63 let (rows, cols) = shape;
64 anyhow::ensure!(
65 rows * cols == blocks.len(),
66 "Block count mismatch: shape ({}, {}) requires {} blocks, but got {}",
67 rows,
68 cols,
69 rows * cols,
70 blocks.len()
71 );
72 Ok(Self { blocks, shape })
73 }
74
75 pub fn new(blocks: Vec<T>, shape: (usize, usize)) -> Self {
86 Self::try_new(blocks, shape).expect("Invalid block tensor shape")
87 }
88
89 pub fn shape(&self) -> (usize, usize) {
91 self.shape
92 }
93
94 pub fn num_blocks(&self) -> usize {
96 self.blocks.len()
97 }
98
99 pub fn get(&self, row: usize, col: usize) -> &T {
105 let (rows, cols) = self.shape;
106 assert!(row < rows && col < cols, "Block index out of bounds");
107 &self.blocks[row * cols + col]
108 }
109
110 pub fn get_mut(&mut self, row: usize, col: usize) -> &mut T {
116 let (rows, cols) = self.shape;
117 assert!(row < rows && col < cols, "Block index out of bounds");
118 &mut self.blocks[row * cols + col]
119 }
120
121 pub fn blocks(&self) -> &[T] {
123 &self.blocks
124 }
125
126 pub fn blocks_mut(&mut self) -> &mut [T] {
128 &mut self.blocks
129 }
130
131 pub fn into_blocks(self) -> Vec<T> {
133 self.blocks
134 }
135
136 pub fn validate_indices(&self) -> Result<()> {
147 let (rows, cols) = self.shape;
148
149 if cols <= 1 {
150 return Ok(());
153 }
154
155 let first_count = self.blocks[0].num_external_indices();
157 for (i, block) in self.blocks.iter().enumerate().skip(1) {
158 let n = block.num_external_indices();
159 anyhow::ensure!(
160 n == first_count,
161 "Block {} has {} external indices, but block 0 has {}",
162 i,
163 n,
164 first_count
165 );
166 }
167
168 for row in 0..rows {
170 let ref_ids: HashSet<_> = self
171 .get(row, 0)
172 .external_indices()
173 .iter()
174 .map(|idx| idx.id().clone())
175 .collect();
176 for col in 1..cols {
177 let ids: HashSet<_> = self
178 .get(row, col)
179 .external_indices()
180 .iter()
181 .map(|idx| idx.id().clone())
182 .collect();
183 let common_count = ref_ids.intersection(&ids).count();
184 anyhow::ensure!(
185 common_count > 0,
186 "Matrix row {}: blocks ({},{}) and ({},{}) share no index IDs",
187 row,
188 row,
189 0,
190 row,
191 col
192 );
193 }
194 }
195
196 for col in 0..cols {
198 let ref_ids: HashSet<_> = self
199 .get(0, col)
200 .external_indices()
201 .iter()
202 .map(|idx| idx.id().clone())
203 .collect();
204 for row in 1..rows {
205 let ids: HashSet<_> = self
206 .get(row, col)
207 .external_indices()
208 .iter()
209 .map(|idx| idx.id().clone())
210 .collect();
211 let common_count = ref_ids.intersection(&ids).count();
212 anyhow::ensure!(
213 common_count > 0,
214 "Matrix col {}: blocks ({},{}) and ({},{}) share no index IDs",
215 col,
216 0,
217 col,
218 row,
219 col
220 );
221 }
222 }
223
224 Ok(())
225 }
226}
227
228impl<T: TensorLike> TensorIndex for BlockTensor<T> {
233 type Index = T::Index;
234
235 fn external_indices(&self) -> Vec<Self::Index> {
236 let mut seen = HashSet::new();
238 let mut result = Vec::new();
239 for block in &self.blocks {
240 for idx in block.external_indices() {
241 if seen.insert(idx.id().clone()) {
242 result.push(idx);
243 }
244 }
245 }
246 result
247 }
248
249 fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
250 let replaced: Result<Vec<T>> = self
251 .blocks
252 .iter()
253 .map(|b| b.replaceind(old_index, new_index))
254 .collect();
255 Ok(Self {
256 blocks: replaced?,
257 shape: self.shape,
258 })
259 }
260
261 fn replaceinds(
262 &self,
263 old_indices: &[Self::Index],
264 new_indices: &[Self::Index],
265 ) -> Result<Self> {
266 let replaced: Result<Vec<T>> = self
267 .blocks
268 .iter()
269 .map(|b| b.replaceinds(old_indices, new_indices))
270 .collect();
271 Ok(Self {
272 blocks: replaced?,
273 shape: self.shape,
274 })
275 }
276}
277
278impl<T: TensorLike> TensorLike for BlockTensor<T> {
283 fn norm_squared(&self) -> f64 {
288 self.blocks.iter().map(|b| b.norm_squared()).sum()
289 }
290
291 fn maxabs(&self) -> f64 {
292 self.blocks
293 .iter()
294 .map(|b| b.maxabs())
295 .fold(0.0_f64, f64::max)
296 }
297
298 fn scale(&self, scalar: AnyScalar) -> Result<Self> {
299 let scaled: Result<Vec<T>> = self
300 .blocks
301 .iter()
302 .map(|b| b.scale(scalar.clone()))
303 .collect();
304 Ok(Self {
305 blocks: scaled?,
306 shape: self.shape,
307 })
308 }
309
310 fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
311 anyhow::ensure!(
312 self.shape == other.shape,
313 "Block shapes must match: {:?} vs {:?}",
314 self.shape,
315 other.shape
316 );
317 let result: Result<Vec<T>> = self
318 .blocks
319 .iter()
320 .zip(other.blocks.iter())
321 .map(|(s, o)| s.axpby(a.clone(), o, b.clone()))
322 .collect();
323 Ok(Self {
324 blocks: result?,
325 shape: self.shape,
326 })
327 }
328
329 fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
330 anyhow::ensure!(
331 self.shape == other.shape,
332 "Block shapes must match for inner product: {:?} vs {:?}",
333 self.shape,
334 other.shape
335 );
336 let mut sum = AnyScalar::new_real(0.0);
337 for (s, o) in self.blocks.iter().zip(other.blocks.iter()) {
338 sum = sum + s.inner_product(o)?;
339 }
340 Ok(sum)
341 }
342
343 fn conj(&self) -> Self {
344 let conjugated: Vec<T> = self.blocks.iter().map(|b| b.conj()).collect();
345 Self {
346 blocks: conjugated,
347 shape: self.shape,
348 }
349 }
350
351 fn validate(&self) -> Result<()> {
352 self.validate_indices()
353 }
354
355 fn factorize(
360 &self,
361 _left_inds: &[<Self as TensorIndex>::Index],
362 _options: &FactorizeOptions,
363 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
364 Err(FactorizeError::ComputationError(anyhow::anyhow!(
365 "BlockTensor does not support factorize"
366 )))
367 }
368
369 fn direct_sum(
370 &self,
371 _other: &Self,
372 _pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
373 ) -> Result<DirectSumResult<Self>> {
374 anyhow::bail!("BlockTensor does not support direct_sum")
375 }
376
377 fn outer_product(&self, _other: &Self) -> Result<Self> {
378 anyhow::bail!("BlockTensor does not support outer_product")
379 }
380
381 fn permuteinds(&self, _new_order: &[<Self as TensorIndex>::Index]) -> Result<Self> {
382 anyhow::bail!("BlockTensor does not support permuteinds")
383 }
384
385 fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
386 anyhow::bail!("BlockTensor does not support contract")
387 }
388
389 fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
390 anyhow::bail!("BlockTensor does not support contract_connected")
391 }
392
393 fn diagonal(
394 _input_index: &<Self as TensorIndex>::Index,
395 _output_index: &<Self as TensorIndex>::Index,
396 ) -> Result<Self> {
397 anyhow::bail!("BlockTensor does not support diagonal")
398 }
399
400 fn scalar_one() -> Result<Self> {
401 anyhow::bail!("BlockTensor does not support scalar_one")
402 }
403
404 fn ones(_indices: &[<Self as TensorIndex>::Index]) -> Result<Self> {
405 anyhow::bail!("BlockTensor does not support ones")
406 }
407
408 fn onehot(_index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self> {
409 anyhow::bail!("BlockTensor does not support onehot")
410 }
411}
412
413#[cfg(test)]
414mod tests;