tenferro_prims/families/metadata.rs
1use tenferro_device::Result;
2use tenferro_tensor::Tensor;
3
4/// Metadata tensor dtypes.
5///
6/// Metadata tensors currently use logical `I32` and `Bool` dtypes. `Bool` is
7/// backed by `u8` storage for now; that storage detail is provisional and may
8/// change once native bool tensor support exists.
9///
10/// # Examples
11///
12/// ```rust
13/// use tenferro_prims::MetadataDType;
14///
15/// assert_eq!(MetadataDType::I32, MetadataDType::I32);
16/// assert_eq!(MetadataDType::Bool, MetadataDType::Bool);
17/// ```
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum MetadataDType {
20 /// Signed 32-bit integer metadata tensors.
21 I32,
22 /// Logical bool/mask metadata tensors.
23 Bool,
24}
25
26/// Constant payload for metadata tensor generation.
27///
28/// # Examples
29///
30/// ```rust
31/// use tenferro_prims::{MetadataConstantValue, MetadataDType};
32///
33/// let int_value = MetadataConstantValue::I32(7);
34/// let bool_value = MetadataConstantValue::Bool(true);
35/// assert_eq!(int_value.dtype(), MetadataDType::I32);
36/// assert_eq!(bool_value.dtype(), MetadataDType::Bool);
37/// ```
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum MetadataConstantValue {
40 /// Integer metadata payload.
41 I32(i32),
42 /// Logical bool payload.
43 Bool(bool),
44}
45
46impl MetadataConstantValue {
47 /// Return the logical dtype carried by this constant payload.
48 ///
49 /// # Examples
50 ///
51 /// ```rust
52 /// use tenferro_prims::{MetadataConstantValue, MetadataDType};
53 ///
54 /// assert_eq!(MetadataConstantValue::I32(-3).dtype(), MetadataDType::I32);
55 /// assert_eq!(MetadataConstantValue::Bool(false).dtype(), MetadataDType::Bool);
56 /// ```
57 pub const fn dtype(self) -> MetadataDType {
58 match self {
59 Self::I32(_) => MetadataDType::I32,
60 Self::Bool(_) => MetadataDType::Bool,
61 }
62 }
63}
64
65/// Metadata tensor generation operations.
66///
67/// Generation is intentionally separate from pointwise metadata ops so the
68/// contract can describe `iota`-style and constant metadata tensors without
69/// requiring a dummy input tensor.
70///
71/// # Examples
72///
73/// ```rust
74/// use tenferro_prims::{MetadataConstantValue, MetadataGenerateOp};
75///
76/// let op = MetadataGenerateOp::IotaStartZero;
77/// assert!(matches!(op, MetadataGenerateOp::IotaStartZero));
78/// let constant = MetadataGenerateOp::Constant(MetadataConstantValue::Bool(true));
79/// assert!(matches!(
80/// constant,
81/// MetadataGenerateOp::Constant(MetadataConstantValue::Bool(true))
82/// ));
83/// ```
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum MetadataGenerateOp {
86 /// Generate a zero-based iota/arange tensor.
87 IotaStartZero,
88 /// Generate a tensor filled with a constant payload.
89 Constant(MetadataConstantValue),
90}
91
92/// Integer/bool metadata binary operations.
93///
94/// # Examples
95///
96/// ```rust
97/// use tenferro_prims::MetadataBinaryOp;
98///
99/// let op = MetadataBinaryOp::NotEqual;
100/// assert_eq!(op, MetadataBinaryOp::NotEqual);
101/// ```
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub enum MetadataBinaryOp {
104 /// Elementwise equality.
105 Equal,
106 /// Elementwise inequality.
107 NotEqual,
108 /// Metadata addition.
109 Add,
110 /// Metadata subtraction.
111 Sub,
112 /// Metadata multiplication.
113 Mul,
114 /// Elementwise bitwise-and.
115 BitAnd,
116}
117
118/// Integer/bool metadata ternary operations.
119///
120/// # Examples
121///
122/// ```rust
123/// use tenferro_prims::MetadataTernaryOp;
124///
125/// let op = MetadataTernaryOp::Where;
126/// assert_eq!(op, MetadataTernaryOp::Where);
127/// ```
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
129pub enum MetadataTernaryOp {
130 /// Select from the second or third input using the first input as the mask.
131 Where,
132}
133
134/// Integer/bool metadata reduction operations.
135///
136/// # Examples
137///
138/// ```rust
139/// use tenferro_prims::MetadataReductionOp;
140///
141/// let op = MetadataReductionOp::Sum;
142/// assert_eq!(op, MetadataReductionOp::Sum);
143/// ```
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
145pub enum MetadataReductionOp {
146 /// Sum metadata values into the output tensor.
147 Sum,
148 /// Logical all-reduction over bool-like metadata tensors.
149 All,
150 /// Logical any-reduction over bool-like metadata tensors.
151 Any,
152}
153
154/// Erased immutable metadata tensor reference.
155///
156/// The contract is intentionally narrow and only admits integer metadata and
157/// logical bool metadata. The bool handle is currently backed by `u8`
158/// storage, but that is a provisional implementation detail.
159///
160/// # Examples
161///
162/// ```ignore
163/// use tenferro_device::LogicalMemorySpace;
164/// use tenferro_prims::{MetadataDType, MetadataTensorRef};
165/// use tenferro_tensor::{MemoryOrder, Tensor};
166///
167/// let tensor = Tensor::<i32>::zeros(
168/// &[2, 2],
169/// LogicalMemorySpace::MainMemory,
170/// MemoryOrder::ColumnMajor,
171/// );
172/// let metadata = MetadataTensorRef::I32(&tensor);
173/// assert_eq!(metadata.dtype(), MetadataDType::I32);
174/// ```
175#[derive(Debug, Clone, Copy)]
176pub enum MetadataTensorRef<'a> {
177 /// A metadata tensor stored as `i32`.
178 I32(&'a Tensor<i32>),
179 /// A logical bool/mask metadata tensor backed by `u8` for now.
180 Bool(&'a Tensor<u8>),
181}
182
183impl<'a> MetadataTensorRef<'a> {
184 /// Return the logical dtype carried by this metadata tensor reference.
185 ///
186 /// # Examples
187 ///
188 /// ```ignore
189 /// use tenferro_device::LogicalMemorySpace;
190 /// use tenferro_prims::{MetadataDType, MetadataTensorRef};
191 /// use tenferro_tensor::{MemoryOrder, Tensor};
192 ///
193 /// let tensor = Tensor::<u8>::zeros(
194 /// &[1],
195 /// LogicalMemorySpace::MainMemory,
196 /// MemoryOrder::ColumnMajor,
197 /// );
198 /// let metadata = MetadataTensorRef::Bool(&tensor);
199 /// assert_eq!(metadata.dtype(), MetadataDType::Bool);
200 /// ```
201 pub const fn dtype(&self) -> MetadataDType {
202 match self {
203 Self::I32(_) => MetadataDType::I32,
204 Self::Bool(_) => MetadataDType::Bool,
205 }
206 }
207}
208
209/// Erased mutable metadata tensor reference.
210///
211/// # Examples
212///
213/// ```ignore
214/// use tenferro_device::LogicalMemorySpace;
215/// use tenferro_prims::{MetadataDType, MetadataTensorMut};
216/// use tenferro_tensor::{MemoryOrder, Tensor};
217///
218/// let mut tensor = Tensor::<u8>::zeros(
219/// &[2, 2],
220/// LogicalMemorySpace::MainMemory,
221/// MemoryOrder::ColumnMajor,
222/// );
223/// let metadata = MetadataTensorMut::Bool(&mut tensor);
224/// assert_eq!(metadata.dtype(), MetadataDType::Bool);
225/// ```
226#[derive(Debug)]
227pub enum MetadataTensorMut<'a> {
228 /// A metadata tensor stored as `i32`.
229 I32(&'a mut Tensor<i32>),
230 /// A logical bool/mask metadata tensor backed by `u8` for now.
231 Bool(&'a mut Tensor<u8>),
232}
233
234impl<'a> MetadataTensorMut<'a> {
235 /// Return the logical dtype carried by this metadata tensor handle.
236 ///
237 /// # Examples
238 ///
239 /// ```ignore
240 /// use tenferro_device::LogicalMemorySpace;
241 /// use tenferro_prims::{MetadataDType, MetadataTensorMut};
242 /// use tenferro_tensor::{MemoryOrder, Tensor};
243 ///
244 /// let mut tensor = Tensor::<i32>::zeros(
245 /// &[1],
246 /// LogicalMemorySpace::MainMemory,
247 /// MemoryOrder::ColumnMajor,
248 /// );
249 /// let metadata = MetadataTensorMut::I32(&mut tensor);
250 /// assert_eq!(metadata.dtype(), MetadataDType::I32);
251 /// ```
252 pub const fn dtype(&self) -> MetadataDType {
253 match self {
254 Self::I32(_) => MetadataDType::I32,
255 Self::Bool(_) => MetadataDType::Bool,
256 }
257 }
258}
259
260/// Descriptor for metadata tensor planning.
261///
262/// # Examples
263///
264/// ```rust
265/// use tenferro_prims::{
266/// MetadataConstantValue, MetadataDType, MetadataGenerateOp, MetadataPrimsDescriptor,
267/// };
268///
269/// let desc = MetadataPrimsDescriptor::Generate {
270/// op: MetadataGenerateOp::Constant(MetadataConstantValue::I32(3)),
271/// output_dtype: MetadataDType::I32,
272/// };
273/// assert!(matches!(desc, MetadataPrimsDescriptor::Generate { .. }));
274/// ```
275#[derive(Debug, Clone, PartialEq, Eq, Hash)]
276pub enum MetadataPrimsDescriptor {
277 /// Generate a metadata tensor without consuming an input tensor.
278 Generate {
279 /// Generation operation to apply.
280 op: MetadataGenerateOp,
281 /// Logical dtype of the generated tensor.
282 output_dtype: MetadataDType,
283 },
284 /// Apply a metadata binary operation to two input tensors.
285 Binary {
286 /// Binary operation to apply.
287 op: MetadataBinaryOp,
288 /// Logical dtype of the left-hand side operand.
289 lhs_dtype: MetadataDType,
290 /// Logical dtype of the right-hand side operand.
291 rhs_dtype: MetadataDType,
292 /// Logical dtype of the output tensor.
293 output_dtype: MetadataDType,
294 },
295 /// Apply a metadata ternary operation to three input tensors.
296 Ternary {
297 /// Ternary operation to apply.
298 op: MetadataTernaryOp,
299 /// Logical dtype of the condition/mask input.
300 cond_dtype: MetadataDType,
301 /// Logical dtype of the first data input.
302 lhs_dtype: MetadataDType,
303 /// Logical dtype of the second data input.
304 rhs_dtype: MetadataDType,
305 /// Logical dtype of the output tensor.
306 output_dtype: MetadataDType,
307 },
308 /// Reduce one tensor into an output tensor over the dropped modes.
309 Reduction {
310 /// Input modes associated with the source tensor.
311 modes_a: Vec<u32>,
312 /// Output modes that remain after reduction.
313 modes_c: Vec<u32>,
314 /// Logical dtype of the input tensor.
315 input_dtype: MetadataDType,
316 /// Logical dtype of the output tensor.
317 output_dtype: MetadataDType,
318 /// Reduction operator to use.
319 op: MetadataReductionOp,
320 },
321}
322
323/// Bridge trait that binds a metadata execution context to its backend.
324///
325/// This mirrors the other family-context bridge traits but is reserved for
326/// integer/bool metadata tensor workflows.
327///
328/// # Examples
329///
330/// ```ignore
331/// use tenferro_prims::TensorMetadataContextFor;
332///
333/// fn accepts_context<C>(_: &mut C)
334/// where
335/// C: TensorMetadataContextFor,
336/// {
337/// }
338///
339/// // A backend context can satisfy this trait once the metadata family is wired.
340/// ```
341pub trait TensorMetadataContextFor {
342 /// Backend associated with this context for the metadata family.
343 type MetadataBackend: TensorMetadataPrims<Context = Self>;
344}
345
346/// Metadata tensor planning and execution protocol.
347///
348/// Metadata execution is overwrite-based and uses erased integer/bool metadata
349/// tensor handles instead of scalar-family `alpha` / `beta` scaling.
350///
351/// # Examples
352///
353/// ```ignore
354/// use tenferro_prims::{
355/// MetadataGenerateOp, MetadataPrimsDescriptor, MetadataTensorMut,
356/// MetadataTensorRef, TensorMetadataPrims,
357/// };
358///
359/// fn accepts_family<F: TensorMetadataPrims>(_: &F) {}
360///
361/// let _ = MetadataPrimsDescriptor::Generate {
362/// op: MetadataGenerateOp::IotaStartZero,
363/// output_dtype: tenferro_prims::MetadataDType::I32,
364/// };
365/// ```
366pub trait TensorMetadataPrims {
367 /// Backend plan type.
368 type Plan;
369 /// Backend execution context.
370 type Context;
371
372 /// Plan a metadata-family operation for the given input and output tensor
373 /// handles.
374 fn plan(
375 ctx: &mut Self::Context,
376 desc: &MetadataPrimsDescriptor,
377 inputs: &[MetadataTensorRef<'_>],
378 output: MetadataTensorMut<'_>,
379 ) -> Result<Self::Plan>;
380
381 /// Execute a previously planned metadata-family operation in overwrite
382 /// mode.
383 fn execute(
384 ctx: &mut Self::Context,
385 plan: &Self::Plan,
386 inputs: &[MetadataTensorRef<'_>],
387 output: MetadataTensorMut<'_>,
388 ) -> Result<()>;
389
390 /// Report whether the backend advertises support for the given descriptor.
391 fn has_metadata_support(desc: MetadataPrimsDescriptor) -> bool;
392}