1use std::collections::BTreeSet;
2
3use strided_view::{StridedView, StridedViewMut};
4use tenferro_device::{Error, Result};
5
6use crate::cpu::common::plan_reduction;
7use crate::cpu::{tensor_to_view, tensor_to_view_mut};
8use crate::shape_helpers::{broadcast_tensor_to_shape, validate_shape_broadcastable};
9use crate::{
10 CpuBackend, CpuContext, MetadataBinaryOp, MetadataConstantValue, MetadataDType,
11 MetadataGenerateOp, MetadataPrimsDescriptor, MetadataReductionOp, MetadataTensorMut,
12 MetadataTensorRef, MetadataTernaryOp, TensorMetadataPrims,
13};
14
15fn tensor_dims_ref<'a>(tensor: &'a MetadataTensorRef<'a>) -> &'a [usize] {
16 match tensor {
17 MetadataTensorRef::I32(tensor) => tensor.dims(),
18 MetadataTensorRef::Bool(tensor) => tensor.dims(),
19 }
20}
21
22fn tensor_dims_mut<'a>(tensor: &'a MetadataTensorMut<'a>) -> &'a [usize] {
23 match tensor {
24 MetadataTensorMut::I32(tensor) => tensor.dims(),
25 MetadataTensorMut::Bool(tensor) => tensor.dims(),
26 }
27}
28
29fn validate_supported_generate(op: MetadataGenerateOp, output_dtype: MetadataDType) -> Result<()> {
30 match (op, output_dtype) {
31 (MetadataGenerateOp::IotaStartZero, MetadataDType::I32)
32 | (MetadataGenerateOp::Constant(MetadataConstantValue::I32(_)), MetadataDType::I32)
33 | (MetadataGenerateOp::Constant(MetadataConstantValue::Bool(_)), MetadataDType::Bool) => {
34 Ok(())
35 }
36 _ => Err(Error::InvalidArgument(format!(
37 "unsupported metadata generate dtype combination: op={op:?} dst={output_dtype:?}"
38 ))),
39 }
40}
41
42fn validate_supported_binary(
43 op: MetadataBinaryOp,
44 lhs_dtype: MetadataDType,
45 rhs_dtype: MetadataDType,
46 output_dtype: MetadataDType,
47) -> Result<()> {
48 match (op, lhs_dtype, rhs_dtype, output_dtype) {
49 (
50 MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
51 MetadataDType::I32,
52 MetadataDType::I32,
53 MetadataDType::Bool,
54 )
55 | (
56 MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
57 MetadataDType::Bool,
58 MetadataDType::Bool,
59 MetadataDType::Bool,
60 )
61 | (
62 MetadataBinaryOp::Add | MetadataBinaryOp::Sub | MetadataBinaryOp::Mul,
63 MetadataDType::I32,
64 MetadataDType::I32,
65 MetadataDType::I32,
66 )
67 | (
68 MetadataBinaryOp::BitAnd,
69 MetadataDType::I32,
70 MetadataDType::I32,
71 MetadataDType::I32,
72 )
73 | (
74 MetadataBinaryOp::BitAnd,
75 MetadataDType::Bool,
76 MetadataDType::Bool,
77 MetadataDType::Bool,
78 ) => Ok(()),
79 _ => Err(Error::InvalidArgument(format!(
80 "unsupported metadata binary dtype combination: op={op:?} lhs={lhs_dtype:?} rhs={rhs_dtype:?} dst={output_dtype:?}"
81 ))),
82 }
83}
84
85fn validate_supported_ternary(
86 op: MetadataTernaryOp,
87 cond_dtype: MetadataDType,
88 lhs_dtype: MetadataDType,
89 rhs_dtype: MetadataDType,
90 output_dtype: MetadataDType,
91) -> Result<()> {
92 if !matches!(op, MetadataTernaryOp::Where) {
93 return Err(Error::InvalidArgument(format!(
94 "metadata ternary operation {op:?} is not supported on CpuBackend"
95 )));
96 }
97 match (cond_dtype, lhs_dtype, rhs_dtype, output_dtype) {
98 (MetadataDType::Bool, MetadataDType::I32, MetadataDType::I32, MetadataDType::I32)
99 | (MetadataDType::Bool, MetadataDType::Bool, MetadataDType::Bool, MetadataDType::Bool) =>
100 {
101 Ok(())
102 }
103 _ => Err(Error::InvalidArgument(format!(
104 "unsupported metadata ternary dtype combination: cond={cond_dtype:?} lhs={lhs_dtype:?} rhs={rhs_dtype:?} dst={output_dtype:?}"
105 ))),
106 }
107}
108
109fn validate_supported_reduction(
110 op: MetadataReductionOp,
111 input_dtype: MetadataDType,
112 output_dtype: MetadataDType,
113) -> Result<()> {
114 match (op, input_dtype, output_dtype) {
115 (MetadataReductionOp::Sum, MetadataDType::Bool, MetadataDType::I32)
116 | (MetadataReductionOp::Sum, MetadataDType::I32, MetadataDType::I32)
117 | (MetadataReductionOp::All, MetadataDType::Bool, MetadataDType::Bool)
118 | (MetadataReductionOp::Any, MetadataDType::Bool, MetadataDType::Bool) => Ok(()),
119 _ => Err(Error::InvalidArgument(format!(
120 "unsupported metadata reduction dtype combination: op={op:?} input={input_dtype:?} dst={output_dtype:?}"
121 ))),
122 }
123}
124
125fn validate_unique_mode_labels(modes: &[u32], label: &str) -> Result<()> {
126 let mut seen = BTreeSet::new();
127 for &mode in modes {
128 if !seen.insert(mode) {
129 return Err(Error::InvalidArgument(format!(
130 "{label} contains duplicate mode label {mode}"
131 )));
132 }
133 }
134 Ok(())
135}
136
137fn validate_metadata_handle_count(
138 inputs: &[MetadataTensorRef<'_>],
139 expected: usize,
140 op_name: &str,
141) -> Result<()> {
142 if inputs.len() != expected {
143 return Err(Error::InvalidArgument(format!(
144 "{op_name} expects {expected} input(s) (got {})",
145 inputs.len()
146 )));
147 }
148 Ok(())
149}
150
151fn execute_metadata_generate_i32(output: &mut StridedViewMut<i32>) -> Result<()> {
152 let dims = output.dims().to_vec();
153 let mut value = 0i32;
154 crate::for_each_index(&dims, |idx| {
155 output.set(idx, value);
156 value = value
157 .checked_add(1)
158 .expect("metadata iota overflow should be validated at plan time");
159 });
160 Ok(())
161}
162
163fn execute_metadata_generate_constant<T: Copy>(
164 output: &mut StridedViewMut<T>,
165 value: T,
166) -> Result<()> {
167 let dims = output.dims().to_vec();
168 crate::for_each_index(&dims, |idx| {
169 output.set(idx, value);
170 });
171 Ok(())
172}
173
174fn execute_metadata_binary_map<Lhs, Rhs, OutputT, F>(
175 lhs: &StridedView<Lhs>,
176 rhs: &StridedView<Rhs>,
177 output: &mut StridedViewMut<OutputT>,
178 f: F,
179) -> Result<()>
180where
181 Lhs: Copy,
182 Rhs: Copy,
183 OutputT: Copy,
184 F: Fn(Lhs, Rhs) -> OutputT + Copy,
185{
186 let dims = output.dims().to_vec();
187 crate::for_each_index(&dims, |idx| {
188 output.set(idx, f(lhs.get(idx), rhs.get(idx)));
189 });
190 Ok(())
191}
192
193fn execute_metadata_ternary_map<Lhs, F>(
194 cond: &StridedView<u8>,
195 on_true: &StridedView<Lhs>,
196 on_false: &StridedView<Lhs>,
197 output: &mut StridedViewMut<Lhs>,
198 f: F,
199) -> Result<()>
200where
201 Lhs: Copy,
202 F: Fn(u8, Lhs, Lhs) -> Lhs + Copy,
203{
204 let dims = output.dims().to_vec();
205 crate::for_each_index(&dims, |idx| {
206 output.set(idx, f(cond.get(idx), on_true.get(idx), on_false.get(idx)));
207 });
208 Ok(())
209}
210
211fn execute_metadata_reduce_sum_i32(
212 input: &StridedView<i32>,
213 output: &mut StridedViewMut<i32>,
214 kept_axes: &[usize],
215 reduced_axes: &[usize],
216) -> Result<()> {
217 let in_dims = input.dims().to_vec();
218 let out_dims = output.dims().to_vec();
219 let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
220 let (kept_axis_positions, reduced_axis_positions) =
221 build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
222 let reduced_total: usize = reduced_dims.iter().product();
223 let mut red_idx = vec![0usize; reduced_dims.len()];
224 let mut in_idx = vec![0usize; in_dims.len()];
225
226 crate::for_each_index(&out_dims, |out_idx| {
227 let mut sum = 0i32;
228 for red_flat in 0..reduced_total {
229 crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
230 build_metadata_reduction_input_index(
231 out_idx,
232 &red_idx,
233 &kept_axis_positions,
234 &reduced_axis_positions,
235 &mut in_idx,
236 );
237 sum += input.get(&in_idx);
238 }
239 output.set(out_idx, sum);
240 });
241
242 Ok(())
243}
244
245fn execute_metadata_reduce_sum_bool(
246 input: &StridedView<u8>,
247 output: &mut StridedViewMut<i32>,
248 kept_axes: &[usize],
249 reduced_axes: &[usize],
250) -> Result<()> {
251 let in_dims = input.dims().to_vec();
252 let out_dims = output.dims().to_vec();
253 let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
254 let (kept_axis_positions, reduced_axis_positions) =
255 build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
256 let reduced_total: usize = reduced_dims.iter().product();
257 let mut red_idx = vec![0usize; reduced_dims.len()];
258 let mut in_idx = vec![0usize; in_dims.len()];
259
260 crate::for_each_index(&out_dims, |out_idx| {
261 let mut sum = 0i32;
262 for red_flat in 0..reduced_total {
263 crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
264 build_metadata_reduction_input_index(
265 out_idx,
266 &red_idx,
267 &kept_axis_positions,
268 &reduced_axis_positions,
269 &mut in_idx,
270 );
271 sum += if input.get(&in_idx) != 0 { 1 } else { 0 };
272 }
273 output.set(out_idx, sum);
274 });
275
276 Ok(())
277}
278
279fn execute_metadata_reduce_all_bool(
280 input: &StridedView<u8>,
281 output: &mut StridedViewMut<u8>,
282 kept_axes: &[usize],
283 reduced_axes: &[usize],
284) -> Result<()> {
285 let in_dims = input.dims().to_vec();
286 let out_dims = output.dims().to_vec();
287 let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
288 let (kept_axis_positions, reduced_axis_positions) =
289 build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
290 let reduced_total: usize = reduced_dims.iter().product();
291 let mut red_idx = vec![0usize; reduced_dims.len()];
292 let mut in_idx = vec![0usize; in_dims.len()];
293
294 crate::for_each_index(&out_dims, |out_idx| {
295 let mut all_true = true;
296 for red_flat in 0..reduced_total {
297 crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
298 build_metadata_reduction_input_index(
299 out_idx,
300 &red_idx,
301 &kept_axis_positions,
302 &reduced_axis_positions,
303 &mut in_idx,
304 );
305 if input.get(&in_idx) == 0 {
306 all_true = false;
307 break;
308 }
309 }
310 output.set(out_idx, if all_true { 1 } else { 0 });
311 });
312
313 Ok(())
314}
315
316fn execute_metadata_reduce_any_bool(
317 input: &StridedView<u8>,
318 output: &mut StridedViewMut<u8>,
319 kept_axes: &[usize],
320 reduced_axes: &[usize],
321) -> Result<()> {
322 let in_dims = input.dims().to_vec();
323 let out_dims = output.dims().to_vec();
324 let reduced_dims: Vec<usize> = reduced_axes.iter().map(|&ax| in_dims[ax]).collect();
325 let (kept_axis_positions, reduced_axis_positions) =
326 build_metadata_reduction_axis_positions(in_dims.len(), kept_axes, reduced_axes);
327 let reduced_total: usize = reduced_dims.iter().product();
328 let mut red_idx = vec![0usize; reduced_dims.len()];
329 let mut in_idx = vec![0usize; in_dims.len()];
330
331 crate::for_each_index(&out_dims, |out_idx| {
332 let mut any_true = false;
333 for red_flat in 0..reduced_total {
334 crate::cpu::common::unflatten_index_into(red_flat, &reduced_dims, &mut red_idx);
335 build_metadata_reduction_input_index(
336 out_idx,
337 &red_idx,
338 &kept_axis_positions,
339 &reduced_axis_positions,
340 &mut in_idx,
341 );
342 if input.get(&in_idx) != 0 {
343 any_true = true;
344 break;
345 }
346 }
347 output.set(out_idx, if any_true { 1 } else { 0 });
348 });
349
350 Ok(())
351}
352
353fn build_metadata_reduction_axis_positions(
354 rank: usize,
355 kept_axes: &[usize],
356 reduced_axes: &[usize],
357) -> (Vec<Option<usize>>, Vec<Option<usize>>) {
358 let mut kept_axis_positions = vec![None; rank];
359 for (output_axis, &input_axis) in kept_axes.iter().enumerate() {
360 kept_axis_positions[input_axis] = Some(output_axis);
361 }
362
363 let mut reduced_axis_positions = vec![None; rank];
364 for (reduced_axis, &input_axis) in reduced_axes.iter().enumerate() {
365 reduced_axis_positions[input_axis] = Some(reduced_axis);
366 }
367
368 (kept_axis_positions, reduced_axis_positions)
369}
370
371fn build_metadata_reduction_input_index(
372 out_idx: &[usize],
373 red_idx: &[usize],
374 kept_axis_positions: &[Option<usize>],
375 reduced_axis_positions: &[Option<usize>],
376 in_idx: &mut [usize],
377) {
378 for (axis, slot) in in_idx.iter_mut().enumerate() {
379 if let Some(output_axis) = kept_axis_positions[axis] {
380 *slot = out_idx[output_axis];
381 } else if let Some(reduced_axis) = reduced_axis_positions[axis] {
382 *slot = red_idx[reduced_axis];
383 } else {
384 unreachable!("metadata reduction axis {axis} missing from kept/reduced axis lists");
385 }
386 }
387}
388
389fn plan_metadata_reduction(
390 input_dims: &[usize],
391 output_dims: &[usize],
392 desc: &MetadataPrimsDescriptor,
393) -> Result<(Vec<usize>, Vec<usize>)> {
394 let MetadataPrimsDescriptor::Reduction {
395 modes_a, modes_c, ..
396 } = desc
397 else {
398 return Err(Error::InvalidArgument(
399 "expected metadata reduction descriptor".into(),
400 ));
401 };
402 validate_unique_mode_labels(modes_a, "CpuMetadataReduction input")?;
403 validate_unique_mode_labels(modes_c, "CpuMetadataReduction output")?;
404 let reduction = plan_reduction(
405 modes_a,
406 modes_c,
407 &[input_dims, output_dims],
408 "CpuMetadataReduction",
409 )?;
410 let kept_axes = modes_c
411 .iter()
412 .map(|mode| {
413 modes_a
414 .iter()
415 .position(|&candidate| candidate == *mode)
416 .ok_or_else(|| {
417 Error::InvalidArgument(format!(
418 "CpuMetadataReduction: output mode {mode} not found in input modes {modes_a:?}"
419 ))
420 })
421 })
422 .collect::<Result<Vec<_>>>()?;
423 Ok((kept_axes, reduction.reduced_axes))
424}
425
426impl TensorMetadataPrims for CpuBackend {
427 type Plan = MetadataPrimsDescriptor;
428 type Context = CpuContext;
429
430 fn plan(
431 _ctx: &mut Self::Context,
432 desc: &MetadataPrimsDescriptor,
433 inputs: &[MetadataTensorRef<'_>],
434 output: MetadataTensorMut<'_>,
435 ) -> Result<Self::Plan> {
436 match desc {
437 MetadataPrimsDescriptor::Generate { op, output_dtype } => {
438 validate_supported_generate(*op, *output_dtype)?;
439 if !inputs.is_empty() {
440 return Err(Error::InvalidArgument(
441 "metadata generate expects no inputs".into(),
442 ));
443 }
444 match (*op, output) {
445 (MetadataGenerateOp::IotaStartZero, MetadataTensorMut::I32(_))
446 | (
447 MetadataGenerateOp::Constant(MetadataConstantValue::I32(_)),
448 MetadataTensorMut::I32(_),
449 )
450 | (
451 MetadataGenerateOp::Constant(MetadataConstantValue::Bool(_)),
452 MetadataTensorMut::Bool(_),
453 ) => Ok(desc.clone()),
454 _ => Err(Error::InvalidArgument(
455 "metadata generate output dtype does not match payload".into(),
456 )),
457 }
458 }
459 MetadataPrimsDescriptor::Binary {
460 op,
461 lhs_dtype,
462 rhs_dtype,
463 output_dtype,
464 } => {
465 validate_supported_binary(*op, *lhs_dtype, *rhs_dtype, *output_dtype)?;
466 validate_metadata_handle_count(inputs, 2, "CpuMetadataBinary")?;
467 validate_shape_broadcastable(
468 tensor_dims_ref(&inputs[0]),
469 tensor_dims_mut(&output),
470 "CpuMetadataBinary lhs",
471 )?;
472 validate_shape_broadcastable(
473 tensor_dims_ref(&inputs[1]),
474 tensor_dims_mut(&output),
475 "CpuMetadataBinary rhs",
476 )?;
477 Ok(desc.clone())
478 }
479 MetadataPrimsDescriptor::Ternary {
480 op,
481 cond_dtype,
482 lhs_dtype,
483 rhs_dtype,
484 output_dtype,
485 } => {
486 validate_supported_ternary(
487 *op,
488 *cond_dtype,
489 *lhs_dtype,
490 *rhs_dtype,
491 *output_dtype,
492 )?;
493 validate_metadata_handle_count(inputs, 3, "CpuMetadataTernary")?;
494 if !matches!(*output_dtype, MetadataDType::I32 | MetadataDType::Bool) {
495 return Err(Error::InvalidArgument(
496 "unsupported metadata ternary output dtype".into(),
497 ));
498 }
499 validate_shape_broadcastable(
500 tensor_dims_ref(&inputs[0]),
501 tensor_dims_mut(&output),
502 "CpuMetadataTernary cond",
503 )?;
504 validate_shape_broadcastable(
505 tensor_dims_ref(&inputs[1]),
506 tensor_dims_mut(&output),
507 "CpuMetadataTernary true",
508 )?;
509 validate_shape_broadcastable(
510 tensor_dims_ref(&inputs[2]),
511 tensor_dims_mut(&output),
512 "CpuMetadataTernary false",
513 )?;
514 Ok(desc.clone())
515 }
516 MetadataPrimsDescriptor::Reduction {
517 input_dtype,
518 output_dtype,
519 op,
520 ..
521 } => {
522 validate_supported_reduction(*op, *input_dtype, *output_dtype)?;
523 validate_metadata_handle_count(inputs, 1, "CpuMetadataReduction")?;
524 let _ = plan_metadata_reduction(
525 tensor_dims_ref(&inputs[0]),
526 tensor_dims_mut(&output),
527 desc,
528 )?;
529 Ok(desc.clone())
530 }
531 }
532 }
533
534 fn execute(
535 _ctx: &mut Self::Context,
536 plan: &Self::Plan,
537 inputs: &[MetadataTensorRef<'_>],
538 output: MetadataTensorMut<'_>,
539 ) -> Result<()> {
540 match plan {
541 MetadataPrimsDescriptor::Generate {
542 op: MetadataGenerateOp::IotaStartZero,
543 output_dtype: MetadataDType::I32,
544 } => match output {
545 MetadataTensorMut::I32(output) => {
546 let mut output = tensor_to_view_mut(output)?;
547 execute_metadata_generate_i32(&mut output)?;
548 Ok(())
549 }
550 MetadataTensorMut::Bool(_) => Err(Error::InvalidArgument(
551 "metadata iota currently supports I32 output only".into(),
552 )),
553 },
554 MetadataPrimsDescriptor::Generate {
555 op: MetadataGenerateOp::Constant(MetadataConstantValue::I32(value)),
556 output_dtype: MetadataDType::I32,
557 } => match output {
558 MetadataTensorMut::I32(output) => {
559 let mut output = tensor_to_view_mut(output)?;
560 execute_metadata_generate_constant(&mut output, *value)?;
561 Ok(())
562 }
563 MetadataTensorMut::Bool(_) => Err(Error::InvalidArgument(
564 "metadata constant output dtype does not match payload".into(),
565 )),
566 },
567 MetadataPrimsDescriptor::Generate {
568 op: MetadataGenerateOp::Constant(MetadataConstantValue::Bool(value)),
569 output_dtype: MetadataDType::Bool,
570 } => match output {
571 MetadataTensorMut::Bool(output) => {
572 let mut output = tensor_to_view_mut(output)?;
573 execute_metadata_generate_constant(&mut output, if *value { 1 } else { 0 })?;
574 Ok(())
575 }
576 MetadataTensorMut::I32(_) => Err(Error::InvalidArgument(
577 "metadata constant output dtype does not match payload".into(),
578 )),
579 },
580 MetadataPrimsDescriptor::Binary {
581 op,
582 lhs_dtype,
583 rhs_dtype,
584 output_dtype,
585 } => {
586 validate_metadata_handle_count(inputs, 2, "CpuMetadataBinary")?;
587 match (
588 *lhs_dtype,
589 *rhs_dtype,
590 *output_dtype,
591 inputs[0],
592 inputs[1],
593 output,
594 ) {
595 (
596 MetadataDType::I32,
597 MetadataDType::I32,
598 MetadataDType::I32,
599 MetadataTensorRef::I32(lhs),
600 MetadataTensorRef::I32(rhs),
601 MetadataTensorMut::I32(dst),
602 ) => {
603 let lhs =
604 broadcast_tensor_to_shape(lhs, dst.dims(), "CpuMetadataBinary lhs")?;
605 let rhs =
606 broadcast_tensor_to_shape(rhs, dst.dims(), "CpuMetadataBinary rhs")?;
607 let lhs = tensor_to_view(&lhs)?;
608 let rhs = tensor_to_view(&rhs)?;
609 let mut dst = tensor_to_view_mut(dst)?;
610 execute_metadata_binary_map(&lhs, &rhs, &mut dst, |x, y| match *op {
611 MetadataBinaryOp::Add => x + y,
612 MetadataBinaryOp::Sub => x - y,
613 MetadataBinaryOp::Mul => x * y,
614 MetadataBinaryOp::BitAnd => x & y,
615 _ => unreachable!("unsupported metadata binary op"),
616 })
617 }
618 (
619 MetadataDType::I32,
620 MetadataDType::I32,
621 MetadataDType::Bool,
622 MetadataTensorRef::I32(lhs),
623 MetadataTensorRef::I32(rhs),
624 MetadataTensorMut::Bool(dst),
625 ) => {
626 let lhs =
627 broadcast_tensor_to_shape(lhs, dst.dims(), "CpuMetadataBinary lhs")?;
628 let rhs =
629 broadcast_tensor_to_shape(rhs, dst.dims(), "CpuMetadataBinary rhs")?;
630 let lhs = tensor_to_view(&lhs)?;
631 let rhs = tensor_to_view(&rhs)?;
632 let mut dst = tensor_to_view_mut(dst)?;
633 execute_metadata_binary_map(&lhs, &rhs, &mut dst, |x, y| {
634 let mapped = match *op {
635 MetadataBinaryOp::Equal => x == y,
636 MetadataBinaryOp::NotEqual => x != y,
637 _ => unreachable!("unsupported metadata binary op"),
638 };
639 if mapped {
640 1
641 } else {
642 0
643 }
644 })
645 }
646 (
647 MetadataDType::Bool,
648 MetadataDType::Bool,
649 MetadataDType::Bool,
650 MetadataTensorRef::Bool(lhs),
651 MetadataTensorRef::Bool(rhs),
652 MetadataTensorMut::Bool(dst),
653 ) => {
654 let lhs =
655 broadcast_tensor_to_shape(lhs, dst.dims(), "CpuMetadataBinary lhs")?;
656 let rhs =
657 broadcast_tensor_to_shape(rhs, dst.dims(), "CpuMetadataBinary rhs")?;
658 let lhs = tensor_to_view(&lhs)?;
659 let rhs = tensor_to_view(&rhs)?;
660 let mut dst = tensor_to_view_mut(dst)?;
661 execute_metadata_binary_map(&lhs, &rhs, &mut dst, |x, y| {
662 let mapped = match *op {
663 MetadataBinaryOp::Equal => (x != 0) == (y != 0),
664 MetadataBinaryOp::NotEqual => (x != 0) != (y != 0),
665 MetadataBinaryOp::BitAnd => (x != 0) && (y != 0),
666 _ => unreachable!("unsupported metadata binary op"),
667 };
668 if mapped {
669 1
670 } else {
671 0
672 }
673 })
674 }
675 _ => Err(Error::InvalidArgument(
676 "unsupported metadata binary execution dtype combination".into(),
677 )),
678 }
679 }
680 MetadataPrimsDescriptor::Ternary {
681 op: MetadataTernaryOp::Where,
682 cond_dtype,
683 lhs_dtype,
684 rhs_dtype,
685 output_dtype,
686 } => {
687 validate_metadata_handle_count(inputs, 3, "CpuMetadataTernary")?;
688 match (
689 *cond_dtype,
690 *lhs_dtype,
691 *rhs_dtype,
692 *output_dtype,
693 inputs[0],
694 inputs[1],
695 inputs[2],
696 output,
697 ) {
698 (
699 MetadataDType::Bool,
700 MetadataDType::I32,
701 MetadataDType::I32,
702 MetadataDType::I32,
703 MetadataTensorRef::Bool(cond),
704 MetadataTensorRef::I32(on_true),
705 MetadataTensorRef::I32(on_false),
706 MetadataTensorMut::I32(dst),
707 ) => {
708 let cond =
709 broadcast_tensor_to_shape(cond, dst.dims(), "CpuMetadataTernary cond")?;
710 let on_true = broadcast_tensor_to_shape(
711 on_true,
712 dst.dims(),
713 "CpuMetadataTernary true",
714 )?;
715 let on_false = broadcast_tensor_to_shape(
716 on_false,
717 dst.dims(),
718 "CpuMetadataTernary false",
719 )?;
720 let cond = tensor_to_view(&cond)?;
721 let on_true = tensor_to_view(&on_true)?;
722 let on_false = tensor_to_view(&on_false)?;
723 let mut dst = tensor_to_view_mut(dst)?;
724 execute_metadata_ternary_map(
725 &cond,
726 &on_true,
727 &on_false,
728 &mut dst,
729 |c, t, f| {
730 if c != 0 {
731 t
732 } else {
733 f
734 }
735 },
736 )
737 }
738 (
739 MetadataDType::Bool,
740 MetadataDType::Bool,
741 MetadataDType::Bool,
742 MetadataDType::Bool,
743 MetadataTensorRef::Bool(cond),
744 MetadataTensorRef::Bool(on_true),
745 MetadataTensorRef::Bool(on_false),
746 MetadataTensorMut::Bool(dst),
747 ) => {
748 let cond =
749 broadcast_tensor_to_shape(cond, dst.dims(), "CpuMetadataTernary cond")?;
750 let on_true = broadcast_tensor_to_shape(
751 on_true,
752 dst.dims(),
753 "CpuMetadataTernary true",
754 )?;
755 let on_false = broadcast_tensor_to_shape(
756 on_false,
757 dst.dims(),
758 "CpuMetadataTernary false",
759 )?;
760 let cond = tensor_to_view(&cond)?;
761 let on_true = tensor_to_view(&on_true)?;
762 let on_false = tensor_to_view(&on_false)?;
763 let mut dst = tensor_to_view_mut(dst)?;
764 execute_metadata_ternary_map(
765 &cond,
766 &on_true,
767 &on_false,
768 &mut dst,
769 |c, t, f| {
770 if c != 0 {
771 t
772 } else {
773 f
774 }
775 },
776 )
777 }
778 _ => Err(Error::InvalidArgument(
779 "unsupported metadata ternary execution dtype combination".into(),
780 )),
781 }
782 }
783 MetadataPrimsDescriptor::Reduction {
784 op,
785 input_dtype,
786 output_dtype,
787 ..
788 } => {
789 validate_metadata_handle_count(inputs, 1, "CpuMetadataReduction")?;
790 let (kept_axes, reduced_axes) = plan_metadata_reduction(
791 tensor_dims_ref(&inputs[0]),
792 tensor_dims_mut(&output),
793 plan,
794 )?;
795 match (*op, *input_dtype, *output_dtype, inputs[0], output) {
796 (
797 MetadataReductionOp::Sum,
798 MetadataDType::I32,
799 MetadataDType::I32,
800 MetadataTensorRef::I32(input),
801 MetadataTensorMut::I32(output),
802 ) => {
803 let input = tensor_to_view(input)?;
804 let mut output = tensor_to_view_mut(output)?;
805 execute_metadata_reduce_sum_i32(
806 &input,
807 &mut output,
808 &kept_axes,
809 &reduced_axes,
810 )
811 }
812 (
813 MetadataReductionOp::Sum,
814 MetadataDType::Bool,
815 MetadataDType::I32,
816 MetadataTensorRef::Bool(input),
817 MetadataTensorMut::I32(output),
818 ) => {
819 let input = tensor_to_view(input)?;
820 let mut output = tensor_to_view_mut(output)?;
821 execute_metadata_reduce_sum_bool(
822 &input,
823 &mut output,
824 &kept_axes,
825 &reduced_axes,
826 )
827 }
828 (
829 MetadataReductionOp::All,
830 MetadataDType::Bool,
831 MetadataDType::Bool,
832 MetadataTensorRef::Bool(input),
833 MetadataTensorMut::Bool(output),
834 ) => {
835 let input = tensor_to_view(input)?;
836 let mut output = tensor_to_view_mut(output)?;
837 execute_metadata_reduce_all_bool(
838 &input,
839 &mut output,
840 &kept_axes,
841 &reduced_axes,
842 )
843 }
844 (
845 MetadataReductionOp::Any,
846 MetadataDType::Bool,
847 MetadataDType::Bool,
848 MetadataTensorRef::Bool(input),
849 MetadataTensorMut::Bool(output),
850 ) => {
851 let input = tensor_to_view(input)?;
852 let mut output = tensor_to_view_mut(output)?;
853 execute_metadata_reduce_any_bool(
854 &input,
855 &mut output,
856 &kept_axes,
857 &reduced_axes,
858 )
859 }
860 _ => Err(Error::InvalidArgument(
861 "unsupported metadata reduction execution dtype combination".into(),
862 )),
863 }
864 }
865 _ => Err(Error::InvalidArgument(
866 "unsupported metadata metadata descriptor on CpuBackend".into(),
867 )),
868 }
869 }
870
871 fn has_metadata_support(desc: MetadataPrimsDescriptor) -> bool {
872 match desc {
873 MetadataPrimsDescriptor::Generate { op, output_dtype } => matches!(
874 (op, output_dtype),
875 (MetadataGenerateOp::IotaStartZero, MetadataDType::I32)
876 | (
877 MetadataGenerateOp::Constant(MetadataConstantValue::I32(_)),
878 MetadataDType::I32
879 )
880 | (
881 MetadataGenerateOp::Constant(MetadataConstantValue::Bool(_)),
882 MetadataDType::Bool
883 )
884 ),
885 MetadataPrimsDescriptor::Binary {
886 op,
887 lhs_dtype,
888 rhs_dtype,
889 output_dtype,
890 } => matches!(
891 (op, lhs_dtype, rhs_dtype, output_dtype),
892 (
893 MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
894 MetadataDType::I32,
895 MetadataDType::I32,
896 MetadataDType::Bool
897 ) | (
898 MetadataBinaryOp::Equal | MetadataBinaryOp::NotEqual,
899 MetadataDType::Bool,
900 MetadataDType::Bool,
901 MetadataDType::Bool
902 ) | (
903 MetadataBinaryOp::Add | MetadataBinaryOp::Sub | MetadataBinaryOp::Mul,
904 MetadataDType::I32,
905 MetadataDType::I32,
906 MetadataDType::I32
907 ) | (
908 MetadataBinaryOp::BitAnd,
909 MetadataDType::I32,
910 MetadataDType::I32,
911 MetadataDType::I32
912 ) | (
913 MetadataBinaryOp::BitAnd,
914 MetadataDType::Bool,
915 MetadataDType::Bool,
916 MetadataDType::Bool
917 )
918 ),
919 MetadataPrimsDescriptor::Ternary {
920 op: MetadataTernaryOp::Where,
921 cond_dtype: MetadataDType::Bool,
922 lhs_dtype,
923 rhs_dtype,
924 output_dtype,
925 } => matches!(
926 (lhs_dtype, rhs_dtype, output_dtype),
927 (MetadataDType::I32, MetadataDType::I32, MetadataDType::I32)
928 | (
929 MetadataDType::Bool,
930 MetadataDType::Bool,
931 MetadataDType::Bool
932 )
933 ),
934 MetadataPrimsDescriptor::Reduction {
935 op,
936 input_dtype,
937 output_dtype,
938 ..
939 } => matches!(
940 (op, input_dtype, output_dtype),
941 (
942 MetadataReductionOp::Sum,
943 MetadataDType::Bool,
944 MetadataDType::I32
945 ) | (
946 MetadataReductionOp::Sum,
947 MetadataDType::I32,
948 MetadataDType::I32
949 ) | (
950 MetadataReductionOp::All,
951 MetadataDType::Bool,
952 MetadataDType::Bool
953 ) | (
954 MetadataReductionOp::Any,
955 MetadataDType::Bool,
956 MetadataDType::Bool
957 )
958 ),
959 _ => false,
960 }
961 }
962}