tensor4all_core/block_tensor.rs
1//! Block tensor type for GMRES with block matrices.
2//!
3//! This module provides [`BlockTensor`], a collection of tensors organized
4//! in a block structure. It implements [`TensorLike`] for the vector space
5//! operations required by GMRES, allowing block matrix linear equations
6//! `Ax = b` to be solved using the existing GMRES implementation.
7//!
8//! # Example
9//!
10//! ```
11//! use tensor4all_core::block_tensor::BlockTensor;
12//! use tensor4all_core::krylov::{gmres, GmresOptions};
13//! use tensor4all_core::{DynIndex, TensorDynLen};
14//!
15//! # fn main() -> anyhow::Result<()> {
16//! let i = DynIndex::new_dyn(2);
17//! let b_block = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0])?;
18//! let zero_block = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
19//!
20//! let b = BlockTensor::new(vec![b_block], (1, 1))?;
21//! let x0 = BlockTensor::new(vec![zero_block], (1, 1))?;
22//!
23//! let apply_a = |x: &BlockTensor<TensorDynLen>| Ok(x.clone());
24//! let result = gmres(apply_a, &b, &x0, &GmresOptions::default())?;
25//!
26//! assert!(result.converged);
27//! assert_eq!(result.solution.shape(), (1, 1));
28//! # Ok(())
29//! # }
30//! ```
31
32use std::collections::HashSet;
33
34use crate::any_scalar::AnyScalar;
35use crate::tensor_index::TensorIndex;
36use crate::tensor_like::{
37 DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult, LinearizationOrder,
38 TensorConstructionLike, TensorContractionLike, TensorFactorizationLike, TensorLike,
39 TensorVectorSpace,
40};
41use anyhow::Result;
42
43/// A collection of tensors organized in a block structure.
44///
45/// Each block is a tensor of type `T` implementing [`TensorLike`].
46/// The flattened block list is ordered row-by-row:
47/// `(0, 0), (0, 1), ..., (1, 0), (1, 1), ...`.
48///
49/// # Type Parameters
50///
51/// * `T` - The tensor type for each block, must implement `TensorLike`
52#[derive(Debug, Clone)]
53pub struct BlockTensor<T: TensorLike> {
54 /// Blocks flattened row-by-row in block-matrix order
55 blocks: Vec<T>,
56 /// Block structure (rows, cols)
57 shape: (usize, usize),
58}
59
60impl<T: TensorLike> BlockTensor<T> {
61 /// Create a new block tensor with validation.
62 ///
63 /// # Arguments
64 ///
65 /// * `blocks` - Vector of blocks flattened row-by-row
66 /// * `shape` - Block structure as (rows, cols)
67 ///
68 /// # Errors
69 ///
70 /// Returns an error if `rows * cols != blocks.len()`.
71 ///
72 /// # Examples
73 ///
74 /// ```
75 /// use tensor4all_core::block_tensor::BlockTensor;
76 /// use tensor4all_core::{DynIndex, TensorDynLen};
77 ///
78 /// let i = DynIndex::new_dyn(2);
79 /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
80 /// let t2 = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap();
81 ///
82 /// let bt = BlockTensor::try_new(vec![t1, t2], (2, 1)).unwrap();
83 /// assert_eq!(bt.shape(), (2, 1));
84 ///
85 /// // Wrong number of blocks returns an error
86 /// let t3 = TensorDynLen::from_dense(vec![i], vec![5.0, 6.0]).unwrap();
87 /// assert!(BlockTensor::try_new(vec![t3], (2, 1)).is_err());
88 /// ```
89 pub fn try_new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
90 let (rows, cols) = shape;
91 anyhow::ensure!(
92 rows * cols == blocks.len(),
93 "Block count mismatch: shape ({}, {}) requires {} blocks, but got {}",
94 rows,
95 cols,
96 rows * cols,
97 blocks.len()
98 );
99 Ok(Self { blocks, shape })
100 }
101
102 /// Create a new block tensor.
103 ///
104 /// # Arguments
105 ///
106 /// * `blocks` - Vector of blocks flattened row-by-row
107 /// * `shape` - Block structure as (rows, cols)
108 ///
109 /// # Errors
110 ///
111 /// Returns an error if `rows * cols != blocks.len()`.
112 ///
113 /// # Examples
114 ///
115 /// ```
116 /// use tensor4all_core::block_tensor::BlockTensor;
117 /// use tensor4all_core::{DynIndex, TensorDynLen};
118 ///
119 /// let i = DynIndex::new_dyn(2);
120 /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
121 /// let bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
122 /// assert_eq!(bt.shape(), (1, 1));
123 /// ```
124 pub fn new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
125 Self::try_new(blocks, shape)
126 }
127
128 /// Get the block structure (rows, cols).
129 ///
130 /// # Examples
131 ///
132 /// ```
133 /// use tensor4all_core::block_tensor::BlockTensor;
134 /// use tensor4all_core::{DynIndex, TensorDynLen};
135 ///
136 /// let i = DynIndex::new_dyn(2);
137 /// let blocks: Vec<TensorDynLen> = (0..6)
138 /// .map(|_| TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap())
139 /// .collect();
140 /// let bt = BlockTensor::new(blocks, (2, 3)).unwrap();
141 /// assert_eq!(bt.shape(), (2, 3));
142 /// ```
143 pub fn shape(&self) -> (usize, usize) {
144 self.shape
145 }
146
147 /// Get the total number of blocks.
148 pub fn num_blocks(&self) -> usize {
149 self.blocks.len()
150 }
151
152 /// Get a reference to the block at (row, col).
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// use tensor4all_core::block_tensor::BlockTensor;
158 /// use tensor4all_core::{DynIndex, TensorDynLen};
159 ///
160 /// let i = DynIndex::new_dyn(2);
161 /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
162 /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
163 /// let bt = BlockTensor::new(vec![t1, t2], (2, 1)).unwrap();
164 /// assert_eq!(bt.get(0, 0).unwrap().dims(), vec![2]);
165 /// assert_eq!(bt.get(1, 0).unwrap().dims(), vec![2]);
166 /// ```
167 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
168 let (rows, cols) = self.shape;
169 if row >= rows || col >= cols {
170 return None;
171 }
172 self.blocks.get(row * cols + col)
173 }
174
175 /// Get a mutable reference to the block at (row, col).
176 ///
177 /// # Examples
178 ///
179 /// ```
180 /// use tensor4all_core::block_tensor::BlockTensor;
181 /// use tensor4all_core::{DynIndex, TensorDynLen};
182 ///
183 /// let i = DynIndex::new_dyn(2);
184 /// let t = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
185 /// let mut bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
186 /// let block = bt.get_mut(0, 0).unwrap();
187 /// assert_eq!(block.dims(), vec![2]);
188 /// ```
189 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
190 let (rows, cols) = self.shape;
191 if row >= rows || col >= cols {
192 return None;
193 }
194 self.blocks.get_mut(row * cols + col)
195 }
196
197 /// Get all blocks as a slice.
198 ///
199 /// # Examples
200 ///
201 /// ```
202 /// use tensor4all_core::block_tensor::BlockTensor;
203 /// use tensor4all_core::{DynIndex, TensorDynLen};
204 ///
205 /// let i = DynIndex::new_dyn(2);
206 /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
207 /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
208 /// let bt = BlockTensor::new(vec![t1, t2], (1, 2)).unwrap();
209 /// assert_eq!(bt.blocks().len(), 2);
210 /// ```
211 pub fn blocks(&self) -> &[T] {
212 &self.blocks
213 }
214
215 /// Get all blocks as a mutable slice.
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// use tensor4all_core::block_tensor::BlockTensor;
221 /// use tensor4all_core::{DynIndex, TensorDynLen};
222 ///
223 /// let i = DynIndex::new_dyn(2);
224 /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
225 /// let mut bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
226 /// assert_eq!(bt.blocks_mut().len(), 1);
227 /// ```
228 pub fn blocks_mut(&mut self) -> &mut [T] {
229 &mut self.blocks
230 }
231
232 /// Consume self and return the blocks.
233 ///
234 /// # Examples
235 ///
236 /// ```
237 /// use tensor4all_core::block_tensor::BlockTensor;
238 /// use tensor4all_core::{DynIndex, TensorDynLen};
239 ///
240 /// let i = DynIndex::new_dyn(2);
241 /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
242 /// let bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
243 /// let blocks = bt.into_blocks();
244 /// assert_eq!(blocks.len(), 1);
245 /// assert_eq!(blocks[0].dims(), vec![2]);
246 /// ```
247 pub fn into_blocks(self) -> Vec<T> {
248 self.blocks
249 }
250
251 /// Validate that blocks share external indices consistently.
252 ///
253 /// For column vectors (cols=1), no index sharing is required between
254 /// blocks. Different rows can have independent physical indices
255 /// (the operator determines their relationship).
256 ///
257 /// For matrices (rows x cols), checks that:
258 /// - All blocks have the same number of external indices.
259 /// - Blocks in the same row share some common index IDs (output indices).
260 /// - Blocks in the same column share some common index IDs (input indices).
261 ///
262 /// # Examples
263 ///
264 /// ```
265 /// use tensor4all_core::block_tensor::BlockTensor;
266 /// use tensor4all_core::{DynIndex, TensorDynLen};
267 ///
268 /// // Column vector: always valid regardless of index sharing
269 /// let i = DynIndex::new_dyn(2);
270 /// let j = DynIndex::new_dyn(2);
271 /// let t1 = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
272 /// let t2 = TensorDynLen::from_dense(vec![j], vec![3.0, 4.0]).unwrap();
273 /// let bt = BlockTensor::new(vec![t1, t2], (2, 1)).unwrap();
274 /// assert!(bt.validate_indices().is_ok());
275 /// ```
276 pub fn validate_indices(&self) -> Result<()> {
277 let (rows, cols) = self.shape;
278
279 if cols <= 1 {
280 // Column vector: blocks in different rows can have independent indices.
281 // The operator determines the relationship between blocks.
282 return Ok(());
283 }
284
285 // Matrix: check all blocks have the same number of external indices
286 let first_count = self.blocks[0].num_external_indices();
287 for (i, block) in self.blocks.iter().enumerate().skip(1) {
288 let n = block.num_external_indices();
289 anyhow::ensure!(
290 n == first_count,
291 "Block {} has {} external indices, but block 0 has {}",
292 i,
293 n,
294 first_count
295 );
296 }
297
298 // Same row: blocks should share at least one full output index.
299 for row in 0..rows {
300 let ref_indices: HashSet<_> = self
301 .get(row, 0)
302 .ok_or_else(|| anyhow::anyhow!("block index ({row}, 0) is out of bounds"))?
303 .external_indices()
304 .iter()
305 .cloned()
306 .collect();
307 for col in 1..cols {
308 let indices: HashSet<_> = self
309 .get(row, col)
310 .ok_or_else(|| anyhow::anyhow!("block index ({row}, {col}) is out of bounds"))?
311 .external_indices()
312 .iter()
313 .cloned()
314 .collect();
315 let common_count = ref_indices.intersection(&indices).count();
316 anyhow::ensure!(
317 common_count > 0,
318 "Matrix row {}: blocks ({},{}) and ({},{}) share no common indices",
319 row,
320 row,
321 0,
322 row,
323 col
324 );
325 }
326 }
327
328 // Same column: blocks should share at least one full input index.
329 for col in 0..cols {
330 let ref_indices: HashSet<_> = self
331 .get(0, col)
332 .ok_or_else(|| anyhow::anyhow!("block index (0, {col}) is out of bounds"))?
333 .external_indices()
334 .iter()
335 .cloned()
336 .collect();
337 for row in 1..rows {
338 let indices: HashSet<_> = self
339 .get(row, col)
340 .ok_or_else(|| anyhow::anyhow!("block index ({row}, {col}) is out of bounds"))?
341 .external_indices()
342 .iter()
343 .cloned()
344 .collect();
345 let common_count = ref_indices.intersection(&indices).count();
346 anyhow::ensure!(
347 common_count > 0,
348 "Matrix col {}: blocks ({},{}) and ({},{}) share no common indices",
349 col,
350 0,
351 col,
352 row,
353 col
354 );
355 }
356 }
357
358 Ok(())
359 }
360}
361
362// ============================================================================
363// TensorIndex implementation
364// ============================================================================
365
366impl<T: TensorLike> TensorIndex for BlockTensor<T> {
367 type Index = T::Index;
368
369 fn external_indices(&self) -> Vec<Self::Index> {
370 // Collect unique external indices across all blocks (deduplicated by full index).
371 let mut seen = HashSet::new();
372 let mut result = Vec::new();
373 for block in &self.blocks {
374 for idx in block.external_indices() {
375 if seen.insert(idx.clone()) {
376 result.push(idx);
377 }
378 }
379 }
380 result
381 }
382
383 fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
384 let replaced: Result<Vec<T>> = self
385 .blocks
386 .iter()
387 .map(|b| b.replaceind(old_index, new_index))
388 .collect();
389 Ok(Self {
390 blocks: replaced?,
391 shape: self.shape,
392 })
393 }
394
395 fn replaceinds(
396 &self,
397 old_indices: &[Self::Index],
398 new_indices: &[Self::Index],
399 ) -> Result<Self> {
400 let replaced: Result<Vec<T>> = self
401 .blocks
402 .iter()
403 .map(|b| b.replaceinds(old_indices, new_indices))
404 .collect();
405 Ok(Self {
406 blocks: replaced?,
407 shape: self.shape,
408 })
409 }
410}
411
412// ============================================================================
413// TensorLike implementation
414// ============================================================================
415
416impl<T: TensorLike> TensorVectorSpace for BlockTensor<T> {
417 // ------------------------------------------------------------------------
418 // Vector space operations (required for GMRES)
419 // ------------------------------------------------------------------------
420
421 fn norm_squared(&self) -> f64 {
422 self.blocks.iter().map(|b| b.norm_squared()).sum()
423 }
424
425 fn try_maxabs(&self) -> Result<f64> {
426 self.blocks
427 .iter()
428 .try_fold(0.0_f64, |acc, block| Ok(acc.max(block.try_maxabs()?)))
429 }
430
431 fn maxabs(&self) -> f64 {
432 self.try_maxabs().unwrap_or(f64::NAN)
433 }
434
435 fn scale(&self, scalar: AnyScalar) -> Result<Self> {
436 let scaled: Result<Vec<T>> = self
437 .blocks
438 .iter()
439 .map(|b| b.scale(scalar.clone()))
440 .collect();
441 Ok(Self {
442 blocks: scaled?,
443 shape: self.shape,
444 })
445 }
446
447 fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
448 anyhow::ensure!(
449 self.shape == other.shape,
450 "Block shapes must match: {:?} vs {:?}",
451 self.shape,
452 other.shape
453 );
454 let result: Result<Vec<T>> = self
455 .blocks
456 .iter()
457 .zip(other.blocks.iter())
458 .map(|(s, o)| s.axpby(a.clone(), o, b.clone()))
459 .collect();
460 Ok(Self {
461 blocks: result?,
462 shape: self.shape,
463 })
464 }
465
466 fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
467 anyhow::ensure!(
468 self.shape == other.shape,
469 "Block shapes must match for inner product: {:?} vs {:?}",
470 self.shape,
471 other.shape
472 );
473 let mut sum = AnyScalar::new_real(0.0);
474 for (s, o) in self.blocks.iter().zip(other.blocks.iter()) {
475 sum = sum + s.inner_product(o)?;
476 }
477 Ok(sum)
478 }
479
480 fn validate(&self) -> Result<()> {
481 self.validate_indices()
482 }
483}
484
485impl<T: TensorLike> TensorContractionLike for BlockTensor<T> {
486 // ------------------------------------------------------------------------
487 // Tensor network operations
488 // ------------------------------------------------------------------------
489
490 fn conj(&self) -> Self {
491 let conjugated: Vec<T> = self.blocks.iter().map(|b| b.conj()).collect();
492 Self {
493 blocks: conjugated,
494 shape: self.shape,
495 }
496 }
497
498 fn direct_sum(
499 &self,
500 _other: &Self,
501 _pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
502 ) -> Result<DirectSumResult<Self>> {
503 anyhow::bail!("BlockTensor does not support direct_sum")
504 }
505
506 fn outer_product(&self, _other: &Self) -> Result<Self> {
507 anyhow::bail!("BlockTensor does not support outer_product")
508 }
509
510 fn permuteinds(&self, _new_order: &[<Self as TensorIndex>::Index]) -> Result<Self> {
511 anyhow::bail!("BlockTensor does not support permuteinds")
512 }
513
514 fn fuse_indices(
515 &self,
516 old_indices: &[Self::Index],
517 new_index: Self::Index,
518 order: LinearizationOrder,
519 ) -> Result<Self> {
520 let blocks: Result<Vec<T>> = self
521 .blocks
522 .iter()
523 .map(|block| block.fuse_indices(old_indices, new_index.clone(), order))
524 .collect();
525 Ok(Self {
526 blocks: blocks?,
527 shape: self.shape,
528 })
529 }
530
531 fn contract(_tensors: &[&Self]) -> Result<Self> {
532 anyhow::bail!("BlockTensor does not support contract")
533 }
534}
535
536impl<T: TensorLike> TensorFactorizationLike for BlockTensor<T> {
537 fn factorize(
538 &self,
539 _left_inds: &[<Self as TensorIndex>::Index],
540 _options: &FactorizeOptions,
541 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
542 Err(FactorizeError::ComputationError(anyhow::anyhow!(
543 "BlockTensor does not support factorize"
544 )))
545 }
546
547 fn factorize_full_rank(
548 &self,
549 _left_inds: &[<Self as TensorIndex>::Index],
550 _alg: crate::FactorizeAlg,
551 _canonical: crate::Canonical,
552 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
553 Err(FactorizeError::ComputationError(anyhow::anyhow!(
554 "BlockTensor does not support factorize_full_rank"
555 )))
556 }
557}
558
559impl<T: TensorLike> TensorConstructionLike for BlockTensor<T> {
560 fn diagonal(
561 _input_index: &<Self as TensorIndex>::Index,
562 _output_index: &<Self as TensorIndex>::Index,
563 ) -> Result<Self> {
564 anyhow::bail!("BlockTensor does not support diagonal")
565 }
566
567 fn scalar_one() -> Result<Self> {
568 anyhow::bail!("BlockTensor does not support scalar_one")
569 }
570
571 fn ones(_indices: &[<Self as TensorIndex>::Index]) -> Result<Self> {
572 anyhow::bail!("BlockTensor does not support ones")
573 }
574
575 fn onehot(_index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self> {
576 anyhow::bail!("BlockTensor does not support onehot")
577 }
578}
579
580#[cfg(test)]
581mod tests;