Skip to main content

tensor4all_core/
block_tensor.rs

1//! Block tensor type for GMRES with block matrices.
2//!
3//! This module provides [`BlockTensor`], a collection of tensors organized
4//! in a block structure. It implements [`TensorLike`] for the vector space
5//! operations required by GMRES, allowing block matrix linear equations
6//! `Ax = b` to be solved using the existing GMRES implementation.
7//!
8//! # Example
9//!
10//! ```
11//! use tensor4all_core::block_tensor::BlockTensor;
12//! use tensor4all_core::krylov::{gmres, GmresOptions};
13//! use tensor4all_core::{DynIndex, TensorDynLen};
14//!
15//! # fn main() -> anyhow::Result<()> {
16//! let i = DynIndex::new_dyn(2);
17//! let b_block = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0])?;
18//! let zero_block = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
19//!
20//! let b = BlockTensor::new(vec![b_block], (1, 1))?;
21//! let x0 = BlockTensor::new(vec![zero_block], (1, 1))?;
22//!
23//! let apply_a = |x: &BlockTensor<TensorDynLen>| Ok(x.clone());
24//! let result = gmres(apply_a, &b, &x0, &GmresOptions::default())?;
25//!
26//! assert!(result.converged);
27//! assert_eq!(result.solution.shape(), (1, 1));
28//! # Ok(())
29//! # }
30//! ```
31
32use std::collections::HashSet;
33
34use crate::any_scalar::AnyScalar;
35use crate::tensor_index::TensorIndex;
36use crate::tensor_like::{
37    DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult, LinearizationOrder,
38    TensorConstructionLike, TensorContractionLike, TensorFactorizationLike, TensorLike,
39    TensorVectorSpace,
40};
41use anyhow::Result;
42
43/// A collection of tensors organized in a block structure.
44///
45/// Each block is a tensor of type `T` implementing [`TensorLike`].
46/// The flattened block list is ordered row-by-row:
47/// `(0, 0), (0, 1), ..., (1, 0), (1, 1), ...`.
48///
49/// # Type Parameters
50///
51/// * `T` - The tensor type for each block, must implement `TensorLike`
52#[derive(Debug, Clone)]
53pub struct BlockTensor<T: TensorLike> {
54    /// Blocks flattened row-by-row in block-matrix order
55    blocks: Vec<T>,
56    /// Block structure (rows, cols)
57    shape: (usize, usize),
58}
59
60impl<T: TensorLike> BlockTensor<T> {
61    /// Create a new block tensor with validation.
62    ///
63    /// # Arguments
64    ///
65    /// * `blocks` - Vector of blocks flattened row-by-row
66    /// * `shape` - Block structure as (rows, cols)
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if `rows * cols != blocks.len()`.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use tensor4all_core::block_tensor::BlockTensor;
76    /// use tensor4all_core::{DynIndex, TensorDynLen};
77    ///
78    /// let i = DynIndex::new_dyn(2);
79    /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
80    /// let t2 = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap();
81    ///
82    /// let bt = BlockTensor::try_new(vec![t1, t2], (2, 1)).unwrap();
83    /// assert_eq!(bt.shape(), (2, 1));
84    ///
85    /// // Wrong number of blocks returns an error
86    /// let t3 = TensorDynLen::from_dense(vec![i], vec![5.0, 6.0]).unwrap();
87    /// assert!(BlockTensor::try_new(vec![t3], (2, 1)).is_err());
88    /// ```
89    pub fn try_new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
90        let (rows, cols) = shape;
91        anyhow::ensure!(
92            rows * cols == blocks.len(),
93            "Block count mismatch: shape ({}, {}) requires {} blocks, but got {}",
94            rows,
95            cols,
96            rows * cols,
97            blocks.len()
98        );
99        Ok(Self { blocks, shape })
100    }
101
102    /// Create a new block tensor.
103    ///
104    /// # Arguments
105    ///
106    /// * `blocks` - Vector of blocks flattened row-by-row
107    /// * `shape` - Block structure as (rows, cols)
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if `rows * cols != blocks.len()`.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use tensor4all_core::block_tensor::BlockTensor;
117    /// use tensor4all_core::{DynIndex, TensorDynLen};
118    ///
119    /// let i = DynIndex::new_dyn(2);
120    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
121    /// let bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
122    /// assert_eq!(bt.shape(), (1, 1));
123    /// ```
124    pub fn new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
125        Self::try_new(blocks, shape)
126    }
127
128    /// Get the block structure (rows, cols).
129    ///
130    /// # Examples
131    ///
132    /// ```
133    /// use tensor4all_core::block_tensor::BlockTensor;
134    /// use tensor4all_core::{DynIndex, TensorDynLen};
135    ///
136    /// let i = DynIndex::new_dyn(2);
137    /// let blocks: Vec<TensorDynLen> = (0..6)
138    ///     .map(|_| TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap())
139    ///     .collect();
140    /// let bt = BlockTensor::new(blocks, (2, 3)).unwrap();
141    /// assert_eq!(bt.shape(), (2, 3));
142    /// ```
143    pub fn shape(&self) -> (usize, usize) {
144        self.shape
145    }
146
147    /// Get the total number of blocks.
148    pub fn num_blocks(&self) -> usize {
149        self.blocks.len()
150    }
151
152    /// Get a reference to the block at (row, col).
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// use tensor4all_core::block_tensor::BlockTensor;
158    /// use tensor4all_core::{DynIndex, TensorDynLen};
159    ///
160    /// let i = DynIndex::new_dyn(2);
161    /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
162    /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
163    /// let bt = BlockTensor::new(vec![t1, t2], (2, 1)).unwrap();
164    /// assert_eq!(bt.get(0, 0).unwrap().dims(), vec![2]);
165    /// assert_eq!(bt.get(1, 0).unwrap().dims(), vec![2]);
166    /// ```
167    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
168        let (rows, cols) = self.shape;
169        if row >= rows || col >= cols {
170            return None;
171        }
172        self.blocks.get(row * cols + col)
173    }
174
175    /// Get a mutable reference to the block at (row, col).
176    ///
177    /// # Examples
178    ///
179    /// ```
180    /// use tensor4all_core::block_tensor::BlockTensor;
181    /// use tensor4all_core::{DynIndex, TensorDynLen};
182    ///
183    /// let i = DynIndex::new_dyn(2);
184    /// let t = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
185    /// let mut bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
186    /// let block = bt.get_mut(0, 0).unwrap();
187    /// assert_eq!(block.dims(), vec![2]);
188    /// ```
189    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
190        let (rows, cols) = self.shape;
191        if row >= rows || col >= cols {
192            return None;
193        }
194        self.blocks.get_mut(row * cols + col)
195    }
196
197    /// Get all blocks as a slice.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use tensor4all_core::block_tensor::BlockTensor;
203    /// use tensor4all_core::{DynIndex, TensorDynLen};
204    ///
205    /// let i = DynIndex::new_dyn(2);
206    /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
207    /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
208    /// let bt = BlockTensor::new(vec![t1, t2], (1, 2)).unwrap();
209    /// assert_eq!(bt.blocks().len(), 2);
210    /// ```
211    pub fn blocks(&self) -> &[T] {
212        &self.blocks
213    }
214
215    /// Get all blocks as a mutable slice.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use tensor4all_core::block_tensor::BlockTensor;
221    /// use tensor4all_core::{DynIndex, TensorDynLen};
222    ///
223    /// let i = DynIndex::new_dyn(2);
224    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
225    /// let mut bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
226    /// assert_eq!(bt.blocks_mut().len(), 1);
227    /// ```
228    pub fn blocks_mut(&mut self) -> &mut [T] {
229        &mut self.blocks
230    }
231
232    /// Consume self and return the blocks.
233    ///
234    /// # Examples
235    ///
236    /// ```
237    /// use tensor4all_core::block_tensor::BlockTensor;
238    /// use tensor4all_core::{DynIndex, TensorDynLen};
239    ///
240    /// let i = DynIndex::new_dyn(2);
241    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
242    /// let bt = BlockTensor::new(vec![t], (1, 1)).unwrap();
243    /// let blocks = bt.into_blocks();
244    /// assert_eq!(blocks.len(), 1);
245    /// assert_eq!(blocks[0].dims(), vec![2]);
246    /// ```
247    pub fn into_blocks(self) -> Vec<T> {
248        self.blocks
249    }
250
251    /// Validate that blocks share external indices consistently.
252    ///
253    /// For column vectors (cols=1), no index sharing is required between
254    /// blocks. Different rows can have independent physical indices
255    /// (the operator determines their relationship).
256    ///
257    /// For matrices (rows x cols), checks that:
258    /// - All blocks have the same number of external indices.
259    /// - Blocks in the same row share some common index IDs (output indices).
260    /// - Blocks in the same column share some common index IDs (input indices).
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use tensor4all_core::block_tensor::BlockTensor;
266    /// use tensor4all_core::{DynIndex, TensorDynLen};
267    ///
268    /// // Column vector: always valid regardless of index sharing
269    /// let i = DynIndex::new_dyn(2);
270    /// let j = DynIndex::new_dyn(2);
271    /// let t1 = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
272    /// let t2 = TensorDynLen::from_dense(vec![j], vec![3.0, 4.0]).unwrap();
273    /// let bt = BlockTensor::new(vec![t1, t2], (2, 1)).unwrap();
274    /// assert!(bt.validate_indices().is_ok());
275    /// ```
276    pub fn validate_indices(&self) -> Result<()> {
277        let (rows, cols) = self.shape;
278
279        if cols <= 1 {
280            // Column vector: blocks in different rows can have independent indices.
281            // The operator determines the relationship between blocks.
282            return Ok(());
283        }
284
285        // Matrix: check all blocks have the same number of external indices
286        let first_count = self.blocks[0].num_external_indices();
287        for (i, block) in self.blocks.iter().enumerate().skip(1) {
288            let n = block.num_external_indices();
289            anyhow::ensure!(
290                n == first_count,
291                "Block {} has {} external indices, but block 0 has {}",
292                i,
293                n,
294                first_count
295            );
296        }
297
298        // Same row: blocks should share at least one full output index.
299        for row in 0..rows {
300            let ref_indices: HashSet<_> = self
301                .get(row, 0)
302                .ok_or_else(|| anyhow::anyhow!("block index ({row}, 0) is out of bounds"))?
303                .external_indices()
304                .iter()
305                .cloned()
306                .collect();
307            for col in 1..cols {
308                let indices: HashSet<_> = self
309                    .get(row, col)
310                    .ok_or_else(|| anyhow::anyhow!("block index ({row}, {col}) is out of bounds"))?
311                    .external_indices()
312                    .iter()
313                    .cloned()
314                    .collect();
315                let common_count = ref_indices.intersection(&indices).count();
316                anyhow::ensure!(
317                    common_count > 0,
318                    "Matrix row {}: blocks ({},{}) and ({},{}) share no common indices",
319                    row,
320                    row,
321                    0,
322                    row,
323                    col
324                );
325            }
326        }
327
328        // Same column: blocks should share at least one full input index.
329        for col in 0..cols {
330            let ref_indices: HashSet<_> = self
331                .get(0, col)
332                .ok_or_else(|| anyhow::anyhow!("block index (0, {col}) is out of bounds"))?
333                .external_indices()
334                .iter()
335                .cloned()
336                .collect();
337            for row in 1..rows {
338                let indices: HashSet<_> = self
339                    .get(row, col)
340                    .ok_or_else(|| anyhow::anyhow!("block index ({row}, {col}) is out of bounds"))?
341                    .external_indices()
342                    .iter()
343                    .cloned()
344                    .collect();
345                let common_count = ref_indices.intersection(&indices).count();
346                anyhow::ensure!(
347                    common_count > 0,
348                    "Matrix col {}: blocks ({},{}) and ({},{}) share no common indices",
349                    col,
350                    0,
351                    col,
352                    row,
353                    col
354                );
355            }
356        }
357
358        Ok(())
359    }
360}
361
362// ============================================================================
363// TensorIndex implementation
364// ============================================================================
365
366impl<T: TensorLike> TensorIndex for BlockTensor<T> {
367    type Index = T::Index;
368
369    fn external_indices(&self) -> Vec<Self::Index> {
370        // Collect unique external indices across all blocks (deduplicated by full index).
371        let mut seen = HashSet::new();
372        let mut result = Vec::new();
373        for block in &self.blocks {
374            for idx in block.external_indices() {
375                if seen.insert(idx.clone()) {
376                    result.push(idx);
377                }
378            }
379        }
380        result
381    }
382
383    fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
384        let replaced: Result<Vec<T>> = self
385            .blocks
386            .iter()
387            .map(|b| b.replaceind(old_index, new_index))
388            .collect();
389        Ok(Self {
390            blocks: replaced?,
391            shape: self.shape,
392        })
393    }
394
395    fn replaceinds(
396        &self,
397        old_indices: &[Self::Index],
398        new_indices: &[Self::Index],
399    ) -> Result<Self> {
400        let replaced: Result<Vec<T>> = self
401            .blocks
402            .iter()
403            .map(|b| b.replaceinds(old_indices, new_indices))
404            .collect();
405        Ok(Self {
406            blocks: replaced?,
407            shape: self.shape,
408        })
409    }
410}
411
412// ============================================================================
413// TensorLike implementation
414// ============================================================================
415
416impl<T: TensorLike> TensorVectorSpace for BlockTensor<T> {
417    // ------------------------------------------------------------------------
418    // Vector space operations (required for GMRES)
419    // ------------------------------------------------------------------------
420
421    fn norm_squared(&self) -> f64 {
422        self.blocks.iter().map(|b| b.norm_squared()).sum()
423    }
424
425    fn try_maxabs(&self) -> Result<f64> {
426        self.blocks
427            .iter()
428            .try_fold(0.0_f64, |acc, block| Ok(acc.max(block.try_maxabs()?)))
429    }
430
431    fn maxabs(&self) -> f64 {
432        self.try_maxabs().unwrap_or(f64::NAN)
433    }
434
435    fn scale(&self, scalar: AnyScalar) -> Result<Self> {
436        let scaled: Result<Vec<T>> = self
437            .blocks
438            .iter()
439            .map(|b| b.scale(scalar.clone()))
440            .collect();
441        Ok(Self {
442            blocks: scaled?,
443            shape: self.shape,
444        })
445    }
446
447    fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
448        anyhow::ensure!(
449            self.shape == other.shape,
450            "Block shapes must match: {:?} vs {:?}",
451            self.shape,
452            other.shape
453        );
454        let result: Result<Vec<T>> = self
455            .blocks
456            .iter()
457            .zip(other.blocks.iter())
458            .map(|(s, o)| s.axpby(a.clone(), o, b.clone()))
459            .collect();
460        Ok(Self {
461            blocks: result?,
462            shape: self.shape,
463        })
464    }
465
466    fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
467        anyhow::ensure!(
468            self.shape == other.shape,
469            "Block shapes must match for inner product: {:?} vs {:?}",
470            self.shape,
471            other.shape
472        );
473        let mut sum = AnyScalar::new_real(0.0);
474        for (s, o) in self.blocks.iter().zip(other.blocks.iter()) {
475            sum = sum + s.inner_product(o)?;
476        }
477        Ok(sum)
478    }
479
480    fn validate(&self) -> Result<()> {
481        self.validate_indices()
482    }
483}
484
485impl<T: TensorLike> TensorContractionLike for BlockTensor<T> {
486    // ------------------------------------------------------------------------
487    // Tensor network operations
488    // ------------------------------------------------------------------------
489
490    fn conj(&self) -> Self {
491        let conjugated: Vec<T> = self.blocks.iter().map(|b| b.conj()).collect();
492        Self {
493            blocks: conjugated,
494            shape: self.shape,
495        }
496    }
497
498    fn direct_sum(
499        &self,
500        _other: &Self,
501        _pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
502    ) -> Result<DirectSumResult<Self>> {
503        anyhow::bail!("BlockTensor does not support direct_sum")
504    }
505
506    fn outer_product(&self, _other: &Self) -> Result<Self> {
507        anyhow::bail!("BlockTensor does not support outer_product")
508    }
509
510    fn permuteinds(&self, _new_order: &[<Self as TensorIndex>::Index]) -> Result<Self> {
511        anyhow::bail!("BlockTensor does not support permuteinds")
512    }
513
514    fn fuse_indices(
515        &self,
516        old_indices: &[Self::Index],
517        new_index: Self::Index,
518        order: LinearizationOrder,
519    ) -> Result<Self> {
520        let blocks: Result<Vec<T>> = self
521            .blocks
522            .iter()
523            .map(|block| block.fuse_indices(old_indices, new_index.clone(), order))
524            .collect();
525        Ok(Self {
526            blocks: blocks?,
527            shape: self.shape,
528        })
529    }
530
531    fn contract(_tensors: &[&Self]) -> Result<Self> {
532        anyhow::bail!("BlockTensor does not support contract")
533    }
534}
535
536impl<T: TensorLike> TensorFactorizationLike for BlockTensor<T> {
537    fn factorize(
538        &self,
539        _left_inds: &[<Self as TensorIndex>::Index],
540        _options: &FactorizeOptions,
541    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
542        Err(FactorizeError::ComputationError(anyhow::anyhow!(
543            "BlockTensor does not support factorize"
544        )))
545    }
546
547    fn factorize_full_rank(
548        &self,
549        _left_inds: &[<Self as TensorIndex>::Index],
550        _alg: crate::FactorizeAlg,
551        _canonical: crate::Canonical,
552    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
553        Err(FactorizeError::ComputationError(anyhow::anyhow!(
554            "BlockTensor does not support factorize_full_rank"
555        )))
556    }
557}
558
559impl<T: TensorLike> TensorConstructionLike for BlockTensor<T> {
560    fn diagonal(
561        _input_index: &<Self as TensorIndex>::Index,
562        _output_index: &<Self as TensorIndex>::Index,
563    ) -> Result<Self> {
564        anyhow::bail!("BlockTensor does not support diagonal")
565    }
566
567    fn scalar_one() -> Result<Self> {
568        anyhow::bail!("BlockTensor does not support scalar_one")
569    }
570
571    fn ones(_indices: &[<Self as TensorIndex>::Index]) -> Result<Self> {
572        anyhow::bail!("BlockTensor does not support ones")
573    }
574
575    fn onehot(_index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self> {
576        anyhow::bail!("BlockTensor does not support onehot")
577    }
578}
579
580#[cfg(test)]
581mod tests;