tenferro_prims/cpu/
metadata.rs

1use std::collections::BTreeSet;
2
3use strided_view::{StridedView, StridedViewMut};
4use tenferro_device::{Error, Result};
5
6use crate::cpu::common::plan_reduction;
7use crate::cpu::{tensor_to_view, tensor_to_view_mut};
8use crate::shape_helpers::{broadcast_tensor_to_shape, validate_shape_broadcastable};
9use crate::{
10    CpuBackend, CpuContext, MetadataBinaryOp, MetadataConstantValue, MetadataDType,
11    MetadataGenerateOp, MetadataPrimsDescriptor, MetadataReductionOp, MetadataTensorMut,
12    MetadataTensorRef, MetadataTernaryOp, TensorMetadataPrims,
13};
14
15fn tensor_dims_ref<'a>(tensor: &'a MetadataTensorRef<'a>) -> &'a [usize] {
16    match tensor {
17        MetadataTensorRef::I32(tensor) => tensor.dims(),
18        MetadataTensorRef::Bool(tensor) => tensor.dims(),
19    }
20}
21
22fn tensor_dims_mut<'a>(tensor: &'a MetadataTensorMut<'a>) -> &'a [usize] {
23    match tensor {
24        MetadataTensorMut::I32(tensor) => tensor.dims(),
25        MetadataTensorMut::Bool(tensor) => tensor.dims(),
26    }
27}
28
29fn validate_supported_generate(op: MetadataGenerateOp, output_dtype: MetadataDType) -> Result<()> {
30    match (op, output_dtype) {
31        (MetadataGenerateOp::IotaStartZero, MetadataDType::I32)
32        | (MetadataGenerateOp::Constant(MetadataConstantValue::I32(_)), MetadataDType::I32)
33        | (MetadataGenerateOp::Constant(MetadataConstantValue::Bool(_)), MetadataDType::Bool) => {
34            Ok(())
35        }
36        _ => Err(Error::InvalidArgument(format!(
37            "unsupported metadata generate dtype combination: op={op:?} dst={output_dtype:?}"
38        ))),
39    }
40}
41
42fn validate_supported_binary(
43    op: MetadataBinaryOp,
44    lhs_dtype: MetadataDType,
45    rhs_dtype: MetadataDType,
46    output_dtype: MetadataDType,
47) -> Result<()> {
48    match (op, lhs_dtype, rhs_dtype, output_dtype) {
49        (
50            MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
51            MetadataDType::I32,
52            MetadataDType::I32,
53            MetadataDType::Bool,
54        )
55        | (
56            MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
57            MetadataDType::Bool,
58            MetadataDType::Bool,
59            MetadataDType::Bool,
60        )
61        | (
62            MetadataBinaryOp::Add | MetadataBinaryOp::Sub | MetadataBinaryOp::Mul,
63            MetadataDType::I32,
64            MetadataDType::I32,
65            MetadataDType::I32,
66        )
67        | (
68            MetadataBinaryOp::BitAnd,
69            MetadataDType::I32,
70            MetadataDType::I32,
71            MetadataDType::I32,
72        )
73        | (
74            MetadataBinaryOp::BitAnd,
75            MetadataDType::Bool,
76            MetadataDType::Bool,
77            MetadataDType::Bool,
78        ) => Ok(()),
79        _ => Err(Error::InvalidArgument(format!(
80            "unsupported metadata binary dtype combination: op={op:?} lhs={lhs_dtype:?} rhs={rhs_dtype:?} dst={output_dtype:?}"
81        ))),
82    }
83}
84
85fn validate_supported_ternary(
86    op: MetadataTernaryOp,
87    cond_dtype: MetadataDType,
88    lhs_dtype: MetadataDType,
89    rhs_dtype: MetadataDType,
90    output_dtype: MetadataDType,
91) -> Result<()> {
92    if !matches!(op, MetadataTernaryOp::Where) {
93        return Err(Error::InvalidArgument(format!(
94            "metadata ternary operation {op:?} is not supported on CpuBackend"
95        )));
96    }
97    match (cond_dtype, lhs_dtype, rhs_dtype, output_dtype) {
98        (MetadataDType::Bool, MetadataDType::I32, MetadataDType::I32, MetadataDType::I32)
99        | (MetadataDType::Bool, MetadataDType::Bool, MetadataDType::Bool, MetadataDType::Bool) =>
100        {
101            Ok(())
102        }
103        _ => Err(Error::InvalidArgument(format!(
104            "unsupported metadata ternary dtype combination: cond={cond_dtype:?} lhs={lhs_dtype:?} rhs={rhs_dtype:?} dst={output_dtype:?}"
105        ))),
106    }
107}
108
109fn validate_supported_reduction(
110    op: MetadataReductionOp,
111    input_dtype: MetadataDType,
112    output_dtype: MetadataDType,
113) -> Result<()> {
114    match (op, input_dtype, output_dtype) {
115        (MetadataReductionOp::Sum, MetadataDType::Bool, MetadataDType::I32)
116        | (MetadataReductionOp::Sum, MetadataDType::I32, MetadataDType::I32)
117        | (MetadataReductionOp::All, MetadataDType::Bool, MetadataDType::Bool)
118        | (MetadataReductionOp::Any, MetadataDType::Bool, MetadataDType::Bool) => Ok(()),
119        _ => Err(Error::InvalidArgument(format!(
120            "unsupported metadata reduction dtype combination: op={op:?} input={input_dtype:?} dst={output_dtype:?}"
121        ))),
122    }
123}
124
125fn validate_unique_mode_labels(modes: &[u32], label: &str) -> Result<()> {
126    let mut seen = BTreeSet::new();
127    for &mode in modes {
128        if !seen.insert(mode) {
129            return Err(Error::InvalidArgument(format!(
130                "{label} contains duplicate mode label {mode}"
131            )));
132        }
133    }
134    Ok(())
135}
136
137fn validate_metadata_handle_count(
138    inputs: &[MetadataTensorRef<'_>],
139    expected: usize,
140    op_name: &str,
141) -> Result<()> {
142    if inputs.len() != expected {
143        return Err(Error::InvalidArgument(format!(
144            "{op_name} expects {expected} input(s) (got {})",
145            inputs.len()
146        )));
147    }
148    Ok(())
149}
150
151fn execute_metadata_generate_i32(output: &mut StridedViewMut<i32>) -> Result<()> {
152    let dims = output.dims().to_vec();
153    let mut value = 0i32;
154    crate::for_each_index(&dims, |idx| {
155        output.set(idx, value);
156        value = value
157            .checked_add(1)
158            .expect("metadata iota overflow should be validated at plan time");
159    });
160    Ok(())
161}
162
163fn execute_metadata_generate_constant<T: Copy>(
164    output: &mut StridedViewMut<T>,
165    value: T,
166) -> Result<()> {
167    let dims = output.dims().to_vec();
168    crate::for_each_index(&dims, |idx| {
169        output.set(idx, value);
170    });
171    Ok(())
172}
173
174fn execute_metadata_binary_map<Lhs, Rhs, OutputT, F>(
175    lhs: &StridedView<Lhs>,
176    rhs: &StridedView<Rhs>,
177    output: &mut StridedViewMut<OutputT>,
178    f: F,
179) -> Result<()>
180where
181    Lhs: Copy,
182    Rhs: Copy,
183    OutputT: Copy,
184    F: Fn(Lhs, Rhs) -> OutputT + Copy,
185{
186    let dims = output.dims().to_vec();
187    crate::for_each_index(&dims, |idx| {
188        output.set(idx, f(lhs.get(idx), rhs.get(idx)));
189    });
190    Ok(())
191}
192
193fn execute_metadata_ternary_map<Lhs, F>(
194    cond: &StridedView<u8>,
195    on_true: &StridedView<Lhs>,
196    on_false: &StridedView<Lhs>,
197    output: &mut StridedViewMut<Lhs>,
198    f: F,
199) -> Result<()>
200where
201    Lhs: Copy,
202    F: Fn(u8, Lhs, Lhs) -> Lhs + Copy,
203{
204    let dims = output.dims().to_vec();
205    crate::for_each_index(&dims, |idx| {
206        output.set(idx, f(cond.get(idx), on_true.get(idx), on_false.get(idx)));
207    });
208    Ok(())
209}
210
211fn execute_metadata_reduce_sum_i32(
212    input: &StridedView<i32>,
213    output: &mut StridedViewMut<i32>,
214    kept_axes: &[usize],
215    reduced_axes: &[usize],
216) -> Result<()> {
217    let in_dims = input.dims().to_vec();
218    let out_dims = output.dims().to_vec();
219    let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
220    let (kept_axis_positions, reduced_axis_positions) =
221        build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
222    let reduced_total: usize = reduced_dims.iter().product();
223    let mut red_idx = vec![0usize; reduced_dims.len()];
224    let mut in_idx = vec![0usize; in_dims.len()];
225
226    crate::for_each_index(&out_dims, |out_idx| {
227        let mut sum = 0i32;
228        for red_flat in 0..reduced_total {
229            crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
230            build_metadata_reduction_input_index(
231                out_idx,
232                &red_idx,
233                &kept_axis_positions,
234                &reduced_axis_positions,
235                &mut in_idx,
236            );
237            sum += input.get(&in_idx);
238        }
239        output.set(out_idx, sum);
240    });
241
242    Ok(())
243}
244
245fn execute_metadata_reduce_sum_bool(
246    input: &StridedView<u8>,
247    output: &mut StridedViewMut<i32>,
248    kept_axes: &[usize],
249    reduced_axes: &[usize],
250) -> Result<()> {
251    let in_dims = input.dims().to_vec();
252    let out_dims = output.dims().to_vec();
253    let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
254    let (kept_axis_positions, reduced_axis_positions) =
255        build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
256    let reduced_total: usize = reduced_dims.iter().product();
257    let mut red_idx = vec![0usize; reduced_dims.len()];
258    let mut in_idx = vec![0usize; in_dims.len()];
259
260    crate::for_each_index(&out_dims, |out_idx| {
261        let mut sum = 0i32;
262        for red_flat in 0..reduced_total {
263            crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
264            build_metadata_reduction_input_index(
265                out_idx,
266                &red_idx,
267                &kept_axis_positions,
268                &reduced_axis_positions,
269                &mut in_idx,
270            );
271            sum += if input.get(&in_idx) != 0 { 1 } else { 0 };
272        }
273        output.set(out_idx, sum);
274    });
275
276    Ok(())
277}
278
279fn execute_metadata_reduce_all_bool(
280    input: &StridedView<u8>,
281    output: &mut StridedViewMut<u8>,
282    kept_axes: &[usize],
283    reduced_axes: &[usize],
284) -> Result<()> {
285    let in_dims = input.dims().to_vec();
286    let out_dims = output.dims().to_vec();
287    let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
288    let (kept_axis_positions, reduced_axis_positions) =
289        build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
290    let reduced_total: usize = reduced_dims.iter().product();
291    let mut red_idx = vec![0usize; reduced_dims.len()];
292    let mut in_idx = vec![0usize; in_dims.len()];
293
294    crate::for_each_index(&out_dims, |out_idx| {
295        let mut all_true = true;
296        for red_flat in 0..reduced_total {
297            crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
298            build_metadata_reduction_input_index(
299                out_idx,
300                &red_idx,
301                &kept_axis_positions,
302                &reduced_axis_positions,
303                &mut in_idx,
304            );
305            if input.get(&in_idx) == 0 {
306                all_true = false;
307                break;
308            }
309        }
310        output.set(out_idx, if all_true { 1 } else { 0 });
311    });
312
313    Ok(())
314}
315
316fn execute_metadata_reduce_any_bool(
317    input: &StridedView<u8>,
318    output: &mut StridedViewMut<u8>,
319    kept_axes: &[usize],
320    reduced_axes: &[usize],
321) -> Result<()> {
322    let in_dims = input.dims().to_vec();
323    let out_dims = output.dims().to_vec();
324    let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
325    let (kept_axis_positions, reduced_axis_positions) =
326        build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
327    let reduced_total: usize = reduced_dims.iter().product();
328    let mut red_idx = vec![0usize; reduced_dims.len()];
329    let mut in_idx = vec![0usize; in_dims.len()];
330
331    crate::for_each_index(&out_dims, |out_idx| {
332        let mut any_true = false;
333        for red_flat in 0..reduced_total {
334            crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
335            build_metadata_reduction_input_index(
336                out_idx,
337                &red_idx,
338                &kept_axis_positions,
339                &reduced_axis_positions,
340                &mut in_idx,
341            );
342            if input.get(&in_idx) != 0 {
343                any_true = true;
344                break;
345            }
346        }
347        output.set(out_idx, if any_true { 1 } else { 0 });
348    });
349
350    Ok(())
351}
352
353fn build_metadata_reduction_axis_positions(
354    rank: usize,
355    kept_axes: &[usize],
356    reduced_axes: &[usize],
357) -> (Vec<Option<usize>>, Vec<Option<usize>>) {
358    let mut kept_axis_positions = vec![None; rank];
359    for (output_axis, &input_axis) in kept_axes.iter().enumerate() {
360        kept_axis_positions[input_axis] = Some(output_axis);
361    }
362
363    let mut reduced_axis_positions = vec![None; rank];
364    for (reduced_axis, &input_axis) in reduced_axes.iter().enumerate() {
365        reduced_axis_positions[input_axis] = Some(reduced_axis);
366    }
367
368    (kept_axis_positions, reduced_axis_positions)
369}
370
371fn build_metadata_reduction_input_index(
372    out_idx: &[usize],
373    red_idx: &[usize],
374    kept_axis_positions: &[Option<usize>],
375    reduced_axis_positions: &[Option<usize>],
376    in_idx: &mut [usize],
377) {
378    for (axis, slot) in in_idx.iter_mut().enumerate() {
379        if let Some(output_axis) = kept_axis_positions[axis] {
380            *slot = out_idx[output_axis];
381        } else if let Some(reduced_axis) = reduced_axis_positions[axis] {
382            *slot = red_idx[reduced_axis];
383        } else {
384            unreachable!("metadata reduction axis {axis} missing from kept/reduced axis lists");
385        }
386    }
387}
388
389fn plan_metadata_reduction(
390    input_dims: &[usize],
391    output_dims: &[usize],
392    desc: &MetadataPrimsDescriptor,
393) -> Result<(Vec<usize>, Vec<usize>)> {
394    let MetadataPrimsDescriptor::Reduction {
395        modes_a, modes_c, ..
396    } = desc
397    else {
398        return Err(Error::InvalidArgument(
399            "expected metadata reduction descriptor".into(),
400        ));
401    };
402    validate_unique_mode_labels(modes_a, "CpuMetadataReduction input")?;
403    validate_unique_mode_labels(modes_c, "CpuMetadataReduction output")?;
404    let reduction = plan_reduction(
405        modes_a,
406        modes_c,
407        &[input_dims, output_dims],
408        "CpuMetadataReduction",
409    )?;
410    let kept_axes = modes_c
411        .iter()
412        .map(|mode| {
413            modes_a
414                .iter()
415                .position(|&candidate| candidate == *mode)
416                .ok_or_else(|| {
417                    Error::InvalidArgument(format!(
418                    "CpuMetadataReduction: output mode {mode} not found in input modes {modes_a:?}"
419                ))
420                })
421        })
422        .collect::<Result<Vec<_>>>()?;
423    Ok((kept_axes, reduction.reduced_axes))
424}
425
426impl TensorMetadataPrims for CpuBackend {
427    type Plan = MetadataPrimsDescriptor;
428    type Context = CpuContext;
429
430    fn plan(
431        _ctx: &mut Self::Context,
432        desc: &MetadataPrimsDescriptor,
433        inputs: &[MetadataTensorRef<'_>],
434        output: MetadataTensorMut<'_>,
435    ) -> Result<Self::Plan> {
436        match desc {
437            MetadataPrimsDescriptor::Generate { op, output_dtype } => {
438                validate_supported_generate(*op, *output_dtype)?;
439                if !inputs.is_empty() {
440                    return Err(Error::InvalidArgument(
441                        "metadata generate expects no inputs".into(),
442                    ));
443                }
444                match (*op, output) {
445                    (MetadataGenerateOp::IotaStartZero, MetadataTensorMut::I32(_))
446                    | (
447                        MetadataGenerateOp::Constant(MetadataConstantValue::I32(_)),
448                        MetadataTensorMut::I32(_),
449                    )
450                    | (
451                        MetadataGenerateOp::Constant(MetadataConstantValue::Bool(_)),
452                        MetadataTensorMut::Bool(_),
453                    ) => Ok(desc.clone()),
454                    _ => Err(Error::InvalidArgument(
455                        "metadata generate output dtype does not match payload".into(),
456                    )),
457                }
458            }
459            MetadataPrimsDescriptor::Binary {
460                op,
461                lhs_dtype,
462                rhs_dtype,
463                output_dtype,
464            } => {
465                validate_supported_binary(*op, *lhs_dtype, *rhs_dtype, *output_dtype)?;
466                validate_metadata_handle_count(inputs, 2, "CpuMetadataBinary")?;
467                validate_shape_broadcastable(
468                    tensor_dims_ref(&inputs[0]),
469                    tensor_dims_mut(&output),
470                    "CpuMetadataBinary lhs",
471                )?;
472                validate_shape_broadcastable(
473                    tensor_dims_ref(&inputs[1]),
474                    tensor_dims_mut(&output),
475                    "CpuMetadataBinary rhs",
476                )?;
477                Ok(desc.clone())
478            }
479            MetadataPrimsDescriptor::Ternary {
480                op,
481                cond_dtype,
482                lhs_dtype,
483                rhs_dtype,
484                output_dtype,
485            } => {
486                validate_supported_ternary(
487                    *op,
488                    *cond_dtype,
489                    *lhs_dtype,
490                    *rhs_dtype,
491                    *output_dtype,
492                )?;
493                validate_metadata_handle_count(inputs, 3, "CpuMetadataTernary")?;
494                if !matches!(*output_dtype, MetadataDType::I32 | MetadataDType::Bool) {
495                    return Err(Error::InvalidArgument(
496                        "unsupported metadata ternary output dtype".into(),
497                    ));
498                }
499                validate_shape_broadcastable(
500                    tensor_dims_ref(&inputs[0]),
501                    tensor_dims_mut(&output),
502                    "CpuMetadataTernary cond",
503                )?;
504                validate_shape_broadcastable(
505                    tensor_dims_ref(&inputs[1]),
506                    tensor_dims_mut(&output),
507                    "CpuMetadataTernary true",
508                )?;
509                validate_shape_broadcastable(
510                    tensor_dims_ref(&inputs[2]),
511                    tensor_dims_mut(&output),
512                    "CpuMetadataTernary false",
513                )?;
514                Ok(desc.clone())
515            }
516            MetadataPrimsDescriptor::Reduction {
517                input_dtype,
518                output_dtype,
519                op,
520                ..
521            } => {
522                validate_supported_reduction(*op, *input_dtype, *output_dtype)?;
523                validate_metadata_handle_count(inputs, 1, "CpuMetadataReduction")?;
524                let _ = plan_metadata_reduction(
525                    tensor_dims_ref(&inputs[0]),
526                    tensor_dims_mut(&output),
527                    desc,
528                )?;
529                Ok(desc.clone())
530            }
531        }
532    }
533
534    fn execute(
535        _ctx: &mut Self::Context,
536        plan: &Self::Plan,
537        inputs: &[MetadataTensorRef<'_>],
538        output: MetadataTensorMut<'_>,
539    ) -> Result<()> {
540        match plan {
541            MetadataPrimsDescriptor::Generate {
542                op: MetadataGenerateOp::IotaStartZero,
543                output_dtype: MetadataDType::I32,
544            } => match output {
545                MetadataTensorMut::I32(output) => {
546                    let mut output = tensor_to_view_mut(output)?;
547                    execute_metadata_generate_i32(&mut output)?;
548                    Ok(())
549                }
550                MetadataTensorMut::Bool(_) => Err(Error::InvalidArgument(
551                    "metadata iota currently supports I32 output only".into(),
552                )),
553            },
554            MetadataPrimsDescriptor::Generate {
555                op: MetadataGenerateOp::Constant(MetadataConstantValue::I32(value)),
556                output_dtype: MetadataDType::I32,
557            } => match output {
558                MetadataTensorMut::I32(output) => {
559                    let mut output = tensor_to_view_mut(output)?;
560                    execute_metadata_generate_constant(&mut output, *value)?;
561                    Ok(())
562                }
563                MetadataTensorMut::Bool(_) => Err(Error::InvalidArgument(
564                    "metadata constant output dtype does not match payload".into(),
565                )),
566            },
567            MetadataPrimsDescriptor::Generate {
568                op: MetadataGenerateOp::Constant(MetadataConstantValue::Bool(value)),
569                output_dtype: MetadataDType::Bool,
570            } => match output {
571                MetadataTensorMut::Bool(output) => {
572                    let mut output = tensor_to_view_mut(output)?;
573                    execute_metadata_generate_constant(&mut output, if *value { 1 } else { 0 })?;
574                    Ok(())
575                }
576                MetadataTensorMut::I32(_) => Err(Error::InvalidArgument(
577                    "metadata constant output dtype does not match payload".into(),
578                )),
579            },
580            MetadataPrimsDescriptor::Binary {
581                op,
582                lhs_dtype,
583                rhs_dtype,
584                output_dtype,
585            } => {
586                validate_metadata_handle_count(inputs, 2, "CpuMetadataBinary")?;
587                match (
588                    *lhs_dtype,
589                    *rhs_dtype,
590                    *output_dtype,
591                    inputs[0],
592                    inputs[1],
593                    output,
594                ) {
595                    (
596                        MetadataDType::I32,
597                        MetadataDType::I32,
598                        MetadataDType::I32,
599                        MetadataTensorRef::I32(lhs),
600                        MetadataTensorRef::I32(rhs),
601                        MetadataTensorMut::I32(dst),
602                    ) => {
603                        let lhs =
604                            broadcast_tensor_to_shape(lhs, dst.dims(), "CpuMetadataBinary lhs")?;
605                        let rhs =
606                            broadcast_tensor_to_shape(rhs, dst.dims(), "CpuMetadataBinary rhs")?;
607                        let lhs = tensor_to_view(&lhs)?;
608                        let rhs = tensor_to_view(&rhs)?;
609                        let mut dst = tensor_to_view_mut(dst)?;
610                        execute_metadata_binary_map(&lhs, &rhs, &mut dst, |x, y| match *op {
611                            MetadataBinaryOp::Add => x + y,
612                            MetadataBinaryOp::Sub => x - y,
613                            MetadataBinaryOp::Mul => x * y,
614                            MetadataBinaryOp::BitAnd => x & y,
615                            _ => unreachable!("unsupported metadata binary op"),
616                        })
617                    }
618                    (
619                        MetadataDType::I32,
620                        MetadataDType::I32,
621                        MetadataDType::Bool,
622                        MetadataTensorRef::I32(lhs),
623                        MetadataTensorRef::I32(rhs),
624                        MetadataTensorMut::Bool(dst),
625                    ) => {
626                        let lhs =
627                            broadcast_tensor_to_shape(lhs, dst.dims(), "CpuMetadataBinary lhs")?;
628                        let rhs =
629                            broadcast_tensor_to_shape(rhs, dst.dims(), "CpuMetadataBinary rhs")?;
630                        let lhs = tensor_to_view(&lhs)?;
631                        let rhs = tensor_to_view(&rhs)?;
632                        let mut dst = tensor_to_view_mut(dst)?;
633                        execute_metadata_binary_map(&lhs, &rhs, &mut dst, |x, y| {
634                            let mapped = match *op {
635                                MetadataBinaryOp::Equal => x == y,
636                                MetadataBinaryOp::NotEqual => x != y,
637                                _ => unreachable!("unsupported metadata binary op"),
638                            };
639                            if mapped {
640                                1
641                            } else {
642                                0
643                            }
644                        })
645                    }
646                    (
647                        MetadataDType::Bool,
648                        MetadataDType::Bool,
649                        MetadataDType::Bool,
650                        MetadataTensorRef::Bool(lhs),
651                        MetadataTensorRef::Bool(rhs),
652                        MetadataTensorMut::Bool(dst),
653                    ) => {
654                        let lhs =
655                            broadcast_tensor_to_shape(lhs, dst.dims(), "CpuMetadataBinary lhs")?;
656                        let rhs =
657                            broadcast_tensor_to_shape(rhs, dst.dims(), "CpuMetadataBinary rhs")?;
658                        let lhs = tensor_to_view(&lhs)?;
659                        let rhs = tensor_to_view(&rhs)?;
660                        let mut dst = tensor_to_view_mut(dst)?;
661                        execute_metadata_binary_map(&lhs, &rhs, &mut dst, |x, y| {
662                            let mapped = match *op {
663                                MetadataBinaryOp::Equal => (x != 0) == (y != 0),
664                                MetadataBinaryOp::NotEqual => (x != 0) != (y != 0),
665                                MetadataBinaryOp::BitAnd => (x != 0) && (y != 0),
666                                _ => unreachable!("unsupported metadata binary op"),
667                            };
668                            if mapped {
669                                1
670                            } else {
671                                0
672                            }
673                        })
674                    }
675                    _ => Err(Error::InvalidArgument(
676                        "unsupported metadata binary execution dtype combination".into(),
677                    )),
678                }
679            }
680            MetadataPrimsDescriptor::Ternary {
681                op: MetadataTernaryOp::Where,
682                cond_dtype,
683                lhs_dtype,
684                rhs_dtype,
685                output_dtype,
686            } => {
687                validate_metadata_handle_count(inputs, 3, "CpuMetadataTernary")?;
688                match (
689                    *cond_dtype,
690                    *lhs_dtype,
691                    *rhs_dtype,
692                    *output_dtype,
693                    inputs[0],
694                    inputs[1],
695                    inputs[2],
696                    output,
697                ) {
698                    (
699                        MetadataDType::Bool,
700                        MetadataDType::I32,
701                        MetadataDType::I32,
702                        MetadataDType::I32,
703                        MetadataTensorRef::Bool(cond),
704                        MetadataTensorRef::I32(on_true),
705                        MetadataTensorRef::I32(on_false),
706                        MetadataTensorMut::I32(dst),
707                    ) => {
708                        let cond =
709                            broadcast_tensor_to_shape(cond, dst.dims(), "CpuMetadataTernary cond")?;
710                        let on_true = broadcast_tensor_to_shape(
711                            on_true,
712                            dst.dims(),
713                            "CpuMetadataTernary true",
714                        )?;
715                        let on_false = broadcast_tensor_to_shape(
716                            on_false,
717                            dst.dims(),
718                            "CpuMetadataTernary false",
719                        )?;
720                        let cond = tensor_to_view(&cond)?;
721                        let on_true = tensor_to_view(&on_true)?;
722                        let on_false = tensor_to_view(&on_false)?;
723                        let mut dst = tensor_to_view_mut(dst)?;
724                        execute_metadata_ternary_map(
725                            &cond,
726                            &on_true,
727                            &on_false,
728                            &mut dst,
729                            |c, t, f| {
730                                if c != 0 {
731                                    t
732                                } else {
733                                    f
734                                }
735                            },
736                        )
737                    }
738                    (
739                        MetadataDType::Bool,
740                        MetadataDType::Bool,
741                        MetadataDType::Bool,
742                        MetadataDType::Bool,
743                        MetadataTensorRef::Bool(cond),
744                        MetadataTensorRef::Bool(on_true),
745                        MetadataTensorRef::Bool(on_false),
746                        MetadataTensorMut::Bool(dst),
747                    ) => {
748                        let cond =
749                            broadcast_tensor_to_shape(cond, dst.dims(), "CpuMetadataTernary cond")?;
750                        let on_true = broadcast_tensor_to_shape(
751                            on_true,
752                            dst.dims(),
753                            "CpuMetadataTernary true",
754                        )?;
755                        let on_false = broadcast_tensor_to_shape(
756                            on_false,
757                            dst.dims(),
758                            "CpuMetadataTernary false",
759                        )?;
760                        let cond = tensor_to_view(&cond)?;
761                        let on_true = tensor_to_view(&on_true)?;
762                        let on_false = tensor_to_view(&on_false)?;
763                        let mut dst = tensor_to_view_mut(dst)?;
764                        execute_metadata_ternary_map(
765                            &cond,
766                            &on_true,
767                            &on_false,
768                            &mut dst,
769                            |c, t, f| {
770                                if c != 0 {
771                                    t
772                                } else {
773                                    f
774                                }
775                            },
776                        )
777                    }
778                    _ => Err(Error::InvalidArgument(
779                        "unsupported metadata ternary execution dtype combination".into(),
780                    )),
781                }
782            }
783            MetadataPrimsDescriptor::Reduction {
784                op,
785                input_dtype,
786                output_dtype,
787                ..
788            } => {
789                validate_metadata_handle_count(inputs, 1, "CpuMetadataReduction")?;
790                let (kept_axes, reduced_axes) = plan_metadata_reduction(
791                    tensor_dims_ref(&inputs[0]),
792                    tensor_dims_mut(&output),
793                    plan,
794                )?;
795                match (*op, *input_dtype, *output_dtype, inputs[0], output) {
796                    (
797                        MetadataReductionOp::Sum,
798                        MetadataDType::I32,
799                        MetadataDType::I32,
800                        MetadataTensorRef::I32(input),
801                        MetadataTensorMut::I32(output),
802                    ) => {
803                        let input = tensor_to_view(input)?;
804                        let mut output = tensor_to_view_mut(output)?;
805                        execute_metadata_reduce_sum_i32(
806                            &input,
807                            &mut output,
808                            &kept_axes,
809                            &reduced_axes,
810                        )
811                    }
812                    (
813                        MetadataReductionOp::Sum,
814                        MetadataDType::Bool,
815                        MetadataDType::I32,
816                        MetadataTensorRef::Bool(input),
817                        MetadataTensorMut::I32(output),
818                    ) => {
819                        let input = tensor_to_view(input)?;
820                        let mut output = tensor_to_view_mut(output)?;
821                        execute_metadata_reduce_sum_bool(
822                            &input,
823                            &mut output,
824                            &kept_axes,
825                            &reduced_axes,
826                        )
827                    }
828                    (
829                        MetadataReductionOp::All,
830                        MetadataDType::Bool,
831                        MetadataDType::Bool,
832                        MetadataTensorRef::Bool(input),
833                        MetadataTensorMut::Bool(output),
834                    ) => {
835                        let input = tensor_to_view(input)?;
836                        let mut output = tensor_to_view_mut(output)?;
837                        execute_metadata_reduce_all_bool(
838                            &input,
839                            &mut output,
840                            &kept_axes,
841                            &reduced_axes,
842                        )
843                    }
844                    (
845                        MetadataReductionOp::Any,
846                        MetadataDType::Bool,
847                        MetadataDType::Bool,
848                        MetadataTensorRef::Bool(input),
849                        MetadataTensorMut::Bool(output),
850                    ) => {
851                        let input = tensor_to_view(input)?;
852                        let mut output = tensor_to_view_mut(output)?;
853                        execute_metadata_reduce_any_bool(
854                            &input,
855                            &mut output,
856                            &kept_axes,
857                            &reduced_axes,
858                        )
859                    }
860                    _ => Err(Error::InvalidArgument(
861                        "unsupported metadata reduction execution dtype combination".into(),
862                    )),
863                }
864            }
865            _ => Err(Error::InvalidArgument(
866                "unsupported metadata metadata descriptor on CpuBackend".into(),
867            )),
868        }
869    }
870
871    fn has_metadata_support(desc: MetadataPrimsDescriptor) -> bool {
872        match desc {
873            MetadataPrimsDescriptor::Generate { op, output_dtype } => matches!(
874                (op, output_dtype),
875                (MetadataGenerateOp::IotaStartZero, MetadataDType::I32)
876                    | (
877                        MetadataGenerateOp::Constant(MetadataConstantValue::I32(_)),
878                        MetadataDType::I32
879                    )
880                    | (
881                        MetadataGenerateOp::Constant(MetadataConstantValue::Bool(_)),
882                        MetadataDType::Bool
883                    )
884            ),
885            MetadataPrimsDescriptor::Binary {
886                op,
887                lhs_dtype,
888                rhs_dtype,
889                output_dtype,
890            } => matches!(
891                (op, lhs_dtype, rhs_dtype, output_dtype),
892                (
893                    MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
894                    MetadataDType::I32,
895                    MetadataDType::I32,
896                    MetadataDType::Bool
897                ) | (
898                    MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
899                    MetadataDType::Bool,
900                    MetadataDType::Bool,
901                    MetadataDType::Bool
902                ) | (
903                    MetadataBinaryOp::Add | MetadataBinaryOp::Sub | MetadataBinaryOp::Mul,
904                    MetadataDType::I32,
905                    MetadataDType::I32,
906                    MetadataDType::I32
907                ) | (
908                    MetadataBinaryOp::BitAnd,
909                    MetadataDType::I32,
910                    MetadataDType::I32,
911                    MetadataDType::I32
912                ) | (
913                    MetadataBinaryOp::BitAnd,
914                    MetadataDType::Bool,
915                    MetadataDType::Bool,
916                    MetadataDType::Bool
917                )
918            ),
919            MetadataPrimsDescriptor::Ternary {
920                op: MetadataTernaryOp::Where,
921                cond_dtype: MetadataDType::Bool,
922                lhs_dtype,
923                rhs_dtype,
924                output_dtype,
925            } => matches!(
926                (lhs_dtype, rhs_dtype, output_dtype),
927                (MetadataDType::I32, MetadataDType::I32, MetadataDType::I32)
928                    | (
929                        MetadataDType::Bool,
930                        MetadataDType::Bool,
931                        MetadataDType::Bool
932                    )
933            ),
934            MetadataPrimsDescriptor::Reduction {
935                op,
936                input_dtype,
937                output_dtype,
938                ..
939            } => matches!(
940                (op, input_dtype, output_dtype),
941                (
942                    MetadataReductionOp::Sum,
943                    MetadataDType::Bool,
944                    MetadataDType::I32
945                ) | (
946                    MetadataReductionOp::Sum,
947                    MetadataDType::I32,
948                    MetadataDType::I32
949                ) | (
950                    MetadataReductionOp::All,
951                    MetadataDType::Bool,
952                    MetadataDType::Bool
953                ) | (
954                    MetadataReductionOp::Any,
955                    MetadataDType::Bool,
956                    MetadataDType::Bool
957                )
958            ),
959            _ => false,
960        }
961    }
962}