tensor4all_core/
block_tensor.rs

1//! Block tensor type for GMRES with block matrices.
2//!
3//! This module provides [`BlockTensor`], a collection of tensors organized
4//! in a block structure. It implements [`TensorLike`] for the vector space
5//! operations required by GMRES, allowing block matrix linear equations
6//! `Ax = b` to be solved using the existing GMRES implementation.
7//!
8//! # Example
9//!
10//! ```ignore
11//! use tensor4all_core::block_tensor::BlockTensor;
12//! use tensor4all_core::krylov::{gmres, GmresOptions};
13//!
14//! // Create 2x1 block vectors
15//! let b = BlockTensor::new(vec![b1, b2], (2, 1));
16//! let x0 = BlockTensor::new(vec![zero1, zero2], (2, 1));
17//!
18//! // Define block matrix operator
19//! let apply_a = |x: &BlockTensor<T>| { /* ... */ };
20//!
21//! let result = gmres(apply_a, &b, &x0, &GmresOptions::default())?;
22//! ```
23
24use std::collections::HashSet;
25
26use crate::any_scalar::AnyScalar;
27use crate::index_like::IndexLike;
28use crate::tensor_index::TensorIndex;
29use crate::tensor_like::{
30    AllowedPairs, DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult, TensorLike,
31};
32use anyhow::Result;
33
34/// A collection of tensors organized in a block structure.
35///
36/// Each block is a tensor of type `T` implementing [`TensorLike`].
37/// The flattened block list is ordered row-by-row:
38/// `(0, 0), (0, 1), ..., (1, 0), (1, 1), ...`.
39///
40/// # Type Parameters
41///
42/// * `T` - The tensor type for each block, must implement `TensorLike`
43#[derive(Debug, Clone)]
44pub struct BlockTensor<T: TensorLike> {
45    /// Blocks flattened row-by-row in block-matrix order
46    blocks: Vec<T>,
47    /// Block structure (rows, cols)
48    shape: (usize, usize),
49}
50
51impl<T: TensorLike> BlockTensor<T> {
52    /// Create a new block tensor with validation.
53    ///
54    /// # Arguments
55    ///
56    /// * `blocks` - Vector of blocks flattened row-by-row
57    /// * `shape` - Block structure as (rows, cols)
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if `rows * cols != blocks.len()`.
62    pub fn try_new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
63        let (rows, cols) = shape;
64        anyhow::ensure!(
65            rows * cols == blocks.len(),
66            "Block count mismatch: shape ({}, {}) requires {} blocks, but got {}",
67            rows,
68            cols,
69            rows * cols,
70            blocks.len()
71        );
72        Ok(Self { blocks, shape })
73    }
74
75    /// Create a new block tensor.
76    ///
77    /// # Arguments
78    ///
79    /// * `blocks` - Vector of blocks flattened row-by-row
80    /// * `shape` - Block structure as (rows, cols)
81    ///
82    /// # Panics
83    ///
84    /// Panics if `rows * cols != blocks.len()`.
85    pub fn new(blocks: Vec<T>, shape: (usize, usize)) -> Self {
86        Self::try_new(blocks, shape).expect("Invalid block tensor shape")
87    }
88
89    /// Get the block structure (rows, cols).
90    pub fn shape(&self) -> (usize, usize) {
91        self.shape
92    }
93
94    /// Get the total number of blocks.
95    pub fn num_blocks(&self) -> usize {
96        self.blocks.len()
97    }
98
99    /// Get a reference to the block at (row, col).
100    ///
101    /// # Panics
102    ///
103    /// Panics if the indices are out of bounds.
104    pub fn get(&self, row: usize, col: usize) -> &T {
105        let (rows, cols) = self.shape;
106        assert!(row < rows && col < cols, "Block index out of bounds");
107        &self.blocks[row * cols + col]
108    }
109
110    /// Get a mutable reference to the block at (row, col).
111    ///
112    /// # Panics
113    ///
114    /// Panics if the indices are out of bounds.
115    pub fn get_mut(&mut self, row: usize, col: usize) -> &mut T {
116        let (rows, cols) = self.shape;
117        assert!(row < rows && col < cols, "Block index out of bounds");
118        &mut self.blocks[row * cols + col]
119    }
120
121    /// Get all blocks as a slice.
122    pub fn blocks(&self) -> &[T] {
123        &self.blocks
124    }
125
126    /// Get all blocks as a mutable slice.
127    pub fn blocks_mut(&mut self) -> &mut [T] {
128        &mut self.blocks
129    }
130
131    /// Consume self and return the blocks.
132    pub fn into_blocks(self) -> Vec<T> {
133        self.blocks
134    }
135
136    /// Validate that blocks share external indices consistently.
137    ///
138    /// For column vectors (cols=1), no index sharing is required between
139    /// blocks. Different rows can have independent physical indices
140    /// (the operator determines their relationship).
141    ///
142    /// For matrices (rows x cols), checks that:
143    /// - All blocks have the same number of external indices.
144    /// - Blocks in the same row share some common index IDs (output indices).
145    /// - Blocks in the same column share some common index IDs (input indices).
146    pub fn validate_indices(&self) -> Result<()> {
147        let (rows, cols) = self.shape;
148
149        if cols <= 1 {
150            // Column vector: blocks in different rows can have independent indices.
151            // The operator determines the relationship between blocks.
152            return Ok(());
153        }
154
155        // Matrix: check all blocks have the same number of external indices
156        let first_count = self.blocks[0].num_external_indices();
157        for (i, block) in self.blocks.iter().enumerate().skip(1) {
158            let n = block.num_external_indices();
159            anyhow::ensure!(
160                n == first_count,
161                "Block {} has {} external indices, but block 0 has {}",
162                i,
163                n,
164                first_count
165            );
166        }
167
168        // Same row: blocks should share some common index IDs (output indices)
169        for row in 0..rows {
170            let ref_ids: HashSet<_> = self
171                .get(row, 0)
172                .external_indices()
173                .iter()
174                .map(|idx| idx.id().clone())
175                .collect();
176            for col in 1..cols {
177                let ids: HashSet<_> = self
178                    .get(row, col)
179                    .external_indices()
180                    .iter()
181                    .map(|idx| idx.id().clone())
182                    .collect();
183                let common_count = ref_ids.intersection(&ids).count();
184                anyhow::ensure!(
185                    common_count > 0,
186                    "Matrix row {}: blocks ({},{}) and ({},{}) share no index IDs",
187                    row,
188                    row,
189                    0,
190                    row,
191                    col
192                );
193            }
194        }
195
196        // Same column: blocks should share some common index IDs (input indices)
197        for col in 0..cols {
198            let ref_ids: HashSet<_> = self
199                .get(0, col)
200                .external_indices()
201                .iter()
202                .map(|idx| idx.id().clone())
203                .collect();
204            for row in 1..rows {
205                let ids: HashSet<_> = self
206                    .get(row, col)
207                    .external_indices()
208                    .iter()
209                    .map(|idx| idx.id().clone())
210                    .collect();
211                let common_count = ref_ids.intersection(&ids).count();
212                anyhow::ensure!(
213                    common_count > 0,
214                    "Matrix col {}: blocks ({},{}) and ({},{}) share no index IDs",
215                    col,
216                    0,
217                    col,
218                    row,
219                    col
220                );
221            }
222        }
223
224        Ok(())
225    }
226}
227
228// ============================================================================
229// TensorIndex implementation
230// ============================================================================
231
232impl<T: TensorLike> TensorIndex for BlockTensor<T> {
233    type Index = T::Index;
234
235    fn external_indices(&self) -> Vec<Self::Index> {
236        // Collect unique external indices across all blocks (deduplicated by ID).
237        let mut seen = HashSet::new();
238        let mut result = Vec::new();
239        for block in &self.blocks {
240            for idx in block.external_indices() {
241                if seen.insert(idx.id().clone()) {
242                    result.push(idx);
243                }
244            }
245        }
246        result
247    }
248
249    fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
250        let replaced: Result<Vec<T>> = self
251            .blocks
252            .iter()
253            .map(|b| b.replaceind(old_index, new_index))
254            .collect();
255        Ok(Self {
256            blocks: replaced?,
257            shape: self.shape,
258        })
259    }
260
261    fn replaceinds(
262        &self,
263        old_indices: &[Self::Index],
264        new_indices: &[Self::Index],
265    ) -> Result<Self> {
266        let replaced: Result<Vec<T>> = self
267            .blocks
268            .iter()
269            .map(|b| b.replaceinds(old_indices, new_indices))
270            .collect();
271        Ok(Self {
272            blocks: replaced?,
273            shape: self.shape,
274        })
275    }
276}
277
278// ============================================================================
279// TensorLike implementation
280// ============================================================================
281
282impl<T: TensorLike> TensorLike for BlockTensor<T> {
283    // ------------------------------------------------------------------------
284    // Vector space operations (required for GMRES)
285    // ------------------------------------------------------------------------
286
287    fn norm_squared(&self) -> f64 {
288        self.blocks.iter().map(|b| b.norm_squared()).sum()
289    }
290
291    fn maxabs(&self) -> f64 {
292        self.blocks
293            .iter()
294            .map(|b| b.maxabs())
295            .fold(0.0_f64, f64::max)
296    }
297
298    fn scale(&self, scalar: AnyScalar) -> Result<Self> {
299        let scaled: Result<Vec<T>> = self
300            .blocks
301            .iter()
302            .map(|b| b.scale(scalar.clone()))
303            .collect();
304        Ok(Self {
305            blocks: scaled?,
306            shape: self.shape,
307        })
308    }
309
310    fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
311        anyhow::ensure!(
312            self.shape == other.shape,
313            "Block shapes must match: {:?} vs {:?}",
314            self.shape,
315            other.shape
316        );
317        let result: Result<Vec<T>> = self
318            .blocks
319            .iter()
320            .zip(other.blocks.iter())
321            .map(|(s, o)| s.axpby(a.clone(), o, b.clone()))
322            .collect();
323        Ok(Self {
324            blocks: result?,
325            shape: self.shape,
326        })
327    }
328
329    fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
330        anyhow::ensure!(
331            self.shape == other.shape,
332            "Block shapes must match for inner product: {:?} vs {:?}",
333            self.shape,
334            other.shape
335        );
336        let mut sum = AnyScalar::new_real(0.0);
337        for (s, o) in self.blocks.iter().zip(other.blocks.iter()) {
338            sum = sum + s.inner_product(o)?;
339        }
340        Ok(sum)
341    }
342
343    fn conj(&self) -> Self {
344        let conjugated: Vec<T> = self.blocks.iter().map(|b| b.conj()).collect();
345        Self {
346            blocks: conjugated,
347            shape: self.shape,
348        }
349    }
350
351    fn validate(&self) -> Result<()> {
352        self.validate_indices()
353    }
354
355    // ------------------------------------------------------------------------
356    // Operations not supported for BlockTensor (return error, don't panic)
357    // ------------------------------------------------------------------------
358
359    fn factorize(
360        &self,
361        _left_inds: &[<Self as TensorIndex>::Index],
362        _options: &FactorizeOptions,
363    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
364        Err(FactorizeError::ComputationError(anyhow::anyhow!(
365            "BlockTensor does not support factorize"
366        )))
367    }
368
369    fn direct_sum(
370        &self,
371        _other: &Self,
372        _pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
373    ) -> Result<DirectSumResult<Self>> {
374        anyhow::bail!("BlockTensor does not support direct_sum")
375    }
376
377    fn outer_product(&self, _other: &Self) -> Result<Self> {
378        anyhow::bail!("BlockTensor does not support outer_product")
379    }
380
381    fn permuteinds(&self, _new_order: &[<Self as TensorIndex>::Index]) -> Result<Self> {
382        anyhow::bail!("BlockTensor does not support permuteinds")
383    }
384
385    fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
386        anyhow::bail!("BlockTensor does not support contract")
387    }
388
389    fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
390        anyhow::bail!("BlockTensor does not support contract_connected")
391    }
392
393    fn diagonal(
394        _input_index: &<Self as TensorIndex>::Index,
395        _output_index: &<Self as TensorIndex>::Index,
396    ) -> Result<Self> {
397        anyhow::bail!("BlockTensor does not support diagonal")
398    }
399
400    fn scalar_one() -> Result<Self> {
401        anyhow::bail!("BlockTensor does not support scalar_one")
402    }
403
404    fn ones(_indices: &[<Self as TensorIndex>::Index]) -> Result<Self> {
405        anyhow::bail!("BlockTensor does not support ones")
406    }
407
408    fn onehot(_index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self> {
409        anyhow::bail!("BlockTensor does not support onehot")
410    }
411}
412
413#[cfg(test)]
414mod tests;