Skip to main content

contract_storage

Function contract_storage 

Source
pub fn contract_storage(
    storage_a: &Storage,
    dims_a: &[usize],
    axes_a: &[usize],
    storage_b: &Storage,
    dims_b: &[usize],
    axes_b: &[usize],
    result_dims: &[usize],
) -> StorageResult<Storage>
Expand description

Contract two storage tensors along specified axes.

All storage is StructuredStorage; contraction is delegated to the native tenferro backend. This is the primary tensor contraction entry point at the storage layer.

§Arguments

  • storage_a - First tensor storage
  • dims_a - Dimensions of the first tensor
  • axes_a - Axes of the first tensor to contract
  • storage_b - Second tensor storage
  • dims_b - Dimensions of the second tensor
  • axes_b - Axes of the second tensor to contract
  • result_dims - Dimensions of the result tensor (empty for scalar result)

§Returns

A new Storage containing the contracted result.

§Errors

Returns an error if axes are invalid, contracted dimensions do not match, or the native backend rejects the contraction.

§Examples

use tensor4all_tensorbackend::{contract_storage, Storage};

// Matrix-vector multiply: A(2x3) * v(3) -> result(2)
let a = Storage::from_dense_col_major(
    vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3],
).unwrap();
let v = Storage::from_dense_col_major(vec![1.0, 1.0, 1.0], &[3]).unwrap();
let result = contract_storage(&a, &[2, 3], &[1], &v, &[3], &[0], &[2]).unwrap();
// Row sums: [1+3+5, 2+4+6] = [9, 12]
let vals = result.to_dense_f64_col_major_vec(&[2]).unwrap();
assert!((vals[0] - 9.0).abs() < 1e-10);
assert!((vals[1] - 12.0).abs() < 1e-10);