tenferro_prims/cpu/
planning.rs

1use std::collections::HashSet;
2use std::marker::PhantomData;
3
4use tenferro_algebra::{Scalar, Standard};
5use tenferro_device::{Error, Result};
6
7use crate::{
8    mode_position, validate_rank, validate_shape_count, validate_shape_eq, SemiringBinaryOp,
9    SemiringCoreDescriptor, SemiringFastPathDescriptor, TensorSemiringCore, TensorSemiringFastPath,
10};
11
12use super::context::CpuBackend;
13use super::context::CpuContext;
14use super::execution::execute_semiring_plan;
15use super::plan::{build_contract_gemm_spec, compute_paired_components, CpuPlan};
16
17impl CpuBackend {
18    pub(super) fn build_semiring_core_plan<T: Scalar>(
19        desc: &SemiringCoreDescriptor,
20        shapes: &[&[usize]],
21    ) -> Result<CpuPlan<T>> {
22        match desc {
23            SemiringCoreDescriptor::BatchedGemm {
24                batch_dims,
25                m,
26                n,
27                k,
28            } => {
29                validate_shape_count(shapes, 3, "BatchedGemm")?;
30                if !Self::supports_batched_gemm_type::<T>() {
31                    return Err(Error::InvalidArgument(format!(
32                        "BatchedGemm supports only f32, f64, Complex32, and Complex64 (got {})",
33                        std::any::type_name::<T>()
34                    )));
35                }
36                let expected_a: Vec<usize> = [*m, *k]
37                    .iter()
38                    .copied()
39                    .chain(batch_dims.iter().copied())
40                    .collect();
41                let expected_b: Vec<usize> = [*k, *n]
42                    .iter()
43                    .copied()
44                    .chain(batch_dims.iter().copied())
45                    .collect();
46                let expected_c: Vec<usize> = [*m, *n]
47                    .iter()
48                    .copied()
49                    .chain(batch_dims.iter().copied())
50                    .collect();
51                validate_shape_eq(shapes[0], &expected_a, "BatchedGemm input A")?;
52                validate_shape_eq(shapes[1], &expected_b, "BatchedGemm input B")?;
53                validate_shape_eq(shapes[2], &expected_c, "BatchedGemm output C")?;
54                Ok(CpuPlan::BatchedGemm {
55                    batch_dims: batch_dims.clone(),
56                    m: *m,
57                    n: *n,
58                    k: *k,
59                    _marker: PhantomData,
60                })
61            }
62            SemiringCoreDescriptor::ReduceAdd { modes_a, modes_c } => {
63                validate_shape_count(shapes, 2, "ReduceAdd")?;
64                validate_rank(shapes[0], modes_a.len(), "ReduceAdd input A")?;
65                validate_rank(shapes[1], modes_c.len(), "ReduceAdd output C")?;
66                let reduced_axes: Vec<usize> = modes_a
67                    .iter()
68                    .enumerate()
69                    .filter(|(_, mode)| !modes_c.contains(mode))
70                    .map(|(idx, _)| idx)
71                    .collect();
72                for window in reduced_axes.windows(2) {
73                    if window[0] >= window[1] {
74                        return Err(Error::InvalidArgument(format!(
75                            "ReduceAdd: reduced_axes must be sorted and unique, got {reduced_axes:?}"
76                        )));
77                    }
78                }
79                if let Some(&last) = reduced_axes.last() {
80                    if last >= modes_a.len() {
81                        return Err(Error::InvalidArgument(format!(
82                            "ReduceAdd: reduced axis {last} out of range for rank {}",
83                            modes_a.len()
84                        )));
85                    }
86                }
87                Ok(CpuPlan::ReduceAdd {
88                    reduced_axes,
89                    _marker: PhantomData,
90                })
91            }
92            SemiringCoreDescriptor::Trace {
93                modes_a,
94                modes_c,
95                paired,
96            } => {
97                validate_shape_count(shapes, 2, "Trace")?;
98                validate_rank(shapes[0], modes_a.len(), "Trace input A")?;
99                validate_rank(shapes[1], modes_c.len(), "Trace output C")?;
100                let paired_axes: Vec<(usize, usize)> = paired
101                    .iter()
102                    .map(|(m1, m2)| {
103                        Ok((mode_position(modes_a, *m1)?, mode_position(modes_a, *m2)?))
104                    })
105                    .collect::<Result<_>>()?;
106                for &(ax1, ax2) in &paired_axes {
107                    if shapes[0][ax1] != shapes[0][ax2] {
108                        return Err(Error::InvalidArgument(format!(
109                            "Trace paired axes ({ax1}, {ax2}) have mismatched dimensions: {} vs {}",
110                            shapes[0][ax1], shapes[0][ax2]
111                        )));
112                    }
113                }
114                let free_axes: Vec<usize> = modes_c
115                    .iter()
116                    .map(|mode| mode_position(modes_a, *mode))
117                    .collect::<Result<_>>()?;
118                let (components, comp_dims) = compute_paired_components(&paired_axes, shapes[0]);
119                Ok(CpuPlan::Trace {
120                    free_axes,
121                    components,
122                    comp_dims,
123                    _marker: PhantomData,
124                })
125            }
126            SemiringCoreDescriptor::AntiTrace {
127                modes_a,
128                modes_c,
129                paired,
130            } => {
131                validate_shape_count(shapes, 2, "AntiTrace")?;
132                validate_rank(shapes[0], modes_a.len(), "AntiTrace input A")?;
133                validate_rank(shapes[1], modes_c.len(), "AntiTrace output C")?;
134                let paired_axes: Vec<(usize, usize)> = paired
135                    .iter()
136                    .map(|(m1, m2)| {
137                        Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?))
138                    })
139                    .collect::<Result<_>>()?;
140                for &(ax1, ax2) in &paired_axes {
141                    if shapes[1][ax1] != shapes[1][ax2] {
142                        return Err(Error::InvalidArgument(format!(
143                            "AntiTrace paired axes ({ax1}, {ax2}) have mismatched dimensions: {} vs {}",
144                            shapes[1][ax1], shapes[1][ax2]
145                        )));
146                    }
147                }
148                let free_axes: Vec<usize> = modes_a
149                    .iter()
150                    .map(|mode| mode_position(modes_c, *mode))
151                    .collect::<Result<_>>()?;
152                let (components, comp_dims) = compute_paired_components(&paired_axes, shapes[1]);
153                Ok(CpuPlan::AntiTrace {
154                    paired_axes,
155                    free_axes,
156                    components,
157                    comp_dims,
158                    _marker: PhantomData,
159                })
160            }
161            SemiringCoreDescriptor::AntiDiag {
162                modes_a,
163                modes_c,
164                paired,
165            } => {
166                validate_shape_count(shapes, 2, "AntiDiag")?;
167                validate_rank(shapes[0], modes_a.len(), "AntiDiag input A")?;
168                validate_rank(shapes[1], modes_c.len(), "AntiDiag output C")?;
169                let paired_axes: Vec<(usize, usize)> = paired
170                    .iter()
171                    .map(|(m1, m2)| {
172                        Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?))
173                    })
174                    .collect::<Result<_>>()?;
175                let free_axes: Vec<usize> = modes_a
176                    .iter()
177                    .map(|mode| mode_position(modes_c, *mode))
178                    .collect::<Result<_>>()?;
179                let (components, comp_dims) = compute_paired_components(&paired_axes, shapes[1]);
180                let free_ax_set: HashSet<usize> = free_axes.iter().copied().collect();
181                let generative_comps: Vec<usize> = components
182                    .iter()
183                    .enumerate()
184                    .filter(|(_, comp)| comp.iter().all(|ax| !free_ax_set.contains(ax)))
185                    .map(|(idx, _)| idx)
186                    .collect();
187                Ok(CpuPlan::AntiDiag {
188                    paired_axes,
189                    free_axes,
190                    components,
191                    comp_dims,
192                    generative_comps,
193                    _marker: PhantomData,
194                })
195            }
196            SemiringCoreDescriptor::MakeContiguous => {
197                validate_shape_count(shapes, 2, "MakeContiguous")?;
198                validate_shape_eq(shapes[1], shapes[0], "MakeContiguous output")?;
199                Ok(CpuPlan::MakeContiguous {
200                    _marker: PhantomData,
201                })
202            }
203        }
204    }
205
206    pub(super) fn build_semiring_fast_path_plan<T: Scalar>(
207        desc: &SemiringFastPathDescriptor,
208        shapes: &[&[usize]],
209    ) -> Result<CpuPlan<T>> {
210        match desc {
211            SemiringFastPathDescriptor::Contract {
212                modes_a,
213                modes_b,
214                modes_c,
215            } => {
216                validate_shape_count(shapes, 3, "Contract")?;
217                validate_rank(shapes[0], modes_a.len(), "Contract input A")?;
218                validate_rank(shapes[1], modes_b.len(), "Contract input B")?;
219                validate_rank(shapes[2], modes_c.len(), "Contract output C")?;
220                for (a_pos, &mode) in modes_a.iter().enumerate() {
221                    if let Some(b_pos) = modes_b.iter().position(|&m| m == mode) {
222                        if shapes[0][a_pos] != shapes[1][b_pos] {
223                            return Err(Error::InvalidArgument(format!(
224                                "Contract mode {mode} has mismatched dimensions: A={} vs B={}",
225                                shapes[0][a_pos], shapes[1][b_pos]
226                            )));
227                        }
228                    }
229                }
230                let gemm_spec = build_contract_gemm_spec(modes_a, modes_b, modes_c);
231                Ok(CpuPlan::Contract {
232                    modes_a: modes_a.clone(),
233                    modes_b: modes_b.clone(),
234                    modes_c: modes_c.clone(),
235                    gemm_spec,
236                    _marker: PhantomData,
237                })
238            }
239            SemiringFastPathDescriptor::ElementwiseBinary { op } => {
240                validate_shape_count(shapes, 3, "ElementwiseBinary")?;
241                validate_shape_eq(shapes[1], shapes[0], "ElementwiseBinary input B")?;
242                validate_shape_eq(shapes[2], shapes[0], "ElementwiseBinary output C")?;
243                Ok(CpuPlan::ElementwiseBinary {
244                    op: *op,
245                    _marker: PhantomData,
246                })
247            }
248        }
249    }
250}
251
252impl<S: Scalar> TensorSemiringCore<Standard<S>> for CpuBackend {
253    type Plan = CpuPlan<S>;
254    type Context = CpuContext;
255
256    fn plan(
257        ctx: &mut CpuContext,
258        desc: &SemiringCoreDescriptor,
259        shapes: &[&[usize]],
260    ) -> Result<CpuPlan<S>> {
261        if let Some(cached) = ctx
262            .plan_cache
263            .get::<CpuPlan<S>, SemiringCoreDescriptor>(desc, shapes)
264        {
265            return Ok(cached);
266        }
267
268        let plan = Self::build_semiring_core_plan::<S>(desc, shapes)?;
269        ctx.plan_cache.insert(desc, shapes, plan.clone());
270        Ok(plan)
271    }
272
273    fn execute(
274        ctx: &mut CpuContext,
275        plan: &CpuPlan<S>,
276        alpha: S,
277        inputs: &[&tenferro_tensor::Tensor<S>],
278        beta: S,
279        output: &mut tenferro_tensor::Tensor<S>,
280    ) -> Result<()> {
281        execute_semiring_plan(ctx, plan, alpha, inputs, beta, output)
282    }
283}
284
285impl<S: Scalar> TensorSemiringFastPath<Standard<S>> for CpuBackend {
286    type Plan = CpuPlan<S>;
287    type Context = CpuContext;
288
289    fn plan(
290        ctx: &mut CpuContext,
291        desc: &SemiringFastPathDescriptor,
292        shapes: &[&[usize]],
293    ) -> Result<CpuPlan<S>> {
294        if let Some(cached) = ctx
295            .plan_cache
296            .get::<CpuPlan<S>, SemiringFastPathDescriptor>(desc, shapes)
297        {
298            return Ok(cached);
299        }
300
301        let plan = Self::build_semiring_fast_path_plan::<S>(desc, shapes)?;
302        ctx.plan_cache.insert(desc, shapes, plan.clone());
303        Ok(plan)
304    }
305
306    fn execute(
307        ctx: &mut CpuContext,
308        plan: &CpuPlan<S>,
309        alpha: S,
310        inputs: &[&tenferro_tensor::Tensor<S>],
311        beta: S,
312        output: &mut tenferro_tensor::Tensor<S>,
313    ) -> Result<()> {
314        execute_semiring_plan(ctx, plan, alpha, inputs, beta, output)
315    }
316
317    fn has_fast_path(desc: SemiringFastPathDescriptor) -> bool {
318        matches!(
319            desc,
320            SemiringFastPathDescriptor::Contract { .. }
321                | SemiringFastPathDescriptor::ElementwiseBinary {
322                    op: SemiringBinaryOp::Add | SemiringBinaryOp::Mul,
323                }
324        )
325    }
326}