Skip to main content

tenferro_tensor/cpu/gemm/
mod.rs

1use num_traits::{One, Zero};
2use smallvec::SmallVec;
3
4use crate::buffer_pool::{BufferPool, PoolScalar};
5use crate::config::DotGeneralConfig;
6use crate::cpu::structural::typed_transpose;
7use crate::types::{col_major_strides, Buffer, TypedTensor};
8use crate::Error;
9
10#[cfg(feature = "cpu-blas")]
11mod blas_gemm;
12#[cfg(feature = "cpu-faer")]
13mod faer_gemm;
14
15#[cfg(feature = "cpu-blas")]
16use blas_gemm::BlasGemm;
17#[cfg(feature = "cpu-faer")]
18use faer_gemm::FaerGemm;
19
20struct GemmDims {
21    m: usize,
22    n: usize,
23    k: usize,
24    batch_total: usize,
25    a_rs: isize,
26    a_cs: isize,
27    a_bs: isize,
28    b_rs: isize,
29    b_cs: isize,
30    b_bs: isize,
31    c_rs: isize,
32    c_cs: isize,
33    c_bs: isize,
34    out_shape: SmallVec<[usize; 8]>,
35}
36
37fn validate_axis_list(
38    op: &'static str,
39    role: &'static str,
40    axes: &[usize],
41    rank: usize,
42) -> crate::Result<()> {
43    let mut seen: SmallVec<[bool; 8]> = smallvec::smallvec![false; rank];
44    for &axis in axes {
45        if axis >= rank {
46            return Err(Error::AxisOutOfBounds { op, axis, rank });
47        }
48        if seen[axis] {
49            return Err(Error::DuplicateAxis { op, axis, role });
50        }
51        seen[axis] = true;
52    }
53    Ok(())
54}
55
56fn validate_role_disjoint(
57    op: &'static str,
58    first_role: &'static str,
59    first_axes: &[usize],
60    second_role: &'static str,
61    second_axes: &[usize],
62) -> crate::Result<()> {
63    for &axis in first_axes {
64        if second_axes.contains(&axis) {
65            return Err(Error::AxisRoleConflict {
66                op,
67                axis,
68                first_role,
69                second_role,
70            });
71        }
72    }
73    Ok(())
74}
75
76fn validate_dot_general<T>(
77    lhs: &TypedTensor<T>,
78    rhs: &TypedTensor<T>,
79    config: &DotGeneralConfig,
80) -> crate::Result<()> {
81    const OP: &str = "dot_general";
82
83    if config.lhs_contracting_dims.len() != config.rhs_contracting_dims.len() {
84        return Err(Error::InvalidConfig {
85            op: OP,
86            message: "lhs/rhs contracting dim counts differ".into(),
87        });
88    }
89    if config.lhs_batch_dims.len() != config.rhs_batch_dims.len() {
90        return Err(Error::InvalidConfig {
91            op: OP,
92            message: "lhs/rhs batch dim counts differ".into(),
93        });
94    }
95
96    let lhs_rank = lhs.shape.len();
97    let rhs_rank = rhs.shape.len();
98    validate_axis_list(
99        OP,
100        "lhs_contracting",
101        &config.lhs_contracting_dims,
102        lhs_rank,
103    )?;
104    validate_axis_list(
105        OP,
106        "rhs_contracting",
107        &config.rhs_contracting_dims,
108        rhs_rank,
109    )?;
110    validate_axis_list(OP, "lhs_batch", &config.lhs_batch_dims, lhs_rank)?;
111    validate_axis_list(OP, "rhs_batch", &config.rhs_batch_dims, rhs_rank)?;
112    validate_role_disjoint(
113        OP,
114        "lhs_contracting",
115        &config.lhs_contracting_dims,
116        "lhs_batch",
117        &config.lhs_batch_dims,
118    )?;
119    validate_role_disjoint(
120        OP,
121        "rhs_contracting",
122        &config.rhs_contracting_dims,
123        "rhs_batch",
124        &config.rhs_batch_dims,
125    )?;
126
127    for (&lhs_axis, &rhs_axis) in config
128        .lhs_contracting_dims
129        .iter()
130        .zip(&config.rhs_contracting_dims)
131    {
132        if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
133            return Err(Error::InvalidConfig {
134                op: OP,
135                message: format!(
136                    "contracting dim size mismatch: lhs axis {lhs_axis}={} rhs axis {rhs_axis}={}",
137                    lhs.shape[lhs_axis], rhs.shape[rhs_axis]
138                ),
139            });
140        }
141    }
142    for (&lhs_axis, &rhs_axis) in config.lhs_batch_dims.iter().zip(&config.rhs_batch_dims) {
143        if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
144            return Err(Error::InvalidConfig {
145                op: OP,
146                message: format!(
147                    "batch dim size mismatch: lhs axis {lhs_axis}={} rhs axis {rhs_axis}={}",
148                    lhs.shape[lhs_axis], rhs.shape[rhs_axis]
149                ),
150            });
151        }
152    }
153
154    Ok(())
155}
156
157fn try_fuse_dims(shapes: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
158    if shapes.is_empty() {
159        return Some((1, 0));
160    }
161    if shapes.len() == 1 {
162        return Some((shapes[0], strides[0]));
163    }
164    let mut dims: SmallVec<[(usize, isize); 8]> = shapes
165        .iter()
166        .copied()
167        .zip(strides.iter().copied())
168        .collect();
169    dims.sort_by_key(|&(_, stride)| stride.unsigned_abs());
170    let base_stride = dims[0].1;
171    let mut expected = base_stride;
172    for (shape, stride) in dims {
173        if stride != expected {
174            return None;
175        }
176        expected = stride.checked_mul(shape as isize)?;
177    }
178    Some((shapes.iter().product(), base_stride))
179}
180
181fn stride_sort_order(strides: &[isize]) -> SmallVec<[usize; 8]> {
182    let mut order: SmallVec<[usize; 8]> = (0..strides.len()).collect();
183    order.sort_by_key(|&idx| strides[idx].unsigned_abs());
184    order
185}
186
187fn is_identity_order(order: &[usize]) -> bool {
188    order.iter().enumerate().all(|(idx, &value)| idx == value)
189}
190
191/// Compute permutations that reorder lhs/rhs into canonical GEMM layout.
192///
193/// Canonical col-major layouts (batch trailing):
194/// - lhs: `[free..., contract..., batch...]`
195/// - rhs: `[contract..., free..., batch...]`
196fn canonical_gemm_layout(
197    config: &DotGeneralConfig,
198    lhs_rank: usize,
199    rhs_rank: usize,
200) -> (SmallVec<[usize; 8]>, SmallVec<[usize; 8]>, DotGeneralConfig) {
201    let lhs_free: SmallVec<[usize; 8]> = (0..lhs_rank)
202        .filter(|d| !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d))
203        .collect();
204    let rhs_free: SmallVec<[usize; 8]> = (0..rhs_rank)
205        .filter(|d| !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d))
206        .collect();
207
208    let mut lhs_perm = SmallVec::<[usize; 8]>::with_capacity(lhs_rank);
209    lhs_perm.extend_from_slice(&lhs_free);
210    lhs_perm.extend_from_slice(&config.lhs_contracting_dims);
211    lhs_perm.extend_from_slice(&config.lhs_batch_dims);
212
213    let mut rhs_perm = SmallVec::<[usize; 8]>::with_capacity(rhs_rank);
214    rhs_perm.extend_from_slice(&config.rhs_contracting_dims);
215    rhs_perm.extend_from_slice(&rhs_free);
216    rhs_perm.extend_from_slice(&config.rhs_batch_dims);
217
218    let nf_lhs = lhs_free.len();
219    let nc = config.lhs_contracting_dims.len();
220    let nb = config.lhs_batch_dims.len();
221    let nf_rhs = rhs_free.len();
222
223    let new_config = DotGeneralConfig {
224        lhs_contracting_dims: (nf_lhs..nf_lhs + nc).collect(),
225        rhs_contracting_dims: (0..nc).collect(),
226        lhs_batch_dims: (nf_lhs + nc..nf_lhs + nc + nb).collect(),
227        rhs_batch_dims: (nc + nf_rhs..nc + nf_rhs + nb).collect(),
228    };
229
230    (lhs_perm, rhs_perm, new_config)
231}
232
233fn is_identity_perm(perm: &[usize]) -> bool {
234    perm.iter().enumerate().all(|(i, &p)| i == p)
235}
236
237fn analyse_gemm<T>(
238    lhs: &TypedTensor<T>,
239    rhs: &TypedTensor<T>,
240    config: &DotGeneralConfig,
241) -> Option<GemmDims> {
242    let lhs_rank = lhs.shape.len();
243    let rhs_rank = rhs.shape.len();
244
245    let lhs_free: SmallVec<[usize; 8]> = (0..lhs_rank)
246        .filter(|d| !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d))
247        .collect();
248    let rhs_free: SmallVec<[usize; 8]> = (0..rhs_rank)
249        .filter(|d| !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d))
250        .collect();
251
252    let lhs_strides: SmallVec<[isize; 8]> = col_major_strides(&lhs.shape).into_iter().collect();
253    let rhs_strides: SmallVec<[isize; 8]> = col_major_strides(&rhs.shape).into_iter().collect();
254
255    let batch_shapes: SmallVec<[usize; 8]> = config
256        .lhs_batch_dims
257        .iter()
258        .map(|&d| lhs.shape[d])
259        .collect();
260    let batch_total: usize = batch_shapes.iter().product();
261
262    let lhs_free_shapes: SmallVec<[usize; 8]> = lhs_free.iter().map(|&d| lhs.shape[d]).collect();
263    let rhs_free_shapes: SmallVec<[usize; 8]> = rhs_free.iter().map(|&d| rhs.shape[d]).collect();
264    let contract_shapes: SmallVec<[usize; 8]> = config
265        .lhs_contracting_dims
266        .iter()
267        .map(|&d| lhs.shape[d])
268        .collect();
269
270    let m: usize = lhs_free_shapes.iter().product();
271    let n: usize = rhs_free_shapes.iter().product();
272    let k: usize = contract_shapes.iter().product();
273
274    let lhs_free_strides: SmallVec<[isize; 8]> = lhs_free.iter().map(|&d| lhs_strides[d]).collect();
275    let rhs_free_strides: SmallVec<[isize; 8]> = rhs_free.iter().map(|&d| rhs_strides[d]).collect();
276    let lhs_contract_strides: SmallVec<[isize; 8]> = config
277        .lhs_contracting_dims
278        .iter()
279        .map(|&d| lhs_strides[d])
280        .collect();
281    let rhs_contract_strides: SmallVec<[isize; 8]> = config
282        .rhs_contracting_dims
283        .iter()
284        .map(|&d| rhs_strides[d])
285        .collect();
286    let lhs_batch_strides: SmallVec<[isize; 8]> = config
287        .lhs_batch_dims
288        .iter()
289        .map(|&d| lhs_strides[d])
290        .collect();
291    let rhs_batch_strides: SmallVec<[isize; 8]> = config
292        .rhs_batch_dims
293        .iter()
294        .map(|&d| rhs_strides[d])
295        .collect();
296
297    if !is_identity_order(&stride_sort_order(&lhs_free_strides))
298        || !is_identity_order(&stride_sort_order(&rhs_free_strides))
299        || !is_identity_order(&stride_sort_order(&lhs_batch_strides))
300        || !is_identity_order(&stride_sort_order(&rhs_batch_strides))
301        || stride_sort_order(&lhs_contract_strides) != stride_sort_order(&rhs_contract_strides)
302    {
303        return None;
304    }
305
306    let (_, a_rs) = try_fuse_dims(&lhs_free_shapes, &lhs_free_strides)?;
307    let (_, a_cs) = try_fuse_dims(&contract_shapes, &lhs_contract_strides)?;
308    let (_, b_rs) = try_fuse_dims(&contract_shapes, &rhs_contract_strides)?;
309    let (_, b_cs) = try_fuse_dims(&rhs_free_shapes, &rhs_free_strides)?;
310    let (_, a_bs) = try_fuse_dims(&batch_shapes, &lhs_batch_strides)?;
311    let (_, b_bs) = try_fuse_dims(&batch_shapes, &rhs_batch_strides)?;
312
313    let mut out_shape = SmallVec::<[usize; 8]>::new();
314    out_shape.extend_from_slice(&lhs_free_shapes);
315    out_shape.extend_from_slice(&rhs_free_shapes);
316    out_shape.extend_from_slice(&batch_shapes);
317
318    let out_strides: SmallVec<[isize; 8]> = col_major_strides(&out_shape).into_iter().collect();
319    let nm = lhs_free_shapes.len();
320    let nn = rhs_free_shapes.len();
321    let out_m_shapes = &out_shape[..nm];
322    let out_m_strides = &out_strides[..nm];
323    let out_n_shapes = &out_shape[nm..nm + nn];
324    let out_n_strides = &out_strides[nm..nm + nn];
325    let out_b_shapes = &out_shape[nm + nn..];
326    let out_b_strides = &out_strides[nm + nn..];
327
328    let (_, c_rs) = try_fuse_dims(out_m_shapes, out_m_strides)?;
329    let (_, c_cs) = try_fuse_dims(out_n_shapes, out_n_strides)?;
330    let (_, c_bs) = try_fuse_dims(out_b_shapes, out_b_strides)?;
331
332    Some(GemmDims {
333        m,
334        n,
335        k,
336        batch_total,
337        a_rs,
338        a_cs,
339        a_bs,
340        b_rs,
341        b_cs,
342        b_bs,
343        c_rs,
344        c_cs,
345        c_bs,
346        out_shape,
347    })
348}
349
350#[cfg(feature = "cpu-faer")]
351pub(crate) fn dot_general<T>(
352    buffers: &mut BufferPool,
353    ctx: &crate::cpu::CpuContext,
354    lhs: &TypedTensor<T>,
355    rhs: &TypedTensor<T>,
356    config: &DotGeneralConfig,
357) -> crate::Result<TypedTensor<T>>
358where
359    T: FaerGemm + PoolScalar + Copy + Clone + Zero + One + PartialEq,
360{
361    validate_dot_general(lhs, rhs, config)?;
362    if let Some(result) = typed_faer_gemm(buffers, ctx, lhs, rhs, config) {
363        return Ok(result);
364    }
365    let (lhs_perm, rhs_perm, new_config) =
366        canonical_gemm_layout(config, lhs.shape.len(), rhs.shape.len());
367    let lhs_canon = if is_identity_perm(&lhs_perm) {
368        std::borrow::Cow::Borrowed(lhs)
369    } else {
370        std::borrow::Cow::Owned(typed_transpose(lhs, &lhs_perm)?)
371    };
372    let rhs_canon = if is_identity_perm(&rhs_perm) {
373        std::borrow::Cow::Borrowed(rhs)
374    } else {
375        std::borrow::Cow::Owned(typed_transpose(rhs, &rhs_perm)?)
376    };
377    typed_faer_gemm(buffers, ctx, &lhs_canon, &rhs_canon, &new_config).ok_or_else(|| {
378        Error::BackendFailure {
379            op: "dot_general",
380            message: "CPU GEMM requires host-backed canonical inputs".into(),
381        }
382    })
383}
384
385#[cfg(feature = "cpu-faer")]
386fn typed_faer_gemm<T>(
387    buffers: &mut BufferPool,
388    ctx: &crate::cpu::CpuContext,
389    lhs: &TypedTensor<T>,
390    rhs: &TypedTensor<T>,
391    config: &DotGeneralConfig,
392) -> Option<TypedTensor<T>>
393where
394    T: FaerGemm + PoolScalar + Copy + Clone + Zero + One + PartialEq,
395{
396    let dims = analyse_gemm(lhs, rhs, config)?;
397    let out_n: usize = dims.out_shape.iter().product();
398    if dims.m == 0 || dims.n == 0 || dims.k == 0 || dims.batch_total == 0 {
399        return Some(TypedTensor {
400            buffer: Buffer::Host(vec![T::zero(); out_n]),
401            shape: dims.out_shape.into_vec(),
402            placement: lhs.placement.clone(),
403        });
404    }
405
406    let a_data = match &lhs.buffer {
407        Buffer::Host(v) => v.as_ptr(),
408        Buffer::Backend(_) => return None,
409        #[cfg(feature = "cubecl")]
410        Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
411    };
412    let b_data = match &rhs.buffer {
413        Buffer::Host(v) => v.as_ptr(),
414        Buffer::Backend(_) => return None,
415        #[cfg(feature = "cubecl")]
416        Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
417    };
418
419    // SAFETY: this GEMM path uses beta = 0 and overwrites every output element.
420    let mut out_data: Vec<T> = unsafe { T::pool_acquire(buffers, out_n) };
421    let c_ptr = out_data.as_mut_ptr();
422
423    for batch in 0..dims.batch_total {
424        let a_off = batch as isize * dims.a_bs;
425        let b_off = batch as isize * dims.b_bs;
426        let c_off = batch as isize * dims.c_bs;
427        unsafe {
428            T::strided_gemm(
429                ctx,
430                T::one(),
431                a_data.offset(a_off),
432                dims.m,
433                dims.k,
434                dims.a_rs,
435                dims.a_cs,
436                b_data.offset(b_off),
437                dims.n,
438                dims.b_rs,
439                dims.b_cs,
440                T::zero(),
441                c_ptr.offset(c_off),
442                dims.c_rs,
443                dims.c_cs,
444            );
445        }
446    }
447
448    Some(TypedTensor {
449        buffer: Buffer::Host(out_data),
450        shape: dims.out_shape.into_vec(),
451        placement: lhs.placement.clone(),
452    })
453}
454
455#[cfg(feature = "cpu-blas")]
456pub(crate) fn dot_general<T>(
457    buffers: &mut BufferPool,
458    lhs: &TypedTensor<T>,
459    rhs: &TypedTensor<T>,
460    config: &DotGeneralConfig,
461) -> crate::Result<TypedTensor<T>>
462where
463    T: BlasGemm + PoolScalar + Copy + Clone + Zero + One,
464{
465    validate_dot_general(lhs, rhs, config)?;
466    if let Some(result) = typed_blas_gemm(buffers, lhs, rhs, config) {
467        return Ok(result);
468    }
469    let (lhs_perm, rhs_perm, new_config) =
470        canonical_gemm_layout(config, lhs.shape.len(), rhs.shape.len());
471    let lhs_canon = if is_identity_perm(&lhs_perm) {
472        std::borrow::Cow::Borrowed(lhs)
473    } else {
474        std::borrow::Cow::Owned(typed_transpose(lhs, &lhs_perm)?)
475    };
476    let rhs_canon = if is_identity_perm(&rhs_perm) {
477        std::borrow::Cow::Borrowed(rhs)
478    } else {
479        std::borrow::Cow::Owned(typed_transpose(rhs, &rhs_perm)?)
480    };
481    typed_blas_gemm(buffers, &lhs_canon, &rhs_canon, &new_config).ok_or_else(|| {
482        Error::BackendFailure {
483            op: "dot_general",
484            message: "CPU GEMM requires host-backed canonical inputs".into(),
485        }
486    })
487}
488
489#[cfg(feature = "cpu-blas")]
490fn typed_blas_gemm<T>(
491    buffers: &mut BufferPool,
492    lhs: &TypedTensor<T>,
493    rhs: &TypedTensor<T>,
494    config: &DotGeneralConfig,
495) -> Option<TypedTensor<T>>
496where
497    T: BlasGemm + PoolScalar + Copy + Clone + Zero + One,
498{
499    let dims = analyse_gemm(lhs, rhs, config)?;
500    let out_n: usize = dims.out_shape.iter().product();
501    if dims.m == 0 || dims.n == 0 || dims.k == 0 || dims.batch_total == 0 {
502        return Some(TypedTensor {
503            buffer: Buffer::Host(vec![T::zero(); out_n]),
504            shape: dims.out_shape.into_vec(),
505            placement: lhs.placement.clone(),
506        });
507    }
508
509    let a_ok = dims.a_rs == 1 || dims.a_cs == 1;
510    let b_ok = dims.b_rs == 1 || dims.b_cs == 1;
511    let c_ok = dims.c_rs == 1;
512    if !a_ok || !b_ok || !c_ok {
513        return None;
514    }
515
516    let a_data = match &lhs.buffer {
517        Buffer::Host(v) => v.as_ptr(),
518        Buffer::Backend(_) => return None,
519        #[cfg(feature = "cubecl")]
520        Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
521    };
522    let b_data = match &rhs.buffer {
523        Buffer::Host(v) => v.as_ptr(),
524        Buffer::Backend(_) => return None,
525        #[cfg(feature = "cubecl")]
526        Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
527    };
528
529    // SAFETY: each batch GEMM writes its full output block with beta = 0.
530    let mut out: Vec<T> = unsafe { T::pool_acquire(buffers, out_n) };
531    let c_ptr = out.as_mut_ptr();
532
533    for batch in 0..dims.batch_total {
534        let a_off = batch as isize * dims.a_bs;
535        let b_off = batch as isize * dims.b_bs;
536        let c_off = batch as isize * dims.c_bs;
537        unsafe {
538            T::strided_gemm(
539                T::one(),
540                a_data.offset(a_off),
541                dims.m,
542                dims.k,
543                dims.a_rs,
544                dims.a_cs,
545                b_data.offset(b_off),
546                dims.n,
547                dims.b_rs,
548                dims.b_cs,
549                T::zero(),
550                c_ptr.offset(c_off),
551                dims.c_rs,
552                dims.c_cs,
553            );
554        }
555    }
556
557    Some(TypedTensor {
558        buffer: Buffer::Host(out),
559        shape: dims.out_shape.into_vec(),
560        placement: lhs.placement.clone(),
561    })
562}
563
564#[cfg(test)]
565mod tests;