tenferro_prims/cpu/
metadata_cast.rs1use num_traits::NumCast;
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, Result};
4use tenferro_tensor::Tensor;
5
6use crate::cpu::{tensor_to_view, tensor_to_view_mut};
7use crate::shape_helpers::broadcast_tensor_to_shape;
8use crate::{
9 cast_metadata_value, for_each_index_result, supports_metadata_cast,
10 validate_metadata_cast_shapes, validate_pointwise_cast_bridge_inputs,
11 validate_where_bridge_inputs, CpuBackend, CpuContext, MetadataCastPrimsDescriptor,
12 MetadataTensorRef, TensorMetadataCastPrims,
13};
14
15fn execute_pointwise_cast<S>(
16 input: MetadataTensorRef<'_>,
17 alpha: S,
18 beta: S,
19 output: &mut Tensor<S>,
20) -> Result<()>
21where
22 S: Scalar + NumCast + 'static,
23{
24 match input {
25 MetadataTensorRef::I32(tensor) => {
26 let tensor = broadcast_tensor_to_shape(tensor, output.dims(), "MetadataCast input")?;
27 let input = tensor_to_view(&tensor)?;
28 let mut output_view = tensor_to_view_mut(output)?;
29 let dims = output_view.dims().to_vec();
30 for_each_index_result(&dims, |idx| {
31 let casted = cast_metadata_value::<S, i32>(input.get(idx), "metadata i32 value")?;
32 output_view.set(idx, alpha * casted + beta * output_view.get(idx));
33 Ok(())
34 })
35 }
36 MetadataTensorRef::Bool(tensor) => {
37 let tensor = broadcast_tensor_to_shape(tensor, output.dims(), "MetadataCast input")?;
38 let input = tensor_to_view(&tensor)?;
39 let mut output_view = tensor_to_view_mut(output)?;
40 let dims = output_view.dims().to_vec();
41 for_each_index_result(&dims, |idx| {
42 let casted = cast_metadata_value::<S, u8>(
43 if input.get(idx) != 0 { 1 } else { 0 },
44 "metadata bool value",
45 )?;
46 output_view.set(idx, alpha * casted + beta * output_view.get(idx));
47 Ok(())
48 })
49 }
50 }
51}
52
53fn execute_where<S>(
54 cond: MetadataTensorRef<'_>,
55 on_true: &Tensor<S>,
56 on_false: &Tensor<S>,
57 alpha: S,
58 beta: S,
59 output: &mut Tensor<S>,
60) -> Result<()>
61where
62 S: Scalar + NumCast + 'static,
63{
64 let on_true = broadcast_tensor_to_shape(on_true, output.dims(), "MetadataCastWhere true")?;
65 let on_false = broadcast_tensor_to_shape(on_false, output.dims(), "MetadataCastWhere false")?;
66 let on_true = tensor_to_view(&on_true)?;
67 let on_false = tensor_to_view(&on_false)?;
68 match cond {
69 MetadataTensorRef::Bool(cond) => {
70 let cond = broadcast_tensor_to_shape(cond, output.dims(), "MetadataCastWhere cond")?;
71 let cond = tensor_to_view(&cond)?;
72 let mut output_view = tensor_to_view_mut(output)?;
73 let dims = output_view.dims().to_vec();
74 for_each_index_result(&dims, |idx| {
75 let selected = if cond.get(idx) != 0 {
76 on_true.get(idx)
77 } else {
78 on_false.get(idx)
79 };
80 output_view.set(idx, alpha * selected + beta * output_view.get(idx));
81 Ok(())
82 })
83 }
84 MetadataTensorRef::I32(_) => Err(Error::InvalidArgument(
85 "MetadataCastWhere expects bool condition metadata".into(),
86 )),
87 }
88}
89
90impl<S> TensorMetadataCastPrims<S> for CpuBackend
91where
92 S: Scalar + NumCast + 'static,
93{
94 type Plan = MetadataCastPrimsDescriptor;
95 type Context = CpuContext;
96
97 fn plan(
98 _ctx: &mut Self::Context,
99 desc: &MetadataCastPrimsDescriptor,
100 shapes: &[&[usize]],
101 ) -> Result<Self::Plan> {
102 validate_metadata_cast_shapes(desc, shapes, "MetadataCast")?;
103 if !supports_metadata_cast(desc) {
104 return Err(Error::InvalidArgument(format!(
105 "metadata cast descriptor {desc:?} is not supported on CpuBackend for {}",
106 std::any::type_name::<S>()
107 )));
108 }
109 Ok(desc.clone())
110 }
111
112 fn execute(
113 _ctx: &mut Self::Context,
114 plan: &Self::Plan,
115 alpha: S,
116 inputs: &[crate::MetadataScalarTensorRef<'_, S>],
117 beta: S,
118 output: &mut Tensor<S>,
119 ) -> Result<()> {
120 match plan {
121 MetadataCastPrimsDescriptor::PointwiseCast { .. } => {
122 let input = validate_pointwise_cast_bridge_inputs(inputs)?;
123 execute_pointwise_cast(input, alpha, beta, output)
124 }
125 MetadataCastPrimsDescriptor::Where { .. } => {
126 let (cond, on_true, on_false) = validate_where_bridge_inputs(inputs)?;
127 execute_where(cond, on_true, on_false, alpha, beta, output)
128 }
129 }
130 }
131
132 fn has_metadata_cast_support(desc: MetadataCastPrimsDescriptor) -> bool {
133 supports_metadata_cast(&desc)
134 }
135}