Skip to main content

strided_einsum2/
contiguous.rs

1//! GEMM-ready operand types and preparation functions for contiguous data.
2//!
3//! These types encapsulate the logic for preparing strided operands for GEMM:
4//! checking fusability, copying to col-major buffers when needed, and managing
5//! the writeback for borrowed output operands.
6
7use crate::ScalarBase;
8use std::any::{Any, TypeId};
9use std::cell::RefCell;
10use std::collections::HashMap;
11use strided_view::{RawStridedMut, RawStridedRef, StridedArray, StridedView, StridedViewMut};
12
13/// GEMM-ready input operand with contiguous data.
14pub struct ContiguousOperand<T: Copy + 'static> {
15    ptr: *const T,
16    row_stride: isize,
17    col_stride: isize,
18    batch_strides: Vec<isize>,
19    conj: bool,
20    /// Owns the buffer if a copy was made or input was consumed.
21    pub(crate) _buf: Option<StridedArray<T>>,
22    buf_is_pooled: bool,
23}
24
25/// GEMM-ready output operand with contiguous data.
26pub struct ContiguousOperandMut<T: Copy + 'static> {
27    ptr: *mut T,
28    row_stride: isize,
29    col_stride: isize,
30    batch_strides: Vec<isize>,
31    /// Whether the caller must copy the buffer back to the original destination
32    /// after GEMM completes (true only for borrowed non-contiguous C).
33    needs_writeback: bool,
34    /// Owns the buffer if a copy was made.
35    pub(crate) _buf: Option<StridedArray<T>>,
36    buf_is_pooled: bool,
37}
38
39thread_local! {
40    static BUFFER_POOL: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
41}
42
43const MAX_POOL_PER_TYPE: usize = 16;
44const MAX_POOLED_BYTES: usize = 64 * 1024 * 1024;
45
46fn take_pooled_vec_uninit<T: Copy + 'static>(len: usize) -> Vec<T> {
47    BUFFER_POOL.with(|pool| {
48        let mut pool = pool.borrow_mut();
49        let entry = pool
50            .entry(TypeId::of::<T>())
51            .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
52        let vecs = entry
53            .downcast_mut::<Vec<Vec<T>>>()
54            .expect("buffer pool type mismatch");
55
56        let mut best_idx = None;
57        let mut best_cap = usize::MAX;
58        for (idx, v) in vecs.iter().enumerate() {
59            let cap = v.capacity();
60            if cap >= len && cap < best_cap {
61                best_idx = Some(idx);
62                best_cap = cap;
63            }
64        }
65
66        let mut data = best_idx
67            .map(|idx| vecs.swap_remove(idx))
68            .unwrap_or_else(|| Vec::with_capacity(len));
69        if data.capacity() < len {
70            data.reserve(len - data.capacity());
71        }
72        unsafe { data.set_len(len) };
73        data
74    })
75}
76
77fn return_pooled_vec<T: Copy + 'static>(mut data: Vec<T>) {
78    let bytes = data.capacity().saturating_mul(std::mem::size_of::<T>());
79    if bytes == 0 || bytes > MAX_POOLED_BYTES {
80        return;
81    }
82    data.clear();
83    BUFFER_POOL.with(|pool| {
84        let mut pool = pool.borrow_mut();
85        let entry = pool
86            .entry(TypeId::of::<T>())
87            .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
88        let vecs = entry
89            .downcast_mut::<Vec<Vec<T>>>()
90            .expect("buffer pool type mismatch");
91        if vecs.len() >= MAX_POOL_PER_TYPE {
92            if let Some((min_idx, min_cap)) = vecs
93                .iter()
94                .enumerate()
95                .map(|(i, v)| (i, v.capacity()))
96                .min_by_key(|(_, cap)| *cap)
97            {
98                if min_cap < data.capacity() {
99                    vecs.swap_remove(min_idx);
100                    vecs.push(data);
101                }
102            }
103        } else {
104            vecs.push(data);
105        }
106    });
107}
108
109fn alloc_col_major_uninit_with_pool<T: Copy + 'static>(dims: &[usize]) -> (StridedArray<T>, bool) {
110    let total: usize = dims.iter().product::<usize>().max(1);
111    let bytes = total.saturating_mul(std::mem::size_of::<T>());
112    if bytes == 0 || bytes > MAX_POOLED_BYTES {
113        return (alloc_col_major_uninit(dims), false);
114    }
115    let data = take_pooled_vec_uninit::<T>(total);
116    let arr = unsafe { StridedArray::col_major_from_buffer_uninit(data, dims) };
117    (arr, true)
118}
119
120/// Allocate a col-major buffer, optionally reusing from the thread-local pool.
121fn alloc_maybe_pooled<T: Copy + 'static>(
122    dims: &[usize],
123    use_pool: bool,
124) -> (StridedArray<T>, bool) {
125    if use_pool {
126        alloc_col_major_uninit_with_pool(dims)
127    } else {
128        (alloc_col_major_uninit(dims), false)
129    }
130}
131
132#[cfg(test)]
133fn pooled_count_for_type<T: 'static>() -> usize {
134    BUFFER_POOL.with(|pool| {
135        let mut pool = pool.borrow_mut();
136        let Some(entry) = pool.get_mut(&TypeId::of::<T>()) else {
137            return 0;
138        };
139        entry
140            .downcast_mut::<Vec<Vec<T>>>()
141            .map_or(0, |vecs| vecs.len())
142    })
143}
144
145impl<T: Copy + 'static> ContiguousOperand<T> {
146    /// Raw const pointer to the operand data at the base offset.
147    #[inline]
148    pub fn ptr(&self) -> *const T {
149        self.ptr
150    }
151
152    /// Row (lo-group) stride for the fused 2D matrix.
153    #[inline]
154    pub fn row_stride(&self) -> isize {
155        self.row_stride
156    }
157
158    /// Column (sum/ro-group) stride for the fused 2D matrix.
159    #[inline]
160    pub fn col_stride(&self) -> isize {
161        self.col_stride
162    }
163
164    /// Batch dimension strides.
165    #[inline]
166    pub fn batch_strides(&self) -> &[isize] {
167        &self.batch_strides
168    }
169
170    /// Whether this operand requires conjugation.
171    #[inline]
172    pub fn conj(&self) -> bool {
173        self.conj
174    }
175
176    /// Returns `true` if this operand owns a buffer (copy was made or ownership transferred).
177    #[cfg(test)]
178    #[inline]
179    pub(crate) fn has_buf(&self) -> bool {
180        self._buf.is_some()
181    }
182}
183
184impl<T: Copy + 'static> ContiguousOperandMut<T> {
185    /// Raw mutable pointer to the operand data at the base offset.
186    #[inline]
187    pub fn ptr(&self) -> *mut T {
188        self.ptr
189    }
190
191    /// Row (lo-group) stride for the fused 2D matrix.
192    #[inline]
193    pub fn row_stride(&self) -> isize {
194        self.row_stride
195    }
196
197    /// Column (ro-group) stride for the fused 2D matrix.
198    #[inline]
199    pub fn col_stride(&self) -> isize {
200        self.col_stride
201    }
202
203    /// Batch dimension strides.
204    #[inline]
205    pub fn batch_strides(&self) -> &[isize] {
206        &self.batch_strides
207    }
208
209    /// Returns `true` if this operand owns a buffer (copy was made).
210    #[cfg(test)]
211    #[inline]
212    pub(crate) fn has_buf(&self) -> bool {
213        self._buf.is_some()
214    }
215
216    /// Returns `true` if the caller must copy the buffer back to the original
217    /// destination after GEMM completes.
218    #[cfg(test)]
219    #[inline]
220    pub(crate) fn needs_writeback(&self) -> bool {
221        self.needs_writeback
222    }
223}
224
225impl<T: Copy + Send + Sync> ContiguousOperandMut<T> {
226    /// After GEMM: copy the internal buffer back to `dest` if needed.
227    ///
228    /// This is a no-op when the GEMM wrote directly to the destination
229    /// (contiguous case or owned output).
230    pub fn finalize_into(self, dest: &mut StridedViewMut<T>) -> crate::Result<()> {
231        if self.needs_writeback {
232            if let Some(ref buf) = self._buf {
233                strided_perm::copy_into(dest, &buf.view())?;
234            }
235        }
236        Ok(())
237    }
238
239    /// After GEMM: copy the internal buffer back to a raw destination if needed.
240    ///
241    /// This mirrors [`Self::finalize_into`] for prepared replay paths that use
242    /// borrowed raw layout metadata instead of owned-metadata views.
243    pub fn finalize_raw_into(self, dest: &mut RawStridedMut<'_, T>) -> crate::Result<()> {
244        if self.needs_writeback {
245            if let Some(ref buf) = self._buf {
246                let mut dest_view = dest.as_view_mut();
247                strided_perm::copy_into(&mut dest_view, &buf.view())?;
248            }
249        }
250        Ok(())
251    }
252}
253
254impl<T: Copy + 'static> Drop for ContiguousOperand<T> {
255    fn drop(&mut self) {
256        if self.buf_is_pooled {
257            if let Some(arr) = self._buf.take() {
258                return_pooled_vec(arr.into_data());
259            }
260        }
261    }
262}
263
264impl<T: Copy + 'static> Drop for ContiguousOperandMut<T> {
265    fn drop(&mut self) {
266        if self.buf_is_pooled {
267            if let Some(arr) = self._buf.take() {
268                return_pooled_vec(arr.into_data());
269            }
270        }
271    }
272}
273
274/// Result of checking whether dimension groups are contiguous enough for GEMM.
275struct ContiguityCheck {
276    fused_g1: Option<(usize, isize)>,
277    fused_g2: Option<(usize, isize)>,
278    needs_copy: bool,
279}
280
281/// Try to fuse a GEMM dimension group while preserving canonical col-major
282/// logical order. GEMM sees each group as a single 1D index, so independently
283/// fusing row-major and col-major groups would pair different logical indices.
284fn try_fuse_col_major_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
285    if dims.len() != strides.len() {
286        return None;
287    }
288    let total = dims
289        .iter()
290        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))?;
291    if dims.is_empty() {
292        return Some((1, 0));
293    }
294
295    let mut base_stride = None;
296    let mut expected_stride = None;
297    for (&dim, &stride) in dims.iter().zip(strides.iter()) {
298        if dim <= 1 {
299            continue;
300        }
301        if stride == 0 {
302            return None;
303        }
304        if let Some(expected) = expected_stride {
305            if stride != expected {
306                return None;
307            }
308        } else {
309            base_stride = Some(stride);
310        }
311        let dim = isize::try_from(dim).ok()?;
312        expected_stride = Some(stride.checked_mul(dim)?);
313    }
314
315    let stride = base_stride.unwrap_or_else(|| {
316        strides
317            .iter()
318            .copied()
319            .min_by_key(|stride| stride.unsigned_abs())
320            .unwrap_or(0)
321    });
322    Some((total, stride))
323}
324
325/// Check if two dimension groups are fusable (contiguous) for GEMM.
326///
327/// The fused logical dimension order must match the canonical axis order used
328/// by the plan. When `requires_unit_stride` is true (e.g., CBLAS backend), also
329/// checks that at least one of the fused strides is 0 or 1.
330fn check_contiguity(
331    group1_dims: &[usize],
332    group1_strides: &[isize],
333    group2_dims: &[usize],
334    group2_strides: &[isize],
335    requires_unit_stride: bool,
336) -> ContiguityCheck {
337    let fused_g1 = try_fuse_col_major_group(group1_dims, group1_strides);
338    let fused_g2 = try_fuse_col_major_group(group2_dims, group2_strides);
339
340    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
341
342    if requires_unit_stride && !needs_copy {
343        let (_, rs) = fused_g1.unwrap();
344        let (_, cs) = fused_g2.unwrap();
345        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
346            needs_copy = true;
347        }
348    }
349
350    ContiguityCheck {
351        fused_g1,
352        fused_g2,
353        needs_copy,
354    }
355}
356
357/// Compute col-major layout parameters from a freshly-allocated col-major buffer.
358///
359/// Returns `(row_stride, col_stride, batch_strides)`.
360fn col_major_layout(
361    buf: &StridedArray<impl Copy>,
362    n_group1: usize,
363    n_inner: usize,
364) -> (isize, isize, Vec<isize>) {
365    let m: usize = buf.dims()[..n_group1].iter().product::<usize>().max(1);
366    let row_stride = if m == 0 { 0 } else { 1isize };
367    let col_stride = m as isize;
368    let batch_strides = buf.strides()[n_inner..].to_vec();
369    (row_stride, col_stride, batch_strides)
370}
371
372/// Allocate a column-major StridedArray with uninitialized data.
373///
374/// With batch-last canonical order `[inner..., batch...]`, pure column-major
375/// naturally gives batch dims the largest strides — each batch slice is a
376/// contiguous column-major matrix.
377pub(crate) fn alloc_col_major_uninit<T: Copy>(dims: &[usize]) -> StridedArray<T> {
378    let total: usize = dims.iter().product::<usize>().max(1);
379    // SAFETY: `T: Copy` guarantees no drop glue, so leaving elements
380    // uninitialised is safe. Every call-site writes all elements before
381    // reading: A and B via `copy_into`, C via `copy_into` (beta != 0)
382    // or GEMM with replace semantics (beta == 0).
383    let mut data = Vec::with_capacity(total);
384    unsafe { data.set_len(total) };
385
386    // Pure column-major: stride 1 for first dim, each subsequent dim
387    // has stride = previous stride * previous dim size.
388    let mut strides = vec![0isize; dims.len()];
389    if !dims.is_empty() {
390        strides[0] = 1;
391        for i in 1..dims.len() {
392            strides[i] = strides[i - 1] * dims[i - 1] as isize;
393        }
394    }
395
396    let arr = StridedArray::from_parts(data, dims, &strides, 0).expect("col-major allocation");
397    arr
398}
399
400/// Prepare a borrowed input view for GEMM.
401///
402/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
403/// Checks if the two inner dimension groups are fusable.
404/// If not, copies to a contiguous col-major buffer.
405///
406/// - `requires_unit_stride`: backend needs at least one unit stride (e.g. CBLAS).
407/// - `use_pool`: reuse thread-local buffers to avoid repeated allocation.
408/// - `materialize_conj_fn`: when `Some(f)` and `conj == true`, applies `f` to each
409///   element during copy (for backends that cannot pass conj flags to GEMM).
410pub fn prepare_input_view<T: ScalarBase + 'static>(
411    view: &StridedView<T>,
412    n_group1: usize,
413    n_group2: usize,
414    conj: bool,
415    requires_unit_stride: bool,
416    use_pool: bool,
417    materialize_conj_fn: Option<fn(T) -> T>,
418) -> crate::Result<ContiguousOperand<T>> {
419    let dims = view.dims();
420    let strides = view.strides();
421    let n_inner = n_group1 + n_group2;
422
423    // For backends that cannot pass conjugation flags to GEMM (e.g., CBLAS),
424    // materialize conj into the data before the GEMM call.
425    if let Some(conj_fn) = materialize_conj_fn {
426        if conj {
427            let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
428            strided_kernel::map_into(&mut buf.view_mut(), view, conj_fn)?;
429            let ptr = buf.view().ptr();
430            let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
431            return Ok(ContiguousOperand {
432                ptr,
433                row_stride,
434                col_stride,
435                batch_strides,
436                conj: false,
437                _buf: Some(buf),
438                buf_is_pooled,
439            });
440        }
441    }
442
443    let check = check_contiguity(
444        &dims[..n_group1],
445        &strides[..n_group1],
446        &dims[n_group1..n_inner],
447        &strides[n_group1..n_inner],
448        requires_unit_stride,
449    );
450
451    if check.needs_copy {
452        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
453        strided_kernel::copy_into_col_major(&mut buf.view_mut(), view)?;
454        let ptr = buf.view().ptr();
455        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
456        Ok(ContiguousOperand {
457            ptr,
458            row_stride,
459            col_stride,
460            batch_strides,
461            conj,
462            _buf: Some(buf),
463            buf_is_pooled,
464        })
465    } else {
466        let (_, rs) = check.fused_g1.unwrap();
467        let (_, cs) = check.fused_g2.unwrap();
468        Ok(ContiguousOperand {
469            ptr: view.ptr(),
470            row_stride: rs,
471            col_stride: cs,
472            batch_strides: strides[n_inner..].to_vec(),
473            conj,
474            _buf: None,
475            buf_is_pooled: false,
476        })
477    }
478}
479
480/// Prepare a borrowed raw input layout for GEMM.
481///
482/// This is the raw-layout counterpart to [`prepare_input_view`]. Direct
483/// fusable layouts do not construct `StridedView` wrappers; a view is created
484/// only when a copy/materialization path needs to call a generic strided kernel.
485pub fn prepare_input_raw<T: ScalarBase + 'static>(
486    view: &RawStridedRef<'_, T>,
487    n_group1: usize,
488    n_group2: usize,
489    conj: bool,
490    requires_unit_stride: bool,
491    use_pool: bool,
492    materialize_conj_fn: Option<fn(T) -> T>,
493) -> crate::Result<ContiguousOperand<T>> {
494    let dims = view.dims();
495    let strides = view.strides();
496    let n_inner = n_group1 + n_group2;
497
498    if let Some(conj_fn) = materialize_conj_fn {
499        if conj {
500            let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
501            strided_kernel::map_into(&mut buf.view_mut(), &view.as_view(), conj_fn)?;
502            let ptr = buf.view().ptr();
503            let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
504            return Ok(ContiguousOperand {
505                ptr,
506                row_stride,
507                col_stride,
508                batch_strides,
509                conj: false,
510                _buf: Some(buf),
511                buf_is_pooled,
512            });
513        }
514    }
515
516    let check = check_contiguity(
517        &dims[..n_group1],
518        &strides[..n_group1],
519        &dims[n_group1..n_inner],
520        &strides[n_group1..n_inner],
521        requires_unit_stride,
522    );
523
524    if check.needs_copy {
525        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
526        strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
527        let ptr = buf.view().ptr();
528        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
529        Ok(ContiguousOperand {
530            ptr,
531            row_stride,
532            col_stride,
533            batch_strides,
534            conj,
535            _buf: Some(buf),
536            buf_is_pooled,
537        })
538    } else {
539        let (_, rs) = check.fused_g1.unwrap();
540        let (_, cs) = check.fused_g2.unwrap();
541        Ok(ContiguousOperand {
542            ptr: view.ptr(),
543            row_stride: rs,
544            col_stride: cs,
545            batch_strides: strides[n_inner..].to_vec(),
546            conj,
547            _buf: None,
548            buf_is_pooled: false,
549        })
550    }
551}
552
553/// Prepare an owned input array for GEMM.
554///
555/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
556/// If already contiguous after dimension grouping, transfers ownership without copying.
557/// Otherwise, copies to a new col-major buffer.
558///
559/// Parameters are the same as [`prepare_input_view`].
560pub fn prepare_input_owned<T: ScalarBase + 'static>(
561    arr: StridedArray<T>,
562    n_group1: usize,
563    n_group2: usize,
564    conj: bool,
565    requires_unit_stride: bool,
566    use_pool: bool,
567    materialize_conj_fn: Option<fn(T) -> T>,
568) -> crate::Result<ContiguousOperand<T>> {
569    let dims = arr.dims().to_vec();
570    let strides = arr.strides().to_vec();
571    let n_inner = n_group1 + n_group2;
572
573    // For backends that cannot pass conjugation flags to GEMM,
574    // materialize conj into the data before the GEMM call.
575    if let Some(conj_fn) = materialize_conj_fn {
576        if conj {
577            let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
578            strided_kernel::map_into(&mut buf.view_mut(), &arr.view(), conj_fn)?;
579            let ptr = buf.view().ptr();
580            let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
581            return Ok(ContiguousOperand {
582                ptr,
583                row_stride,
584                col_stride,
585                batch_strides,
586                conj: false,
587                _buf: Some(buf),
588                buf_is_pooled,
589            });
590        }
591    }
592
593    let check = check_contiguity(
594        &dims[..n_group1],
595        &strides[..n_group1],
596        &dims[n_group1..n_inner],
597        &strides[n_group1..n_inner],
598        requires_unit_stride,
599    );
600
601    if check.needs_copy {
602        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
603        strided_kernel::copy_into_col_major(&mut buf.view_mut(), &arr.view())?;
604        let ptr = buf.view().ptr();
605        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
606        Ok(ContiguousOperand {
607            ptr,
608            row_stride,
609            col_stride,
610            batch_strides,
611            conj,
612            _buf: Some(buf),
613            buf_is_pooled,
614        })
615    } else {
616        let (_, rs) = check.fused_g1.unwrap();
617        let (_, cs) = check.fused_g2.unwrap();
618        let ptr = arr.view().ptr();
619        Ok(ContiguousOperand {
620            ptr,
621            row_stride: rs,
622            col_stride: cs,
623            batch_strides: strides[n_inner..].to_vec(),
624            conj,
625            _buf: Some(arr),
626            buf_is_pooled: false,
627        })
628    }
629}
630
631/// Prepare a borrowed mutable output view for GEMM.
632///
633/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
634/// Checks if the two inner dimension groups (lo, ro) are fusable.
635/// If not, allocates a col-major buffer and copies the existing data into it
636/// when `beta` is non-zero (so the GEMM accumulation is correct).
637///
638/// After GEMM, call [`ContiguousOperandMut::finalize_into`] with the original
639/// view to copy results back if needed.
640///
641/// # Safety contract
642///
643/// When inner dims are fusable (no copy needed), the returned `ContiguousOperandMut`
644/// holds a raw pointer into `view`'s data. The caller must ensure `view` outlives
645/// the returned operand and that no aliasing mutable references exist during GEMM.
646pub fn prepare_output_view<T: ScalarBase + 'static>(
647    view: &mut StridedViewMut<T>,
648    n_group1: usize,
649    n_group2: usize,
650    beta: T,
651    requires_unit_stride: bool,
652    use_pool: bool,
653) -> crate::Result<ContiguousOperandMut<T>> {
654    let dims = view.dims().to_vec();
655    let strides = view.strides().to_vec();
656    let n_inner = n_group1 + n_group2;
657
658    let check = check_contiguity(
659        &dims[..n_group1],
660        &strides[..n_group1],
661        &dims[n_group1..n_inner],
662        &strides[n_group1..n_inner],
663        requires_unit_stride,
664    );
665
666    if check.needs_copy {
667        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
668        if beta != T::zero() {
669            strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
670        }
671        let ptr = buf.view_mut().as_mut_ptr();
672        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
673        Ok(ContiguousOperandMut {
674            ptr,
675            row_stride,
676            col_stride,
677            batch_strides,
678            needs_writeback: true,
679            _buf: Some(buf),
680            buf_is_pooled,
681        })
682    } else {
683        let (_, rs) = check.fused_g1.unwrap();
684        let (_, cs) = check.fused_g2.unwrap();
685        Ok(ContiguousOperandMut {
686            ptr: view.as_mut_ptr(),
687            row_stride: rs,
688            col_stride: cs,
689            batch_strides: strides[n_inner..].to_vec(),
690            needs_writeback: false,
691            _buf: None,
692            buf_is_pooled: false,
693        })
694    }
695}
696
697/// Prepare a borrowed raw output layout for GEMM.
698///
699/// This is the raw-layout counterpart to [`prepare_output_view`]. Direct
700/// fusable layouts do not construct `StridedViewMut` wrappers; a view is
701/// created only when a copy/writeback path needs a generic strided kernel.
702pub fn prepare_output_raw<T: ScalarBase + 'static>(
703    view: &mut RawStridedMut<'_, T>,
704    n_group1: usize,
705    n_group2: usize,
706    beta: T,
707    requires_unit_stride: bool,
708    use_pool: bool,
709) -> crate::Result<ContiguousOperandMut<T>> {
710    let dims = view.dims().to_vec();
711    let strides = view.strides().to_vec();
712    let n_inner = n_group1 + n_group2;
713
714    let check = check_contiguity(
715        &dims[..n_group1],
716        &strides[..n_group1],
717        &dims[n_group1..n_inner],
718        &strides[n_group1..n_inner],
719        requires_unit_stride,
720    );
721
722    if check.needs_copy {
723        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
724        if beta != T::zero() {
725            strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
726        }
727        let ptr = buf.view_mut().as_mut_ptr();
728        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
729        Ok(ContiguousOperandMut {
730            ptr,
731            row_stride,
732            col_stride,
733            batch_strides,
734            needs_writeback: true,
735            _buf: Some(buf),
736            buf_is_pooled,
737        })
738    } else {
739        let (_, rs) = check.fused_g1.unwrap();
740        let (_, cs) = check.fused_g2.unwrap();
741        Ok(ContiguousOperandMut {
742            ptr: view.as_mut_ptr(),
743            row_stride: rs,
744            col_stride: cs,
745            batch_strides: strides[n_inner..].to_vec(),
746            needs_writeback: false,
747            _buf: None,
748            buf_is_pooled: false,
749        })
750    }
751}
752
753#[cfg(test)]
754mod tests_generic_backend {
755    use super::*;
756    use crate::backend::{Backend, NaiveBackend};
757
758    #[test]
759    fn test_input_for_backend_contiguous() {
760        let a = StridedArray::<f64>::col_major(&[2, 3]);
761        let view = a.view();
762        let op = prepare_input_view(
763            &view,
764            1,
765            1,
766            false,
767            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
768            false,
769            None,
770        )
771        .unwrap();
772        assert!(op._buf.is_none());
773        assert_eq!(op.row_stride(), 1);
774        assert_eq!(op.col_stride(), 2);
775        assert!(!op.conj());
776    }
777
778    #[test]
779    fn test_input_for_backend_non_contiguous() {
780        let data = vec![0.0f64; 100];
781        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
782        let view = a.view();
783        let op = prepare_input_view(
784            &view,
785            2,
786            1,
787            false,
788            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
789            false,
790            None,
791        )
792        .unwrap();
793        assert!(op._buf.is_some());
794        assert_eq!(op.row_stride(), 1);
795        assert_eq!(op.col_stride(), 6);
796    }
797
798    #[test]
799    fn test_output_for_backend_contiguous() {
800        let mut c = StridedArray::<f64>::col_major(&[2, 3]);
801        let mut view = c.view_mut();
802        let op = prepare_output_view(
803            &mut view,
804            1,
805            1,
806            0.0,
807            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
808            false,
809        )
810        .unwrap();
811        assert!(!op.needs_writeback);
812        assert!(op._buf.is_none());
813        assert_eq!(op.row_stride(), 1);
814        assert_eq!(op.col_stride(), 2);
815    }
816
817    #[test]
818    fn test_output_for_backend_non_contiguous_beta_zero() {
819        let data = vec![0.0f64; 100];
820        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
821        let mut view = c.view_mut();
822        let op = prepare_output_view(
823            &mut view,
824            2,
825            1,
826            0.0,
827            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
828            false,
829        )
830        .unwrap();
831        assert!(op.needs_writeback);
832        assert!(op._buf.is_some());
833        assert_eq!(op.row_stride(), 1);
834        assert_eq!(op.col_stride(), 6);
835    }
836
837    #[test]
838    fn test_output_for_backend_non_contiguous_beta_nonzero_and_finalize() {
839        let mut data = vec![0.0f64; 30];
840        data[0] = 10.0;
841        data[1] = 20.0;
842        data[10] = 40.0;
843        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
844        let mut view = c.view_mut();
845        let op = prepare_output_view(
846            &mut view,
847            2,
848            1,
849            1.0,
850            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
851            false,
852        )
853        .unwrap();
854        assert!(op.needs_writeback);
855        let buf = op._buf.as_ref().unwrap();
856        assert_eq!(buf.get(&[0, 0, 0]), 10.0);
857        assert_eq!(buf.get(&[0, 1, 0]), 20.0);
858        assert_eq!(buf.get(&[1, 0, 0]), 40.0);
859        op.finalize_into(&mut view).unwrap();
860    }
861}
862
863#[cfg(test)]
864mod tests {
865    use super::*;
866    use crate::backend::{ActiveBackend, Backend};
867
868    // Helper to construct prepare_input_view params matching the active backend.
869    const UNIT_STRIDE: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
870
871    #[test]
872    fn test_borrowed_contiguous_no_copy() {
873        let a = StridedArray::<f64>::col_major(&[2, 3]);
874        let view = a.view();
875
876        let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
877
878        assert!(!op.has_buf());
879        assert_eq!(op.row_stride(), 1);
880        assert_eq!(op.col_stride(), 2);
881        assert!(!op.conj());
882    }
883
884    #[test]
885    fn test_borrowed_transposed_matrix_no_copy() {
886        let data = vec![0.0f64; 6];
887        let a_t = StridedArray::<f64>::from_parts(data, &[2, 3], &[3, 1], 0).unwrap();
888        let view = a_t.view();
889
890        let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
891
892        assert!(!op.has_buf());
893        assert_eq!(op.row_stride(), 3);
894        assert_eq!(op.col_stride(), 1);
895    }
896
897    #[test]
898    fn test_borrowed_batched_transposed_matrix_no_copy() {
899        let data = vec![0.0f64; 2 * 3 * 5];
900        let a_t = StridedArray::<f64>::from_parts(data, &[2, 3, 5], &[3, 1, 6], 0).unwrap();
901        let view = a_t.view();
902
903        let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
904
905        assert!(!op.has_buf());
906        assert_eq!(op.row_stride(), 3);
907        assert_eq!(op.col_stride(), 1);
908        assert_eq!(op.batch_strides(), &[6]);
909    }
910
911    #[test]
912    fn test_borrowed_non_contiguous_copies() {
913        let data = vec![0.0f64; 100];
914        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
915        let view = a.view();
916
917        let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
918
919        assert!(op.has_buf());
920        assert_eq!(op.row_stride(), 1);
921        assert_eq!(op.col_stride(), 6);
922    }
923
924    #[test]
925    fn test_owned_contiguous_no_copy() {
926        let a = StridedArray::<f64>::col_major(&[2, 3]);
927
928        let op = prepare_input_owned(a, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
929
930        assert!(op.has_buf());
931        assert_eq!(op.row_stride(), 1);
932        assert_eq!(op.col_stride(), 2);
933    }
934
935    #[test]
936    fn test_owned_non_contiguous_copies() {
937        let data = vec![0.0f64; 100];
938        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
939
940        let op = prepare_input_owned(a, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
941
942        assert!(op.has_buf());
943        assert_eq!(op.row_stride(), 1);
944        assert_eq!(op.col_stride(), 6);
945    }
946
947    #[test]
948    fn test_output_view_contiguous() {
949        let mut c = StridedArray::<f64>::col_major(&[2, 3]);
950        let mut view = c.view_mut();
951
952        let op = prepare_output_view(&mut view, 1, 1, 0.0, UNIT_STRIDE, true).unwrap();
953
954        assert!(!op.needs_writeback());
955        assert!(!op.has_buf());
956        assert_eq!(op.row_stride(), 1);
957        assert_eq!(op.col_stride(), 2);
958    }
959
960    #[test]
961    fn test_output_view_non_contiguous_beta_zero() {
962        let data = vec![0.0f64; 100];
963        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
964        let mut view = c.view_mut();
965
966        let op = prepare_output_view(&mut view, 2, 1, 0.0, UNIT_STRIDE, true).unwrap();
967
968        assert!(op.needs_writeback());
969        assert!(op.has_buf());
970        assert_eq!(op.row_stride(), 1);
971        assert_eq!(op.col_stride(), 6);
972    }
973
974    #[test]
975    fn test_output_view_non_contiguous_beta_nonzero_and_finalize() {
976        let mut data = vec![0.0f64; 30];
977        data[0] = 10.0;
978        data[1] = 20.0;
979        data[2] = 30.0;
980        data[10] = 40.0;
981        data[11] = 50.0;
982        data[12] = 60.0;
983        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
984
985        assert_eq!(c.get(&[0, 0, 0]), 10.0);
986        assert_eq!(c.get(&[1, 1, 0]), 50.0);
987
988        let mut view = c.view_mut();
989
990        let mut op = prepare_output_view(&mut view, 2, 1, 1.0, UNIT_STRIDE, true).unwrap();
991
992        assert!(op.needs_writeback());
993        assert!(op.has_buf());
994
995        let buf = op._buf.as_ref().unwrap();
996        assert_eq!(buf.get(&[0, 0, 0]), 10.0);
997        assert_eq!(buf.get(&[1, 1, 0]), 50.0);
998
999        {
1000            let result_data = vec![100.0f64; 6];
1001            let result =
1002                StridedArray::<f64>::from_parts(result_data, &[2, 3, 1], &[3, 1, 1], 0).unwrap();
1003            strided_kernel::copy_into(&mut op._buf.as_mut().unwrap().view_mut(), &result.view())
1004                .unwrap();
1005            op.ptr = op._buf.as_mut().unwrap().view_mut().as_mut_ptr();
1006        }
1007
1008        op.finalize_into(&mut view).unwrap();
1009
1010        assert_eq!(c.get(&[0, 0, 0]), 100.0);
1011        assert_eq!(c.get(&[0, 1, 0]), 100.0);
1012        assert_eq!(c.get(&[0, 2, 0]), 100.0);
1013        assert_eq!(c.get(&[1, 0, 0]), 100.0);
1014        assert_eq!(c.get(&[1, 1, 0]), 100.0);
1015        assert_eq!(c.get(&[1, 2, 0]), 100.0);
1016    }
1017
1018    #[test]
1019    fn test_prepare_input_view_temp_buffer_is_recycled() {
1020        let before = pooled_count_for_type::<f64>();
1021        let data = vec![0.0f64; 100];
1022        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
1023        let view = a.view();
1024
1025        {
1026            let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
1027            assert!(op.has_buf());
1028        }
1029
1030        let after = pooled_count_for_type::<f64>();
1031        assert!(after >= before.saturating_add(1));
1032    }
1033}