Skip to main content

tensor4all_core/
block_tensor.rs

1//! Block tensor type for GMRES with block matrices.
2//!
3//! This module provides [`BlockTensor`], a collection of tensors organized
4//! in a block structure. It implements [`TensorLike`] for the vector space
5//! operations required by GMRES, allowing block matrix linear equations
6//! `Ax = b` to be solved using the existing GMRES implementation.
7//!
8//! # Example
9//!
10//! ```
11//! use tensor4all_core::block_tensor::BlockTensor;
12//! use tensor4all_core::krylov::{gmres, GmresOptions};
13//! use tensor4all_core::{DynIndex, TensorDynLen};
14//!
15//! # fn main() -> anyhow::Result<()> {
16//! let i = DynIndex::new_dyn(2);
17//! let b_block = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0])?;
18//! let zero_block = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
19//!
20//! let b = BlockTensor::new(vec![b_block], (1, 1));
21//! let x0 = BlockTensor::new(vec![zero_block], (1, 1));
22//!
23//! let apply_a = |x: &BlockTensor<TensorDynLen>| Ok(x.clone());
24//! let result = gmres(apply_a, &b, &x0, &GmresOptions::default())?;
25//!
26//! assert!(result.converged);
27//! assert_eq!(result.solution.shape(), (1, 1));
28//! # Ok(())
29//! # }
30//! ```
31
32use std::collections::HashSet;
33
34use crate::any_scalar::AnyScalar;
35use crate::tensor_index::TensorIndex;
36use crate::tensor_like::{
37    AllowedPairs, DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult,
38    LinearizationOrder, TensorLike,
39};
40use anyhow::Result;
41
42/// A collection of tensors organized in a block structure.
43///
44/// Each block is a tensor of type `T` implementing [`TensorLike`].
45/// The flattened block list is ordered row-by-row:
46/// `(0, 0), (0, 1), ..., (1, 0), (1, 1), ...`.
47///
48/// # Type Parameters
49///
50/// * `T` - The tensor type for each block, must implement `TensorLike`
51#[derive(Debug, Clone)]
52pub struct BlockTensor<T: TensorLike> {
53    /// Blocks flattened row-by-row in block-matrix order
54    blocks: Vec<T>,
55    /// Block structure (rows, cols)
56    shape: (usize, usize),
57}
58
59impl<T: TensorLike> BlockTensor<T> {
60    /// Create a new block tensor with validation.
61    ///
62    /// # Arguments
63    ///
64    /// * `blocks` - Vector of blocks flattened row-by-row
65    /// * `shape` - Block structure as (rows, cols)
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if `rows * cols != blocks.len()`.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use tensor4all_core::block_tensor::BlockTensor;
75    /// use tensor4all_core::{DynIndex, TensorDynLen};
76    ///
77    /// let i = DynIndex::new_dyn(2);
78    /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
79    /// let t2 = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap();
80    ///
81    /// let bt = BlockTensor::try_new(vec![t1, t2], (2, 1)).unwrap();
82    /// assert_eq!(bt.shape(), (2, 1));
83    ///
84    /// // Wrong number of blocks returns an error
85    /// let t3 = TensorDynLen::from_dense(vec![i], vec![5.0, 6.0]).unwrap();
86    /// assert!(BlockTensor::try_new(vec![t3], (2, 1)).is_err());
87    /// ```
88    pub fn try_new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
89        let (rows, cols) = shape;
90        anyhow::ensure!(
91            rows * cols == blocks.len(),
92            "Block count mismatch: shape ({}, {}) requires {} blocks, but got {}",
93            rows,
94            cols,
95            rows * cols,
96            blocks.len()
97        );
98        Ok(Self { blocks, shape })
99    }
100
101    /// Create a new block tensor.
102    ///
103    /// # Arguments
104    ///
105    /// * `blocks` - Vector of blocks flattened row-by-row
106    /// * `shape` - Block structure as (rows, cols)
107    ///
108    /// # Panics
109    ///
110    /// Panics if `rows * cols != blocks.len()`.
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// use tensor4all_core::block_tensor::BlockTensor;
116    /// use tensor4all_core::{DynIndex, TensorDynLen};
117    ///
118    /// let i = DynIndex::new_dyn(2);
119    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
120    /// let bt = BlockTensor::new(vec![t], (1, 1));
121    /// assert_eq!(bt.shape(), (1, 1));
122    /// ```
123    pub fn new(blocks: Vec<T>, shape: (usize, usize)) -> Self {
124        Self::try_new(blocks, shape).expect("Invalid block tensor shape")
125    }
126
127    /// Get the block structure (rows, cols).
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use tensor4all_core::block_tensor::BlockTensor;
133    /// use tensor4all_core::{DynIndex, TensorDynLen};
134    ///
135    /// let i = DynIndex::new_dyn(2);
136    /// let blocks: Vec<TensorDynLen> = (0..6)
137    ///     .map(|_| TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap())
138    ///     .collect();
139    /// let bt = BlockTensor::new(blocks, (2, 3));
140    /// assert_eq!(bt.shape(), (2, 3));
141    /// ```
142    pub fn shape(&self) -> (usize, usize) {
143        self.shape
144    }
145
146    /// Get the total number of blocks.
147    pub fn num_blocks(&self) -> usize {
148        self.blocks.len()
149    }
150
151    /// Get a reference to the block at (row, col).
152    ///
153    /// # Panics
154    ///
155    /// Panics if the indices are out of bounds.
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// use tensor4all_core::block_tensor::BlockTensor;
161    /// use tensor4all_core::{DynIndex, TensorDynLen};
162    ///
163    /// let i = DynIndex::new_dyn(2);
164    /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
165    /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
166    /// let bt = BlockTensor::new(vec![t1, t2], (2, 1));
167    /// assert_eq!(bt.get(0, 0).dims(), vec![2]);
168    /// assert_eq!(bt.get(1, 0).dims(), vec![2]);
169    /// ```
170    pub fn get(&self, row: usize, col: usize) -> &T {
171        let (rows, cols) = self.shape;
172        assert!(row < rows && col < cols, "Block index out of bounds");
173        &self.blocks[row * cols + col]
174    }
175
176    /// Get a mutable reference to the block at (row, col).
177    ///
178    /// # Panics
179    ///
180    /// Panics if the indices are out of bounds.
181    ///
182    /// # Examples
183    ///
184    /// ```
185    /// use tensor4all_core::block_tensor::BlockTensor;
186    /// use tensor4all_core::{DynIndex, TensorDynLen};
187    ///
188    /// let i = DynIndex::new_dyn(2);
189    /// let t = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
190    /// let mut bt = BlockTensor::new(vec![t], (1, 1));
191    /// let block = bt.get_mut(0, 0);
192    /// assert_eq!(block.dims(), vec![2]);
193    /// ```
194    pub fn get_mut(&mut self, row: usize, col: usize) -> &mut T {
195        let (rows, cols) = self.shape;
196        assert!(row < rows && col < cols, "Block index out of bounds");
197        &mut self.blocks[row * cols + col]
198    }
199
200    /// Get all blocks as a slice.
201    ///
202    /// # Examples
203    ///
204    /// ```
205    /// use tensor4all_core::block_tensor::BlockTensor;
206    /// use tensor4all_core::{DynIndex, TensorDynLen};
207    ///
208    /// let i = DynIndex::new_dyn(2);
209    /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
210    /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
211    /// let bt = BlockTensor::new(vec![t1, t2], (1, 2));
212    /// assert_eq!(bt.blocks().len(), 2);
213    /// ```
214    pub fn blocks(&self) -> &[T] {
215        &self.blocks
216    }
217
218    /// Get all blocks as a mutable slice.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use tensor4all_core::block_tensor::BlockTensor;
224    /// use tensor4all_core::{DynIndex, TensorDynLen};
225    ///
226    /// let i = DynIndex::new_dyn(2);
227    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
228    /// let mut bt = BlockTensor::new(vec![t], (1, 1));
229    /// assert_eq!(bt.blocks_mut().len(), 1);
230    /// ```
231    pub fn blocks_mut(&mut self) -> &mut [T] {
232        &mut self.blocks
233    }
234
235    /// Consume self and return the blocks.
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// use tensor4all_core::block_tensor::BlockTensor;
241    /// use tensor4all_core::{DynIndex, TensorDynLen};
242    ///
243    /// let i = DynIndex::new_dyn(2);
244    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
245    /// let bt = BlockTensor::new(vec![t], (1, 1));
246    /// let blocks = bt.into_blocks();
247    /// assert_eq!(blocks.len(), 1);
248    /// assert_eq!(blocks[0].dims(), vec![2]);
249    /// ```
250    pub fn into_blocks(self) -> Vec<T> {
251        self.blocks
252    }
253
254    /// Validate that blocks share external indices consistently.
255    ///
256    /// For column vectors (cols=1), no index sharing is required between
257    /// blocks. Different rows can have independent physical indices
258    /// (the operator determines their relationship).
259    ///
260    /// For matrices (rows x cols), checks that:
261    /// - All blocks have the same number of external indices.
262    /// - Blocks in the same row share some common index IDs (output indices).
263    /// - Blocks in the same column share some common index IDs (input indices).
264    ///
265    /// # Examples
266    ///
267    /// ```
268    /// use tensor4all_core::block_tensor::BlockTensor;
269    /// use tensor4all_core::{DynIndex, TensorDynLen};
270    ///
271    /// // Column vector: always valid regardless of index sharing
272    /// let i = DynIndex::new_dyn(2);
273    /// let j = DynIndex::new_dyn(2);
274    /// let t1 = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
275    /// let t2 = TensorDynLen::from_dense(vec![j], vec![3.0, 4.0]).unwrap();
276    /// let bt = BlockTensor::new(vec![t1, t2], (2, 1));
277    /// assert!(bt.validate_indices().is_ok());
278    /// ```
279    pub fn validate_indices(&self) -> Result<()> {
280        let (rows, cols) = self.shape;
281
282        if cols <= 1 {
283            // Column vector: blocks in different rows can have independent indices.
284            // The operator determines the relationship between blocks.
285            return Ok(());
286        }
287
288        // Matrix: check all blocks have the same number of external indices
289        let first_count = self.blocks[0].num_external_indices();
290        for (i, block) in self.blocks.iter().enumerate().skip(1) {
291            let n = block.num_external_indices();
292            anyhow::ensure!(
293                n == first_count,
294                "Block {} has {} external indices, but block 0 has {}",
295                i,
296                n,
297                first_count
298            );
299        }
300
301        // Same row: blocks should share at least one full output index.
302        for row in 0..rows {
303            let ref_indices: HashSet<_> = self
304                .get(row, 0)
305                .external_indices()
306                .iter()
307                .cloned()
308                .collect();
309            for col in 1..cols {
310                let indices: HashSet<_> = self
311                    .get(row, col)
312                    .external_indices()
313                    .iter()
314                    .cloned()
315                    .collect();
316                let common_count = ref_indices.intersection(&indices).count();
317                anyhow::ensure!(
318                    common_count > 0,
319                    "Matrix row {}: blocks ({},{}) and ({},{}) share no common indices",
320                    row,
321                    row,
322                    0,
323                    row,
324                    col
325                );
326            }
327        }
328
329        // Same column: blocks should share at least one full input index.
330        for col in 0..cols {
331            let ref_indices: HashSet<_> = self
332                .get(0, col)
333                .external_indices()
334                .iter()
335                .cloned()
336                .collect();
337            for row in 1..rows {
338                let indices: HashSet<_> = self
339                    .get(row, col)
340                    .external_indices()
341                    .iter()
342                    .cloned()
343                    .collect();
344                let common_count = ref_indices.intersection(&indices).count();
345                anyhow::ensure!(
346                    common_count > 0,
347                    "Matrix col {}: blocks ({},{}) and ({},{}) share no common indices",
348                    col,
349                    0,
350                    col,
351                    row,
352                    col
353                );
354            }
355        }
356
357        Ok(())
358    }
359}
360
361// ============================================================================
362// TensorIndex implementation
363// ============================================================================
364
365impl<T: TensorLike> TensorIndex for BlockTensor<T> {
366    type Index = T::Index;
367
368    fn external_indices(&self) -> Vec<Self::Index> {
369        // Collect unique external indices across all blocks (deduplicated by full index).
370        let mut seen = HashSet::new();
371        let mut result = Vec::new();
372        for block in &self.blocks {
373            for idx in block.external_indices() {
374                if seen.insert(idx.clone()) {
375                    result.push(idx);
376                }
377            }
378        }
379        result
380    }
381
382    fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
383        let replaced: Result<Vec<T>> = self
384            .blocks
385            .iter()
386            .map(|b| b.replaceind(old_index, new_index))
387            .collect();
388        Ok(Self {
389            blocks: replaced?,
390            shape: self.shape,
391        })
392    }
393
394    fn replaceinds(
395        &self,
396        old_indices: &[Self::Index],
397        new_indices: &[Self::Index],
398    ) -> Result<Self> {
399        let replaced: Result<Vec<T>> = self
400            .blocks
401            .iter()
402            .map(|b| b.replaceinds(old_indices, new_indices))
403            .collect();
404        Ok(Self {
405            blocks: replaced?,
406            shape: self.shape,
407        })
408    }
409}
410
411// ============================================================================
412// TensorLike implementation
413// ============================================================================
414
415impl<T: TensorLike> TensorLike for BlockTensor<T> {
416    // ------------------------------------------------------------------------
417    // Vector space operations (required for GMRES)
418    // ------------------------------------------------------------------------
419
420    fn norm_squared(&self) -> f64 {
421        self.blocks.iter().map(|b| b.norm_squared()).sum()
422    }
423
424    fn maxabs(&self) -> f64 {
425        self.blocks
426            .iter()
427            .map(|b| b.maxabs())
428            .fold(0.0_f64, f64::max)
429    }
430
431    fn scale(&self, scalar: AnyScalar) -> Result<Self> {
432        let scaled: Result<Vec<T>> = self
433            .blocks
434            .iter()
435            .map(|b| b.scale(scalar.clone()))
436            .collect();
437        Ok(Self {
438            blocks: scaled?,
439            shape: self.shape,
440        })
441    }
442
443    fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
444        anyhow::ensure!(
445            self.shape == other.shape,
446            "Block shapes must match: {:?} vs {:?}",
447            self.shape,
448            other.shape
449        );
450        let result: Result<Vec<T>> = self
451            .blocks
452            .iter()
453            .zip(other.blocks.iter())
454            .map(|(s, o)| s.axpby(a.clone(), o, b.clone()))
455            .collect();
456        Ok(Self {
457            blocks: result?,
458            shape: self.shape,
459        })
460    }
461
462    fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
463        anyhow::ensure!(
464            self.shape == other.shape,
465            "Block shapes must match for inner product: {:?} vs {:?}",
466            self.shape,
467            other.shape
468        );
469        let mut sum = AnyScalar::new_real(0.0);
470        for (s, o) in self.blocks.iter().zip(other.blocks.iter()) {
471            sum = sum + s.inner_product(o)?;
472        }
473        Ok(sum)
474    }
475
476    fn conj(&self) -> Self {
477        let conjugated: Vec<T> = self.blocks.iter().map(|b| b.conj()).collect();
478        Self {
479            blocks: conjugated,
480            shape: self.shape,
481        }
482    }
483
484    fn validate(&self) -> Result<()> {
485        self.validate_indices()
486    }
487
488    // ------------------------------------------------------------------------
489    // Operations not supported for BlockTensor (return error, don't panic)
490    // ------------------------------------------------------------------------
491
492    fn factorize(
493        &self,
494        _left_inds: &[<Self as TensorIndex>::Index],
495        _options: &FactorizeOptions,
496    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
497        Err(FactorizeError::ComputationError(anyhow::anyhow!(
498            "BlockTensor does not support factorize"
499        )))
500    }
501
502    fn factorize_full_rank(
503        &self,
504        _left_inds: &[<Self as TensorIndex>::Index],
505        _alg: crate::FactorizeAlg,
506        _canonical: crate::Canonical,
507    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
508        Err(FactorizeError::ComputationError(anyhow::anyhow!(
509            "BlockTensor does not support factorize_full_rank"
510        )))
511    }
512
513    fn direct_sum(
514        &self,
515        _other: &Self,
516        _pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
517    ) -> Result<DirectSumResult<Self>> {
518        anyhow::bail!("BlockTensor does not support direct_sum")
519    }
520
521    fn outer_product(&self, _other: &Self) -> Result<Self> {
522        anyhow::bail!("BlockTensor does not support outer_product")
523    }
524
525    fn permuteinds(&self, _new_order: &[<Self as TensorIndex>::Index]) -> Result<Self> {
526        anyhow::bail!("BlockTensor does not support permuteinds")
527    }
528
529    fn fuse_indices(
530        &self,
531        old_indices: &[Self::Index],
532        new_index: Self::Index,
533        order: LinearizationOrder,
534    ) -> Result<Self> {
535        let blocks: Result<Vec<T>> = self
536            .blocks
537            .iter()
538            .map(|block| block.fuse_indices(old_indices, new_index.clone(), order))
539            .collect();
540        Ok(Self {
541            blocks: blocks?,
542            shape: self.shape,
543        })
544    }
545
546    fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
547        anyhow::bail!("BlockTensor does not support contract")
548    }
549
550    fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
551        anyhow::bail!("BlockTensor does not support contract_connected")
552    }
553
554    fn diagonal(
555        _input_index: &<Self as TensorIndex>::Index,
556        _output_index: &<Self as TensorIndex>::Index,
557    ) -> Result<Self> {
558        anyhow::bail!("BlockTensor does not support diagonal")
559    }
560
561    fn scalar_one() -> Result<Self> {
562        anyhow::bail!("BlockTensor does not support scalar_one")
563    }
564
565    fn ones(_indices: &[<Self as TensorIndex>::Index]) -> Result<Self> {
566        anyhow::bail!("BlockTensor does not support ones")
567    }
568
569    fn onehot(_index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self> {
570        anyhow::bail!("BlockTensor does not support onehot")
571    }
572}
573
574#[cfg(test)]
575mod tests;