tensor4all_core/block_tensor.rs
1//! Block tensor type for GMRES with block matrices.
2//!
3//! This module provides [`BlockTensor`], a collection of tensors organized
4//! in a block structure. It implements [`TensorLike`] for the vector space
5//! operations required by GMRES, allowing block matrix linear equations
6//! `Ax = b` to be solved using the existing GMRES implementation.
7//!
8//! # Example
9//!
10//! ```
11//! use tensor4all_core::block_tensor::BlockTensor;
12//! use tensor4all_core::krylov::{gmres, GmresOptions};
13//! use tensor4all_core::{DynIndex, TensorDynLen};
14//!
15//! # fn main() -> anyhow::Result<()> {
16//! let i = DynIndex::new_dyn(2);
17//! let b_block = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0])?;
18//! let zero_block = TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0])?;
19//!
20//! let b = BlockTensor::new(vec![b_block], (1, 1));
21//! let x0 = BlockTensor::new(vec![zero_block], (1, 1));
22//!
23//! let apply_a = |x: &BlockTensor<TensorDynLen>| Ok(x.clone());
24//! let result = gmres(apply_a, &b, &x0, &GmresOptions::default())?;
25//!
26//! assert!(result.converged);
27//! assert_eq!(result.solution.shape(), (1, 1));
28//! # Ok(())
29//! # }
30//! ```
31
32use std::collections::HashSet;
33
34use crate::any_scalar::AnyScalar;
35use crate::tensor_index::TensorIndex;
36use crate::tensor_like::{
37 AllowedPairs, DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult,
38 LinearizationOrder, TensorLike,
39};
40use anyhow::Result;
41
42/// A collection of tensors organized in a block structure.
43///
44/// Each block is a tensor of type `T` implementing [`TensorLike`].
45/// The flattened block list is ordered row-by-row:
46/// `(0, 0), (0, 1), ..., (1, 0), (1, 1), ...`.
47///
48/// # Type Parameters
49///
50/// * `T` - The tensor type for each block, must implement `TensorLike`
51#[derive(Debug, Clone)]
52pub struct BlockTensor<T: TensorLike> {
53 /// Blocks flattened row-by-row in block-matrix order
54 blocks: Vec<T>,
55 /// Block structure (rows, cols)
56 shape: (usize, usize),
57}
58
59impl<T: TensorLike> BlockTensor<T> {
60 /// Create a new block tensor with validation.
61 ///
62 /// # Arguments
63 ///
64 /// * `blocks` - Vector of blocks flattened row-by-row
65 /// * `shape` - Block structure as (rows, cols)
66 ///
67 /// # Errors
68 ///
69 /// Returns an error if `rows * cols != blocks.len()`.
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use tensor4all_core::block_tensor::BlockTensor;
75 /// use tensor4all_core::{DynIndex, TensorDynLen};
76 ///
77 /// let i = DynIndex::new_dyn(2);
78 /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
79 /// let t2 = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap();
80 ///
81 /// let bt = BlockTensor::try_new(vec![t1, t2], (2, 1)).unwrap();
82 /// assert_eq!(bt.shape(), (2, 1));
83 ///
84 /// // Wrong number of blocks returns an error
85 /// let t3 = TensorDynLen::from_dense(vec![i], vec![5.0, 6.0]).unwrap();
86 /// assert!(BlockTensor::try_new(vec![t3], (2, 1)).is_err());
87 /// ```
88 pub fn try_new(blocks: Vec<T>, shape: (usize, usize)) -> Result<Self> {
89 let (rows, cols) = shape;
90 anyhow::ensure!(
91 rows * cols == blocks.len(),
92 "Block count mismatch: shape ({}, {}) requires {} blocks, but got {}",
93 rows,
94 cols,
95 rows * cols,
96 blocks.len()
97 );
98 Ok(Self { blocks, shape })
99 }
100
101 /// Create a new block tensor.
102 ///
103 /// # Arguments
104 ///
105 /// * `blocks` - Vector of blocks flattened row-by-row
106 /// * `shape` - Block structure as (rows, cols)
107 ///
108 /// # Panics
109 ///
110 /// Panics if `rows * cols != blocks.len()`.
111 ///
112 /// # Examples
113 ///
114 /// ```
115 /// use tensor4all_core::block_tensor::BlockTensor;
116 /// use tensor4all_core::{DynIndex, TensorDynLen};
117 ///
118 /// let i = DynIndex::new_dyn(2);
119 /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
120 /// let bt = BlockTensor::new(vec![t], (1, 1));
121 /// assert_eq!(bt.shape(), (1, 1));
122 /// ```
123 pub fn new(blocks: Vec<T>, shape: (usize, usize)) -> Self {
124 Self::try_new(blocks, shape).expect("Invalid block tensor shape")
125 }
126
127 /// Get the block structure (rows, cols).
128 ///
129 /// # Examples
130 ///
131 /// ```
132 /// use tensor4all_core::block_tensor::BlockTensor;
133 /// use tensor4all_core::{DynIndex, TensorDynLen};
134 ///
135 /// let i = DynIndex::new_dyn(2);
136 /// let blocks: Vec<TensorDynLen> = (0..6)
137 /// .map(|_| TensorDynLen::from_dense(vec![i.clone()], vec![0.0, 0.0]).unwrap())
138 /// .collect();
139 /// let bt = BlockTensor::new(blocks, (2, 3));
140 /// assert_eq!(bt.shape(), (2, 3));
141 /// ```
142 pub fn shape(&self) -> (usize, usize) {
143 self.shape
144 }
145
146 /// Get the total number of blocks.
147 pub fn num_blocks(&self) -> usize {
148 self.blocks.len()
149 }
150
151 /// Get a reference to the block at (row, col).
152 ///
153 /// # Panics
154 ///
155 /// Panics if the indices are out of bounds.
156 ///
157 /// # Examples
158 ///
159 /// ```
160 /// use tensor4all_core::block_tensor::BlockTensor;
161 /// use tensor4all_core::{DynIndex, TensorDynLen};
162 ///
163 /// let i = DynIndex::new_dyn(2);
164 /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
165 /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
166 /// let bt = BlockTensor::new(vec![t1, t2], (2, 1));
167 /// assert_eq!(bt.get(0, 0).dims(), vec![2]);
168 /// assert_eq!(bt.get(1, 0).dims(), vec![2]);
169 /// ```
170 pub fn get(&self, row: usize, col: usize) -> &T {
171 let (rows, cols) = self.shape;
172 assert!(row < rows && col < cols, "Block index out of bounds");
173 &self.blocks[row * cols + col]
174 }
175
176 /// Get a mutable reference to the block at (row, col).
177 ///
178 /// # Panics
179 ///
180 /// Panics if the indices are out of bounds.
181 ///
182 /// # Examples
183 ///
184 /// ```
185 /// use tensor4all_core::block_tensor::BlockTensor;
186 /// use tensor4all_core::{DynIndex, TensorDynLen};
187 ///
188 /// let i = DynIndex::new_dyn(2);
189 /// let t = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
190 /// let mut bt = BlockTensor::new(vec![t], (1, 1));
191 /// let block = bt.get_mut(0, 0);
192 /// assert_eq!(block.dims(), vec![2]);
193 /// ```
194 pub fn get_mut(&mut self, row: usize, col: usize) -> &mut T {
195 let (rows, cols) = self.shape;
196 assert!(row < rows && col < cols, "Block index out of bounds");
197 &mut self.blocks[row * cols + col]
198 }
199
200 /// Get all blocks as a slice.
201 ///
202 /// # Examples
203 ///
204 /// ```
205 /// use tensor4all_core::block_tensor::BlockTensor;
206 /// use tensor4all_core::{DynIndex, TensorDynLen};
207 ///
208 /// let i = DynIndex::new_dyn(2);
209 /// let t1 = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
210 /// let t2 = TensorDynLen::from_dense(vec![i], vec![3.0, 4.0]).unwrap();
211 /// let bt = BlockTensor::new(vec![t1, t2], (1, 2));
212 /// assert_eq!(bt.blocks().len(), 2);
213 /// ```
214 pub fn blocks(&self) -> &[T] {
215 &self.blocks
216 }
217
218 /// Get all blocks as a mutable slice.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// use tensor4all_core::block_tensor::BlockTensor;
224 /// use tensor4all_core::{DynIndex, TensorDynLen};
225 ///
226 /// let i = DynIndex::new_dyn(2);
227 /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
228 /// let mut bt = BlockTensor::new(vec![t], (1, 1));
229 /// assert_eq!(bt.blocks_mut().len(), 1);
230 /// ```
231 pub fn blocks_mut(&mut self) -> &mut [T] {
232 &mut self.blocks
233 }
234
235 /// Consume self and return the blocks.
236 ///
237 /// # Examples
238 ///
239 /// ```
240 /// use tensor4all_core::block_tensor::BlockTensor;
241 /// use tensor4all_core::{DynIndex, TensorDynLen};
242 ///
243 /// let i = DynIndex::new_dyn(2);
244 /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
245 /// let bt = BlockTensor::new(vec![t], (1, 1));
246 /// let blocks = bt.into_blocks();
247 /// assert_eq!(blocks.len(), 1);
248 /// assert_eq!(blocks[0].dims(), vec![2]);
249 /// ```
250 pub fn into_blocks(self) -> Vec<T> {
251 self.blocks
252 }
253
254 /// Validate that blocks share external indices consistently.
255 ///
256 /// For column vectors (cols=1), no index sharing is required between
257 /// blocks. Different rows can have independent physical indices
258 /// (the operator determines their relationship).
259 ///
260 /// For matrices (rows x cols), checks that:
261 /// - All blocks have the same number of external indices.
262 /// - Blocks in the same row share some common index IDs (output indices).
263 /// - Blocks in the same column share some common index IDs (input indices).
264 ///
265 /// # Examples
266 ///
267 /// ```
268 /// use tensor4all_core::block_tensor::BlockTensor;
269 /// use tensor4all_core::{DynIndex, TensorDynLen};
270 ///
271 /// // Column vector: always valid regardless of index sharing
272 /// let i = DynIndex::new_dyn(2);
273 /// let j = DynIndex::new_dyn(2);
274 /// let t1 = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
275 /// let t2 = TensorDynLen::from_dense(vec![j], vec![3.0, 4.0]).unwrap();
276 /// let bt = BlockTensor::new(vec![t1, t2], (2, 1));
277 /// assert!(bt.validate_indices().is_ok());
278 /// ```
279 pub fn validate_indices(&self) -> Result<()> {
280 let (rows, cols) = self.shape;
281
282 if cols <= 1 {
283 // Column vector: blocks in different rows can have independent indices.
284 // The operator determines the relationship between blocks.
285 return Ok(());
286 }
287
288 // Matrix: check all blocks have the same number of external indices
289 let first_count = self.blocks[0].num_external_indices();
290 for (i, block) in self.blocks.iter().enumerate().skip(1) {
291 let n = block.num_external_indices();
292 anyhow::ensure!(
293 n == first_count,
294 "Block {} has {} external indices, but block 0 has {}",
295 i,
296 n,
297 first_count
298 );
299 }
300
301 // Same row: blocks should share at least one full output index.
302 for row in 0..rows {
303 let ref_indices: HashSet<_> = self
304 .get(row, 0)
305 .external_indices()
306 .iter()
307 .cloned()
308 .collect();
309 for col in 1..cols {
310 let indices: HashSet<_> = self
311 .get(row, col)
312 .external_indices()
313 .iter()
314 .cloned()
315 .collect();
316 let common_count = ref_indices.intersection(&indices).count();
317 anyhow::ensure!(
318 common_count > 0,
319 "Matrix row {}: blocks ({},{}) and ({},{}) share no common indices",
320 row,
321 row,
322 0,
323 row,
324 col
325 );
326 }
327 }
328
329 // Same column: blocks should share at least one full input index.
330 for col in 0..cols {
331 let ref_indices: HashSet<_> = self
332 .get(0, col)
333 .external_indices()
334 .iter()
335 .cloned()
336 .collect();
337 for row in 1..rows {
338 let indices: HashSet<_> = self
339 .get(row, col)
340 .external_indices()
341 .iter()
342 .cloned()
343 .collect();
344 let common_count = ref_indices.intersection(&indices).count();
345 anyhow::ensure!(
346 common_count > 0,
347 "Matrix col {}: blocks ({},{}) and ({},{}) share no common indices",
348 col,
349 0,
350 col,
351 row,
352 col
353 );
354 }
355 }
356
357 Ok(())
358 }
359}
360
361// ============================================================================
362// TensorIndex implementation
363// ============================================================================
364
365impl<T: TensorLike> TensorIndex for BlockTensor<T> {
366 type Index = T::Index;
367
368 fn external_indices(&self) -> Vec<Self::Index> {
369 // Collect unique external indices across all blocks (deduplicated by full index).
370 let mut seen = HashSet::new();
371 let mut result = Vec::new();
372 for block in &self.blocks {
373 for idx in block.external_indices() {
374 if seen.insert(idx.clone()) {
375 result.push(idx);
376 }
377 }
378 }
379 result
380 }
381
382 fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
383 let replaced: Result<Vec<T>> = self
384 .blocks
385 .iter()
386 .map(|b| b.replaceind(old_index, new_index))
387 .collect();
388 Ok(Self {
389 blocks: replaced?,
390 shape: self.shape,
391 })
392 }
393
394 fn replaceinds(
395 &self,
396 old_indices: &[Self::Index],
397 new_indices: &[Self::Index],
398 ) -> Result<Self> {
399 let replaced: Result<Vec<T>> = self
400 .blocks
401 .iter()
402 .map(|b| b.replaceinds(old_indices, new_indices))
403 .collect();
404 Ok(Self {
405 blocks: replaced?,
406 shape: self.shape,
407 })
408 }
409}
410
411// ============================================================================
412// TensorLike implementation
413// ============================================================================
414
415impl<T: TensorLike> TensorLike for BlockTensor<T> {
416 // ------------------------------------------------------------------------
417 // Vector space operations (required for GMRES)
418 // ------------------------------------------------------------------------
419
420 fn norm_squared(&self) -> f64 {
421 self.blocks.iter().map(|b| b.norm_squared()).sum()
422 }
423
424 fn maxabs(&self) -> f64 {
425 self.blocks
426 .iter()
427 .map(|b| b.maxabs())
428 .fold(0.0_f64, f64::max)
429 }
430
431 fn scale(&self, scalar: AnyScalar) -> Result<Self> {
432 let scaled: Result<Vec<T>> = self
433 .blocks
434 .iter()
435 .map(|b| b.scale(scalar.clone()))
436 .collect();
437 Ok(Self {
438 blocks: scaled?,
439 shape: self.shape,
440 })
441 }
442
443 fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
444 anyhow::ensure!(
445 self.shape == other.shape,
446 "Block shapes must match: {:?} vs {:?}",
447 self.shape,
448 other.shape
449 );
450 let result: Result<Vec<T>> = self
451 .blocks
452 .iter()
453 .zip(other.blocks.iter())
454 .map(|(s, o)| s.axpby(a.clone(), o, b.clone()))
455 .collect();
456 Ok(Self {
457 blocks: result?,
458 shape: self.shape,
459 })
460 }
461
462 fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
463 anyhow::ensure!(
464 self.shape == other.shape,
465 "Block shapes must match for inner product: {:?} vs {:?}",
466 self.shape,
467 other.shape
468 );
469 let mut sum = AnyScalar::new_real(0.0);
470 for (s, o) in self.blocks.iter().zip(other.blocks.iter()) {
471 sum = sum + s.inner_product(o)?;
472 }
473 Ok(sum)
474 }
475
476 fn conj(&self) -> Self {
477 let conjugated: Vec<T> = self.blocks.iter().map(|b| b.conj()).collect();
478 Self {
479 blocks: conjugated,
480 shape: self.shape,
481 }
482 }
483
484 fn validate(&self) -> Result<()> {
485 self.validate_indices()
486 }
487
488 // ------------------------------------------------------------------------
489 // Operations not supported for BlockTensor (return error, don't panic)
490 // ------------------------------------------------------------------------
491
492 fn factorize(
493 &self,
494 _left_inds: &[<Self as TensorIndex>::Index],
495 _options: &FactorizeOptions,
496 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
497 Err(FactorizeError::ComputationError(anyhow::anyhow!(
498 "BlockTensor does not support factorize"
499 )))
500 }
501
502 fn factorize_full_rank(
503 &self,
504 _left_inds: &[<Self as TensorIndex>::Index],
505 _alg: crate::FactorizeAlg,
506 _canonical: crate::Canonical,
507 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
508 Err(FactorizeError::ComputationError(anyhow::anyhow!(
509 "BlockTensor does not support factorize_full_rank"
510 )))
511 }
512
513 fn direct_sum(
514 &self,
515 _other: &Self,
516 _pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
517 ) -> Result<DirectSumResult<Self>> {
518 anyhow::bail!("BlockTensor does not support direct_sum")
519 }
520
521 fn outer_product(&self, _other: &Self) -> Result<Self> {
522 anyhow::bail!("BlockTensor does not support outer_product")
523 }
524
525 fn permuteinds(&self, _new_order: &[<Self as TensorIndex>::Index]) -> Result<Self> {
526 anyhow::bail!("BlockTensor does not support permuteinds")
527 }
528
529 fn fuse_indices(
530 &self,
531 old_indices: &[Self::Index],
532 new_index: Self::Index,
533 order: LinearizationOrder,
534 ) -> Result<Self> {
535 let blocks: Result<Vec<T>> = self
536 .blocks
537 .iter()
538 .map(|block| block.fuse_indices(old_indices, new_index.clone(), order))
539 .collect();
540 Ok(Self {
541 blocks: blocks?,
542 shape: self.shape,
543 })
544 }
545
546 fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
547 anyhow::bail!("BlockTensor does not support contract")
548 }
549
550 fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result<Self> {
551 anyhow::bail!("BlockTensor does not support contract_connected")
552 }
553
554 fn diagonal(
555 _input_index: &<Self as TensorIndex>::Index,
556 _output_index: &<Self as TensorIndex>::Index,
557 ) -> Result<Self> {
558 anyhow::bail!("BlockTensor does not support diagonal")
559 }
560
561 fn scalar_one() -> Result<Self> {
562 anyhow::bail!("BlockTensor does not support scalar_one")
563 }
564
565 fn ones(_indices: &[<Self as TensorIndex>::Index]) -> Result<Self> {
566 anyhow::bail!("BlockTensor does not support ones")
567 }
568
569 fn onehot(_index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self> {
570 anyhow::bail!("BlockTensor does not support onehot")
571 }
572}
573
574#[cfg(test)]
575mod tests;