Skip to main content

broadcast_input_plan

Function broadcast_input_plan 

Source
pub fn broadcast_input_plan(
    input: &[usize],
    output: &[usize],
) -> Result<BroadcastInputPlan, BroadcastError>
Expand description

Plan how one input should lower to BroadcastInDim.

Expanding singleton axes are omitted from source_shape so downstream VJP rules reduce those axes explicitly.

ยงExamples

use tenferro_ops::broadcast::broadcast_input_plan;

let plan = broadcast_input_plan(&[3, 1], &[3, 4]).unwrap();
assert_eq!(plan.source_shape, vec![3]);
assert_eq!(plan.dims, vec![0]);