tenferro_prims/families/
metadata.rs

1use tenferro_device::Result;
2use tenferro_tensor::Tensor;
3
4/// Metadata tensor dtypes.
5///
6/// Metadata tensors currently use logical `I32` and `Bool` dtypes. `Bool` is
7/// backed by `u8` storage for now; that storage detail is provisional and may
8/// change once native bool tensor support exists.
9///
10/// # Examples
11///
12/// ```rust
13/// use tenferro_prims::MetadataDType;
14///
15/// assert_eq!(MetadataDType::I32, MetadataDType::I32);
16/// assert_eq!(MetadataDType::Bool, MetadataDType::Bool);
17/// ```
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum MetadataDType {
20    /// Signed 32-bit integer metadata tensors.
21    I32,
22    /// Logical bool/mask metadata tensors.
23    Bool,
24}
25
26/// Constant payload for metadata tensor generation.
27///
28/// # Examples
29///
30/// ```rust
31/// use tenferro_prims::{MetadataConstantValue, MetadataDType};
32///
33/// let int_value = MetadataConstantValue::I32(7);
34/// let bool_value = MetadataConstantValue::Bool(true);
35/// assert_eq!(int_value.dtype(), MetadataDType::I32);
36/// assert_eq!(bool_value.dtype(), MetadataDType::Bool);
37/// ```
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum MetadataConstantValue {
40    /// Integer metadata payload.
41    I32(i32),
42    /// Logical bool payload.
43    Bool(bool),
44}
45
46impl MetadataConstantValue {
47    /// Return the logical dtype carried by this constant payload.
48    ///
49    /// # Examples
50    ///
51    /// ```rust
52    /// use tenferro_prims::{MetadataConstantValue, MetadataDType};
53    ///
54    /// assert_eq!(MetadataConstantValue::I32(-3).dtype(), MetadataDType::I32);
55    /// assert_eq!(MetadataConstantValue::Bool(false).dtype(), MetadataDType::Bool);
56    /// ```
57    pub const fn dtype(self) -> MetadataDType {
58        match self {
59            Self::I32(_) => MetadataDType::I32,
60            Self::Bool(_) => MetadataDType::Bool,
61        }
62    }
63}
64
65/// Metadata tensor generation operations.
66///
67/// Generation is intentionally separate from pointwise metadata ops so the
68/// contract can describe `iota`-style and constant metadata tensors without
69/// requiring a dummy input tensor.
70///
71/// # Examples
72///
73/// ```rust
74/// use tenferro_prims::{MetadataConstantValue, MetadataGenerateOp};
75///
76/// let op = MetadataGenerateOp::IotaStartZero;
77/// assert!(matches!(op, MetadataGenerateOp::IotaStartZero));
78/// let constant = MetadataGenerateOp::Constant(MetadataConstantValue::Bool(true));
79/// assert!(matches!(
80///     constant,
81///     MetadataGenerateOp::Constant(MetadataConstantValue::Bool(true))
82/// ));
83/// ```
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum MetadataGenerateOp {
86    /// Generate a zero-based iota/arange tensor.
87    IotaStartZero,
88    /// Generate a tensor filled with a constant payload.
89    Constant(MetadataConstantValue),
90}
91
92/// Integer/bool metadata binary operations.
93///
94/// # Examples
95///
96/// ```rust
97/// use tenferro_prims::MetadataBinaryOp;
98///
99/// let op = MetadataBinaryOp::NotEqual;
100/// assert_eq!(op, MetadataBinaryOp::NotEqual);
101/// ```
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub enum MetadataBinaryOp {
104    /// Elementwise equality.
105    Equal,
106    /// Elementwise inequality.
107    NotEqual,
108    /// Metadata addition.
109    Add,
110    /// Metadata subtraction.
111    Sub,
112    /// Metadata multiplication.
113    Mul,
114    /// Elementwise bitwise-and.
115    BitAnd,
116}
117
118/// Integer/bool metadata ternary operations.
119///
120/// # Examples
121///
122/// ```rust
123/// use tenferro_prims::MetadataTernaryOp;
124///
125/// let op = MetadataTernaryOp::Where;
126/// assert_eq!(op, MetadataTernaryOp::Where);
127/// ```
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
129pub enum MetadataTernaryOp {
130    /// Select from the second or third input using the first input as the mask.
131    Where,
132}
133
134/// Integer/bool metadata reduction operations.
135///
136/// # Examples
137///
138/// ```rust
139/// use tenferro_prims::MetadataReductionOp;
140///
141/// let op = MetadataReductionOp::Sum;
142/// assert_eq!(op, MetadataReductionOp::Sum);
143/// ```
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum MetadataReductionOp {
146    /// Sum metadata values into the output tensor.
147    Sum,
148    /// Logical all-reduction over bool-like metadata tensors.
149    All,
150    /// Logical any-reduction over bool-like metadata tensors.
151    Any,
152}
153
154/// Erased immutable metadata tensor reference.
155///
156/// The contract is intentionally narrow and only admits integer metadata and
157/// logical bool metadata. The bool handle is currently backed by `u8`
158/// storage, but that is a provisional implementation detail.
159///
160/// # Examples
161///
162/// ```ignore
163/// use tenferro_device::LogicalMemorySpace;
164/// use tenferro_prims::{MetadataDType, MetadataTensorRef};
165/// use tenferro_tensor::{MemoryOrder, Tensor};
166///
167/// let tensor = Tensor::<i32>::zeros(
168///     &[2, 2],
169///     LogicalMemorySpace::MainMemory,
170///     MemoryOrder::ColumnMajor,
171/// );
172/// let metadata = MetadataTensorRef::I32(&tensor);
173/// assert_eq!(metadata.dtype(), MetadataDType::I32);
174/// ```
175#[derive(Debug, Clone, Copy)]
176pub enum MetadataTensorRef<'a> {
177    /// A metadata tensor stored as `i32`.
178    I32(&'a Tensor<i32>),
179    /// A logical bool/mask metadata tensor backed by `u8` for now.
180    Bool(&'a Tensor<u8>),
181}
182
183impl<'a> MetadataTensorRef<'a> {
184    /// Return the logical dtype carried by this metadata tensor reference.
185    ///
186    /// # Examples
187    ///
188    /// ```ignore
189    /// use tenferro_device::LogicalMemorySpace;
190    /// use tenferro_prims::{MetadataDType, MetadataTensorRef};
191    /// use tenferro_tensor::{MemoryOrder, Tensor};
192    ///
193    /// let tensor = Tensor::<u8>::zeros(
194    ///     &[1],
195    ///     LogicalMemorySpace::MainMemory,
196    ///     MemoryOrder::ColumnMajor,
197    /// );
198    /// let metadata = MetadataTensorRef::Bool(&tensor);
199    /// assert_eq!(metadata.dtype(), MetadataDType::Bool);
200    /// ```
201    pub const fn dtype(&self) -> MetadataDType {
202        match self {
203            Self::I32(_) => MetadataDType::I32,
204            Self::Bool(_) => MetadataDType::Bool,
205        }
206    }
207}
208
209/// Erased mutable metadata tensor reference.
210///
211/// # Examples
212///
213/// ```ignore
214/// use tenferro_device::LogicalMemorySpace;
215/// use tenferro_prims::{MetadataDType, MetadataTensorMut};
216/// use tenferro_tensor::{MemoryOrder, Tensor};
217///
218/// let mut tensor = Tensor::<u8>::zeros(
219///     &[2, 2],
220///     LogicalMemorySpace::MainMemory,
221///     MemoryOrder::ColumnMajor,
222/// );
223/// let metadata = MetadataTensorMut::Bool(&mut tensor);
224/// assert_eq!(metadata.dtype(), MetadataDType::Bool);
225/// ```
226#[derive(Debug)]
227pub enum MetadataTensorMut<'a> {
228    /// A metadata tensor stored as `i32`.
229    I32(&'a mut Tensor<i32>),
230    /// A logical bool/mask metadata tensor backed by `u8` for now.
231    Bool(&'a mut Tensor<u8>),
232}
233
234impl<'a> MetadataTensorMut<'a> {
235    /// Return the logical dtype carried by this metadata tensor handle.
236    ///
237    /// # Examples
238    ///
239    /// ```ignore
240    /// use tenferro_device::LogicalMemorySpace;
241    /// use tenferro_prims::{MetadataDType, MetadataTensorMut};
242    /// use tenferro_tensor::{MemoryOrder, Tensor};
243    ///
244    /// let mut tensor = Tensor::<i32>::zeros(
245    ///     &[1],
246    ///     LogicalMemorySpace::MainMemory,
247    ///     MemoryOrder::ColumnMajor,
248    /// );
249    /// let metadata = MetadataTensorMut::I32(&mut tensor);
250    /// assert_eq!(metadata.dtype(), MetadataDType::I32);
251    /// ```
252    pub const fn dtype(&self) -> MetadataDType {
253        match self {
254            Self::I32(_) => MetadataDType::I32,
255            Self::Bool(_) => MetadataDType::Bool,
256        }
257    }
258}
259
260/// Descriptor for metadata tensor planning.
261///
262/// # Examples
263///
264/// ```rust
265/// use tenferro_prims::{
266///     MetadataConstantValue, MetadataDType, MetadataGenerateOp, MetadataPrimsDescriptor,
267/// };
268///
269/// let desc = MetadataPrimsDescriptor::Generate {
270///     op: MetadataGenerateOp::Constant(MetadataConstantValue::I32(3)),
271///     output_dtype: MetadataDType::I32,
272/// };
273/// assert!(matches!(desc, MetadataPrimsDescriptor::Generate { .. }));
274/// ```
275#[derive(Debug, Clone, PartialEq, Eq, Hash)]
276pub enum MetadataPrimsDescriptor {
277    /// Generate a metadata tensor without consuming an input tensor.
278    Generate {
279        /// Generation operation to apply.
280        op: MetadataGenerateOp,
281        /// Logical dtype of the generated tensor.
282        output_dtype: MetadataDType,
283    },
284    /// Apply a metadata binary operation to two input tensors.
285    Binary {
286        /// Binary operation to apply.
287        op: MetadataBinaryOp,
288        /// Logical dtype of the left-hand side operand.
289        lhs_dtype: MetadataDType,
290        /// Logical dtype of the right-hand side operand.
291        rhs_dtype: MetadataDType,
292        /// Logical dtype of the output tensor.
293        output_dtype: MetadataDType,
294    },
295    /// Apply a metadata ternary operation to three input tensors.
296    Ternary {
297        /// Ternary operation to apply.
298        op: MetadataTernaryOp,
299        /// Logical dtype of the condition/mask input.
300        cond_dtype: MetadataDType,
301        /// Logical dtype of the first data input.
302        lhs_dtype: MetadataDType,
303        /// Logical dtype of the second data input.
304        rhs_dtype: MetadataDType,
305        /// Logical dtype of the output tensor.
306        output_dtype: MetadataDType,
307    },
308    /// Reduce one tensor into an output tensor over the dropped modes.
309    Reduction {
310        /// Input modes associated with the source tensor.
311        modes_a: Vec<u32>,
312        /// Output modes that remain after reduction.
313        modes_c: Vec<u32>,
314        /// Logical dtype of the input tensor.
315        input_dtype: MetadataDType,
316        /// Logical dtype of the output tensor.
317        output_dtype: MetadataDType,
318        /// Reduction operator to use.
319        op: MetadataReductionOp,
320    },
321}
322
323/// Bridge trait that binds a metadata execution context to its backend.
324///
325/// This mirrors the other family-context bridge traits but is reserved for
326/// integer/bool metadata tensor workflows.
327///
328/// # Examples
329///
330/// ```ignore
331/// use tenferro_prims::TensorMetadataContextFor;
332///
333/// fn accepts_context<C>(_: &mut C)
334/// where
335///     C: TensorMetadataContextFor,
336/// {
337/// }
338///
339/// // A backend context can satisfy this trait once the metadata family is wired.
340/// ```
341pub trait TensorMetadataContextFor {
342    /// Backend associated with this context for the metadata family.
343    type MetadataBackend: TensorMetadataPrims<Context = Self>;
344}
345
346/// Metadata tensor planning and execution protocol.
347///
348/// Metadata execution is overwrite-based and uses erased integer/bool metadata
349/// tensor handles instead of scalar-family `alpha` / `beta` scaling.
350///
351/// # Examples
352///
353/// ```ignore
354/// use tenferro_prims::{
355///     MetadataGenerateOp, MetadataPrimsDescriptor, MetadataTensorMut,
356///     MetadataTensorRef, TensorMetadataPrims,
357/// };
358///
359/// fn accepts_family<F: TensorMetadataPrims>(_: &F) {}
360///
361/// let _ = MetadataPrimsDescriptor::Generate {
362///     op: MetadataGenerateOp::IotaStartZero,
363///     output_dtype: tenferro_prims::MetadataDType::I32,
364/// };
365/// ```
366pub trait TensorMetadataPrims {
367    /// Backend plan type.
368    type Plan;
369    /// Backend execution context.
370    type Context;
371
372    /// Plan a metadata-family operation for the given input and output tensor
373    /// handles.
374    fn plan(
375        ctx: &mut Self::Context,
376        desc: &MetadataPrimsDescriptor,
377        inputs: &[MetadataTensorRef<'_>],
378        output: MetadataTensorMut<'_>,
379    ) -> Result<Self::Plan>;
380
381    /// Execute a previously planned metadata-family operation in overwrite
382    /// mode.
383    fn execute(
384        ctx: &mut Self::Context,
385        plan: &Self::Plan,
386        inputs: &[MetadataTensorRef<'_>],
387        output: MetadataTensorMut<'_>,
388    ) -> Result<()>;
389
390    /// Report whether the backend advertises support for the given descriptor.
391    fn has_metadata_support(desc: MetadataPrimsDescriptor) -> bool;
392}