pub fn broadcast_input_plan(
input: &[usize],
output: &[usize],
) -> Result<BroadcastInputPlan, BroadcastError>Expand description
Plan how one input should lower to BroadcastInDim.
Expanding singleton axes are omitted from source_shape so downstream VJP
rules reduce those axes explicitly.
ยงExamples
use tenferro_ops::broadcast::broadcast_input_plan;
let plan = broadcast_input_plan(&[3, 1], &[3, 4]).unwrap();
assert_eq!(plan.source_shape, vec![3]);
assert_eq!(plan.dims, vec![0]);