Dynamic Shapes: Truncated SVD
Use traced dynamic-shape operations when an output size is known only after execution starts. This tutorial builds compiled programs that run an SVD, count singular values above a threshold, truncate u, s, and vt with dynamic_truncate, and reconstruct the thresholded matrix.
The same traced graph and explicit placeholder spec are compiled once, then the programs run twice below: once with two singular values above the threshold and once with three. No re-trace or recompile is needed between the two executions.
use tenferro_cpu::CpuBackend;
use tenferro_einsum::GraphCompilerEinsumExt;
use tenferro_linalg::TracedTensorLinalgExt;
use tenferro_runtime::{
CompareDir, DType, GraphCompiler, GraphExecutor, GraphProgram, Tensor, TracedTensor,
};
fn assert_close(actual: &[f64], expected: &[f64], tolerance: f64) {
assert_eq!(actual.len(), expected.len());
for (index, (actual, expected)) in actual.iter().zip(expected).enumerate() {
let error = (actual - expected).abs();
assert!(
error <= tolerance,
"value {index}: actual={actual}, expected={expected}, error={error}, tolerance={tolerance}"
);
}
}
fn diagonal_matrix(diagonal: &[f64]) -> Result<Tensor, tenferro_runtime::Error> {
let n = diagonal.len();
let mut values = vec![0.0_f64; n * n];
for (index, value) in diagonal.iter().enumerate() {
values[index + index * n] = *value;
}
Ok(Tensor::from_vec_col_major(vec![n, n], values)?)
}
fn truncated_expected(diagonal: &[f64], threshold: f64) -> Vec<f64> {
let n = diagonal.len();
let mut values = vec![0.0_f64; n * n];
for (index, value) in diagonal.iter().enumerate() {
if value.abs() > threshold {
values[index + index * n] = *value;
}
}
values
}
fn run_case(
executor: &mut GraphExecutor<CpuBackend>,
reconstructed_program: &GraphProgram,
singular_values_program: &GraphProgram,
x: &TracedTensor,
input: &Tensor,
expected_rank: usize,
expected_values: &[f64],
) -> Result<(), tenferro_runtime::Error> {
let reconstructed = executor.run_with_inputs(reconstructed_program, &[(x, input)])?;
let singular_values = executor.run_with_inputs(singular_values_program, &[(x, input)])?;
assert_eq!(singular_values.shape(), &[expected_rank]);
assert_eq!(reconstructed.shape(), &[4, 4]);
assert_close(
reconstructed.as_slice::<f64>().unwrap(),
expected_values,
1.0e-10,
);
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let x = TracedTensor::input_concrete_shape(DType::F64, &[4, 4])?;
let (u, s, vt) = x.svd_with_eps(1.0e-12)?;
let threshold = TracedTensor::from_vec_col_major(vec![], vec![0.5_f64])?;
let keep_count = s
.compare(&threshold, CompareDir::Gt)?
.convert(DType::F64)?
.reduce_sum(&[0])?;
let u_truncated = u.dynamic_truncate(&keep_count, 1)?;
let s_truncated = s.dynamic_truncate(&keep_count, 0)?;
let vt_truncated = vt.dynamic_truncate(&keep_count, 0)?;
let mut compiler = GraphCompiler::new();
let reconstructed = compiler.einsum_with(
&[&u_truncated, &s_truncated, &vt_truncated],
"ik,k,kj->ij",
tenferro_einsum::EinsumOptimize::Path(vec![(0, 1), (0, 1)]),
)?;
let input_specs = [(&x, DType::F64, &[4, 4][..])];
let reconstructed_program = compiler.compile_with_input_specs(&reconstructed, &input_specs)?;
let singular_values_program = compiler.compile_with_input_specs(&s_truncated, &input_specs)?;
let mut executor = GraphExecutor::new(CpuBackend::new());
executor.register_extension(tenferro_linalg::register_runtime)?;
executor.register_extension(tenferro_einsum::register_runtime)?;
let rank2 = diagonal_matrix(&[4.0, 3.0, 0.1, 0.01])?;
run_case(
&mut executor,
&reconstructed_program,
&singular_values_program,
&x,
&rank2,
2,
&truncated_expected(&[4.0, 3.0, 0.1, 0.01], 0.5),
)?;
let rank3 = diagonal_matrix(&[4.0, 3.0, 2.0, 0.01])?;
run_case(
&mut executor,
&reconstructed_program,
&singular_values_program,
&x,
&rank3,
3,
&truncated_expected(&[4.0, 3.0, 2.0, 0.01], 0.5),
)?;
Ok(())
}The shape metadata for the truncated axis is an upper bound in the compiled program. The concrete extent is resolved at dispatch from the runtime scalar keep_count, then later operations consume the resulting dynamic extent.
For the implementation contract, see Dynamic and Symbolic Shape Metadata. For the broader eager/traced split, see the execution models guide.