tenferro_prims/families/
metadata_cast.rs

1use num_traits::{NumCast, ToPrimitive};
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, Result};
4use tenferro_tensor::Tensor;
5
6use crate::shape_helpers::validate_shape_broadcastable;
7use crate::{validate_shape_count, MetadataDType, MetadataTensorRef};
8#[cfg(feature = "cuda")]
9use crate::{ScalarPrimsDescriptor, ScalarTernaryOp};
10
11/// Metadata-to-scalar bridge planning operations.
12///
13/// # Examples
14///
15/// ```rust
16/// use tenferro_prims::{MetadataCastPrimsDescriptor, MetadataDType};
17///
18/// let cast = MetadataCastPrimsDescriptor::PointwiseCast {
19///     input_dtype: MetadataDType::Bool,
20/// };
21/// assert!(matches!(cast, MetadataCastPrimsDescriptor::PointwiseCast { .. }));
22/// ```
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum MetadataCastPrimsDescriptor {
25    /// Cast a metadata tensor into a scalar tensor of the backend's scalar dtype.
26    PointwiseCast {
27        /// Logical dtype of the metadata input.
28        input_dtype: MetadataDType,
29    },
30    /// Select between scalar tensors using a bool metadata mask.
31    Where {
32        /// Logical dtype of the condition/mask metadata input.
33        cond_dtype: MetadataDType,
34    },
35}
36
37/// Erased inputs for metadata-to-scalar bridge execution.
38///
39/// # Examples
40///
41/// ```ignore
42/// use tenferro_device::LogicalMemorySpace;
43/// use tenferro_prims::{MetadataScalarTensorRef, MetadataTensorRef};
44/// use tenferro_tensor::{MemoryOrder, Tensor};
45///
46/// let mask = Tensor::<u8>::zeros(
47///     &[2],
48///     LogicalMemorySpace::MainMemory,
49///     MemoryOrder::ColumnMajor,
50/// ).unwrap();
51/// let input = MetadataScalarTensorRef::Metadata(MetadataTensorRef::Bool(&mask));
52/// ```
53#[derive(Debug, Clone, Copy)]
54pub enum MetadataScalarTensorRef<'a, S: Scalar> {
55    /// Metadata tensor input.
56    Metadata(MetadataTensorRef<'a>),
57    /// Scalar tensor input.
58    Scalar(&'a Tensor<S>),
59}
60
61/// Metadata-to-scalar bridge protocol.
62///
63/// This family bridges integer/bool metadata tensors into scalar tensors so
64/// higher-level crates can reuse scalar `where` and similar dense eager paths.
65///
66/// # Examples
67///
68/// ```ignore
69/// use tenferro_algebra::Standard;
70/// use tenferro_prims::{
71///     CpuBackend, CpuContext, MetadataCastPrimsDescriptor, MetadataDType,
72///     TensorMetadataCastPrims,
73/// };
74///
75/// let mut ctx = CpuContext::new(1);
76/// let desc = MetadataCastPrimsDescriptor::PointwiseCast {
77///     input_dtype: MetadataDType::I32,
78/// };
79/// let _plan = <CpuBackend as TensorMetadataCastPrims<f32>>::plan(
80///     &mut ctx,
81///     &desc,
82///     &[&[2], &[2]],
83/// )
84/// .unwrap();
85/// ```
86pub trait TensorMetadataCastPrims<S: Scalar> {
87    /// Backend plan type.
88    type Plan;
89    /// Backend execution context.
90    type Context;
91
92    /// Plan a metadata-to-scalar bridge operation.
93    fn plan(
94        ctx: &mut Self::Context,
95        desc: &MetadataCastPrimsDescriptor,
96        shapes: &[&[usize]],
97    ) -> Result<Self::Plan>;
98
99    /// Execute a previously planned metadata-to-scalar bridge operation.
100    ///
101    /// The execution contract matches the rest of tenferro prims:
102    /// `output <- alpha * op(inputs) + beta * output`.
103    fn execute(
104        ctx: &mut Self::Context,
105        plan: &Self::Plan,
106        alpha: S,
107        inputs: &[MetadataScalarTensorRef<'_, S>],
108        beta: S,
109        output: &mut Tensor<S>,
110    ) -> Result<()>;
111
112    /// Report whether the backend advertises support for the given descriptor.
113    fn has_metadata_cast_support(desc: MetadataCastPrimsDescriptor) -> bool;
114}
115
116/// Return whether a metadata-to-scalar descriptor is supported by phase 1.
117pub(crate) fn supports_metadata_cast(desc: &MetadataCastPrimsDescriptor) -> bool {
118    match desc {
119        MetadataCastPrimsDescriptor::PointwiseCast { input_dtype } => {
120            matches!(input_dtype, MetadataDType::Bool | MetadataDType::I32)
121        }
122        MetadataCastPrimsDescriptor::Where { cond_dtype } => {
123            matches!(cond_dtype, MetadataDType::Bool)
124        }
125    }
126}
127
128/// Validate the shapes used by a metadata-to-scalar bridge descriptor.
129pub(crate) fn validate_metadata_cast_shapes(
130    desc: &MetadataCastPrimsDescriptor,
131    shapes: &[&[usize]],
132    op_name: &str,
133) -> Result<()> {
134    match desc {
135        MetadataCastPrimsDescriptor::PointwiseCast { .. } => {
136            validate_shape_count(shapes, 2, op_name)?;
137            validate_shape_broadcastable(shapes[0], shapes[1], op_name)?;
138            Ok(())
139        }
140        MetadataCastPrimsDescriptor::Where { .. } => {
141            validate_shape_count(shapes, 4, op_name)?;
142            validate_shape_broadcastable(shapes[0], shapes[3], op_name)?;
143            validate_shape_broadcastable(shapes[1], shapes[3], op_name)?;
144            validate_shape_broadcastable(shapes[2], shapes[3], op_name)?;
145            Ok(())
146        }
147    }
148}
149
150pub(crate) fn cast_metadata_value<S, T>(value: T, label: &str) -> Result<S>
151where
152    S: Scalar + NumCast,
153    T: ToPrimitive + Copy,
154{
155    NumCast::from(value).ok_or_else(|| {
156        Error::InvalidArgument(format!(
157            "{label} cannot be represented as {}",
158            std::any::type_name::<S>()
159        ))
160    })
161}
162
163pub(crate) fn for_each_index_result(
164    dims: &[usize],
165    mut f: impl FnMut(&[usize]) -> Result<()>,
166) -> Result<()> {
167    let mut result = Ok(());
168    crate::for_each_index(dims, |idx| {
169        if result.is_ok() {
170            result = f(idx);
171        }
172    });
173    result
174}
175
176pub(crate) fn validate_where_bridge_inputs<'a, S: Scalar>(
177    inputs: &'a [MetadataScalarTensorRef<'a, S>],
178) -> Result<(MetadataTensorRef<'a>, &'a Tensor<S>, &'a Tensor<S>)> {
179    if inputs.len() != 3 {
180        return Err(Error::InvalidArgument(format!(
181            "MetadataCastWhere expects 3 input(s) (got {})",
182            inputs.len()
183        )));
184    }
185    let MetadataScalarTensorRef::Metadata(cond) = inputs[0] else {
186        return Err(Error::InvalidArgument(
187            "MetadataCastWhere expects metadata condition input".into(),
188        ));
189    };
190    let MetadataScalarTensorRef::Scalar(on_true) = inputs[1] else {
191        return Err(Error::InvalidArgument(
192            "MetadataCastWhere expects scalar on_true input".into(),
193        ));
194    };
195    let MetadataScalarTensorRef::Scalar(on_false) = inputs[2] else {
196        return Err(Error::InvalidArgument(
197            "MetadataCastWhere expects scalar on_false input".into(),
198        ));
199    };
200    Ok((cond, on_true, on_false))
201}
202
203pub(crate) fn validate_pointwise_cast_bridge_inputs<'a, S: Scalar>(
204    inputs: &'a [MetadataScalarTensorRef<'a, S>],
205) -> Result<MetadataTensorRef<'a>> {
206    if inputs.len() != 1 {
207        return Err(Error::InvalidArgument(format!(
208            "MetadataCastPointwise expects 1 input(s) (got {})",
209            inputs.len()
210        )));
211    }
212    let MetadataScalarTensorRef::Metadata(input) = inputs[0] else {
213        return Err(Error::InvalidArgument(
214            "MetadataCastPointwise expects metadata input".into(),
215        ));
216    };
217    Ok(input)
218}
219
220/// Scalar-family scalar ternary descriptor for metadata bridge reuse.
221#[cfg(feature = "cuda")]
222pub(crate) fn scalar_where_desc() -> ScalarPrimsDescriptor {
223    ScalarPrimsDescriptor::PointwiseTernary {
224        op: ScalarTernaryOp::Where,
225    }
226}