tenferro_prims/cpu/
metadata_cast.rs

1use num_traits::NumCast;
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, Result};
4use tenferro_tensor::Tensor;
5
6use crate::cpu::{tensor_to_view, tensor_to_view_mut};
7use crate::shape_helpers::broadcast_tensor_to_shape;
8use crate::{
9    cast_metadata_value, for_each_index_result, supports_metadata_cast,
10    validate_metadata_cast_shapes, validate_pointwise_cast_bridge_inputs,
11    validate_where_bridge_inputs, CpuBackend, CpuContext, MetadataCastPrimsDescriptor,
12    MetadataTensorRef, TensorMetadataCastPrims,
13};
14
15fn execute_pointwise_cast<S>(
16    input: MetadataTensorRef<'_>,
17    alpha: S,
18    beta: S,
19    output: &mut Tensor<S>,
20) -> Result<()>
21where
22    S: Scalar + NumCast + 'static,
23{
24    match input {
25        MetadataTensorRef::I32(tensor) => {
26            let tensor = broadcast_tensor_to_shape(tensor, output.dims(), "MetadataCast input")?;
27            let input = tensor_to_view(&tensor)?;
28            let mut output_view = tensor_to_view_mut(output)?;
29            let dims = output_view.dims().to_vec();
30            for_each_index_result(&dims, |idx| {
31                let casted = cast_metadata_value::<S, i32>(input.get(idx), "metadata i32 value")?;
32                output_view.set(idx, alpha * casted + beta * output_view.get(idx));
33                Ok(())
34            })
35        }
36        MetadataTensorRef::Bool(tensor) => {
37            let tensor = broadcast_tensor_to_shape(tensor, output.dims(), "MetadataCast input")?;
38            let input = tensor_to_view(&tensor)?;
39            let mut output_view = tensor_to_view_mut(output)?;
40            let dims = output_view.dims().to_vec();
41            for_each_index_result(&dims, |idx| {
42                let casted = cast_metadata_value::<S, u8>(
43                    if input.get(idx) != 0 { 1 } else { 0 },
44                    "metadata bool value",
45                )?;
46                output_view.set(idx, alpha * casted + beta * output_view.get(idx));
47                Ok(())
48            })
49        }
50    }
51}
52
53fn execute_where<S>(
54    cond: MetadataTensorRef<'_>,
55    on_true: &Tensor<S>,
56    on_false: &Tensor<S>,
57    alpha: S,
58    beta: S,
59    output: &mut Tensor<S>,
60) -> Result<()>
61where
62    S: Scalar + NumCast + 'static,
63{
64    let on_true = broadcast_tensor_to_shape(on_true, output.dims(), "MetadataCastWhere true")?;
65    let on_false = broadcast_tensor_to_shape(on_false, output.dims(), "MetadataCastWhere false")?;
66    let on_true = tensor_to_view(&on_true)?;
67    let on_false = tensor_to_view(&on_false)?;
68    match cond {
69        MetadataTensorRef::Bool(cond) => {
70            let cond = broadcast_tensor_to_shape(cond, output.dims(), "MetadataCastWhere cond")?;
71            let cond = tensor_to_view(&cond)?;
72            let mut output_view = tensor_to_view_mut(output)?;
73            let dims = output_view.dims().to_vec();
74            for_each_index_result(&dims, |idx| {
75                let selected = if cond.get(idx) != 0 {
76                    on_true.get(idx)
77                } else {
78                    on_false.get(idx)
79                };
80                output_view.set(idx, alpha * selected + beta * output_view.get(idx));
81                Ok(())
82            })
83        }
84        MetadataTensorRef::I32(_) => Err(Error::InvalidArgument(
85            "MetadataCastWhere expects bool condition metadata".into(),
86        )),
87    }
88}
89
90impl<S> TensorMetadataCastPrims<S> for CpuBackend
91where
92    S: Scalar + NumCast + 'static,
93{
94    type Plan = MetadataCastPrimsDescriptor;
95    type Context = CpuContext;
96
97    fn plan(
98        _ctx: &mut Self::Context,
99        desc: &MetadataCastPrimsDescriptor,
100        shapes: &[&[usize]],
101    ) -> Result<Self::Plan> {
102        validate_metadata_cast_shapes(desc, shapes, "MetadataCast")?;
103        if !supports_metadata_cast(desc) {
104            return Err(Error::InvalidArgument(format!(
105                "metadata cast descriptor {desc:?} is not supported on CpuBackend for {}",
106                std::any::type_name::<S>()
107            )));
108        }
109        Ok(desc.clone())
110    }
111
112    fn execute(
113        _ctx: &mut Self::Context,
114        plan: &Self::Plan,
115        alpha: S,
116        inputs: &[crate::MetadataScalarTensorRef<'_, S>],
117        beta: S,
118        output: &mut Tensor<S>,
119    ) -> Result<()> {
120        match plan {
121            MetadataCastPrimsDescriptor::PointwiseCast { .. } => {
122                let input = validate_pointwise_cast_bridge_inputs(inputs)?;
123                execute_pointwise_cast(input, alpha, beta, output)
124            }
125            MetadataCastPrimsDescriptor::Where { .. } => {
126                let (cond, on_true, on_false) = validate_where_bridge_inputs(inputs)?;
127                execute_where(cond, on_true, on_false, alpha, beta, output)
128            }
129        }
130    }
131
132    fn has_metadata_cast_support(desc: MetadataCastPrimsDescriptor) -> bool {
133        supports_metadata_cast(&desc)
134    }
135}