tenferro_prims/families/
metadata_cast.rs1use num_traits::{NumCast, ToPrimitive};
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, Result};
4use tenferro_tensor::Tensor;
5
6use crate::shape_helpers::validate_shape_broadcastable;
7use crate::{validate_shape_count, MetadataDType, MetadataTensorRef};
8#[cfg(feature = "cuda")]
9use crate::{ScalarPrimsDescriptor, ScalarTernaryOp};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum MetadataCastPrimsDescriptor {
25 PointwiseCast {
27 input_dtype: MetadataDType,
29 },
30 Where {
32 cond_dtype: MetadataDType,
34 },
35}
36
37#[derive(Debug, Clone, Copy)]
54pub enum MetadataScalarTensorRef<'a, S: Scalar> {
55 Metadata(MetadataTensorRef<'a>),
57 Scalar(&'a Tensor<S>),
59}
60
61pub trait TensorMetadataCastPrims<S: Scalar> {
87 type Plan;
89 type Context;
91
92 fn plan(
94 ctx: &mut Self::Context,
95 desc: &MetadataCastPrimsDescriptor,
96 shapes: &[&[usize]],
97 ) -> Result<Self::Plan>;
98
99 fn execute(
104 ctx: &mut Self::Context,
105 plan: &Self::Plan,
106 alpha: S,
107 inputs: &[MetadataScalarTensorRef<'_, S>],
108 beta: S,
109 output: &mut Tensor<S>,
110 ) -> Result<()>;
111
112 fn has_metadata_cast_support(desc: MetadataCastPrimsDescriptor) -> bool;
114}
115
116pub(crate) fn supports_metadata_cast(desc: &MetadataCastPrimsDescriptor) -> bool {
118 match desc {
119 MetadataCastPrimsDescriptor::PointwiseCast { input_dtype } => {
120 matches!(input_dtype, MetadataDType::Bool | MetadataDType::I32)
121 }
122 MetadataCastPrimsDescriptor::Where { cond_dtype } => {
123 matches!(cond_dtype, MetadataDType::Bool)
124 }
125 }
126}
127
128pub(crate) fn validate_metadata_cast_shapes(
130 desc: &MetadataCastPrimsDescriptor,
131 shapes: &[&[usize]],
132 op_name: &str,
133) -> Result<()> {
134 match desc {
135 MetadataCastPrimsDescriptor::PointwiseCast { .. } => {
136 validate_shape_count(shapes, 2, op_name)?;
137 validate_shape_broadcastable(shapes[0], shapes[1], op_name)?;
138 Ok(())
139 }
140 MetadataCastPrimsDescriptor::Where { .. } => {
141 validate_shape_count(shapes, 4, op_name)?;
142 validate_shape_broadcastable(shapes[0], shapes[3], op_name)?;
143 validate_shape_broadcastable(shapes[1], shapes[3], op_name)?;
144 validate_shape_broadcastable(shapes[2], shapes[3], op_name)?;
145 Ok(())
146 }
147 }
148}
149
150pub(crate) fn cast_metadata_value<S, T>(value: T, label: &str) -> Result<S>
151where
152 S: Scalar + NumCast,
153 T: ToPrimitive + Copy,
154{
155 NumCast::from(value).ok_or_else(|| {
156 Error::InvalidArgument(format!(
157 "{label} cannot be represented as {}",
158 std::any::type_name::<S>()
159 ))
160 })
161}
162
163pub(crate) fn for_each_index_result(
164 dims: &[usize],
165 mut f: impl FnMut(&[usize]) -> Result<()>,
166) -> Result<()> {
167 let mut result = Ok(());
168 crate::for_each_index(dims, |idx| {
169 if result.is_ok() {
170 result = f(idx);
171 }
172 });
173 result
174}
175
176pub(crate) fn validate_where_bridge_inputs<'a, S: Scalar>(
177 inputs: &'a [MetadataScalarTensorRef<'a, S>],
178) -> Result<(MetadataTensorRef<'a>, &'a Tensor<S>, &'a Tensor<S>)> {
179 if inputs.len() != 3 {
180 return Err(Error::InvalidArgument(format!(
181 "MetadataCastWhere expects 3 input(s) (got {})",
182 inputs.len()
183 )));
184 }
185 let MetadataScalarTensorRef::Metadata(cond) = inputs[0] else {
186 return Err(Error::InvalidArgument(
187 "MetadataCastWhere expects metadata condition input".into(),
188 ));
189 };
190 let MetadataScalarTensorRef::Scalar(on_true) = inputs[1] else {
191 return Err(Error::InvalidArgument(
192 "MetadataCastWhere expects scalar on_true input".into(),
193 ));
194 };
195 let MetadataScalarTensorRef::Scalar(on_false) = inputs[2] else {
196 return Err(Error::InvalidArgument(
197 "MetadataCastWhere expects scalar on_false input".into(),
198 ));
199 };
200 Ok((cond, on_true, on_false))
201}
202
203pub(crate) fn validate_pointwise_cast_bridge_inputs<'a, S: Scalar>(
204 inputs: &'a [MetadataScalarTensorRef<'a, S>],
205) -> Result<MetadataTensorRef<'a>> {
206 if inputs.len() != 1 {
207 return Err(Error::InvalidArgument(format!(
208 "MetadataCastPointwise expects 1 input(s) (got {})",
209 inputs.len()
210 )));
211 }
212 let MetadataScalarTensorRef::Metadata(input) = inputs[0] else {
213 return Err(Error::InvalidArgument(
214 "MetadataCastPointwise expects metadata input".into(),
215 ));
216 };
217 Ok(input)
218}
219
220#[cfg(feature = "cuda")]
222pub(crate) fn scalar_where_desc() -> ScalarPrimsDescriptor {
223 ScalarPrimsDescriptor::PointwiseTernary {
224 op: ScalarTernaryOp::Where,
225 }
226}