strided_einsum2/
contiguous.rs

1//! GEMM-ready operand types and preparation functions for contiguous data.
2//!
3//! These types encapsulate the logic for preparing strided operands for GEMM:
4//! checking fusability, copying to col-major buffers when needed, and managing
5//! the writeback for borrowed output operands.
6
7use crate::backend::{ActiveBackend, BackendConfig};
8use crate::util::try_fuse_group;
9use crate::{Scalar, ScalarBase};
10use strided_view::{StridedArray, StridedView, StridedViewMut};
11
12/// GEMM-ready input operand with contiguous data.
13pub struct ContiguousOperand<T> {
14    ptr: *const T,
15    row_stride: isize,
16    col_stride: isize,
17    batch_strides: Vec<isize>,
18    conj: bool,
19    /// Owns the buffer if a copy was made or input was consumed.
20    pub(crate) _buf: Option<StridedArray<T>>,
21}
22
23/// GEMM-ready output operand with contiguous data.
24pub struct ContiguousOperandMut<T> {
25    ptr: *mut T,
26    row_stride: isize,
27    col_stride: isize,
28    batch_strides: Vec<isize>,
29    /// Whether the caller must copy the buffer back to the original destination
30    /// after GEMM completes (true only for borrowed non-contiguous C).
31    needs_writeback: bool,
32    /// Owns the buffer if a copy was made.
33    pub(crate) _buf: Option<StridedArray<T>>,
34}
35
36impl<T> ContiguousOperand<T> {
37    /// Raw const pointer to the operand data at the base offset.
38    #[inline]
39    pub fn ptr(&self) -> *const T {
40        self.ptr
41    }
42
43    /// Row (lo-group) stride for the fused 2D matrix.
44    #[inline]
45    pub fn row_stride(&self) -> isize {
46        self.row_stride
47    }
48
49    /// Column (sum/ro-group) stride for the fused 2D matrix.
50    #[inline]
51    pub fn col_stride(&self) -> isize {
52        self.col_stride
53    }
54
55    /// Batch dimension strides.
56    #[inline]
57    pub fn batch_strides(&self) -> &[isize] {
58        &self.batch_strides
59    }
60
61    /// Whether this operand requires conjugation.
62    #[inline]
63    pub fn conj(&self) -> bool {
64        self.conj
65    }
66
67    /// Returns `true` if this operand owns a buffer (copy was made or ownership transferred).
68    #[cfg(test)]
69    #[inline]
70    pub(crate) fn has_buf(&self) -> bool {
71        self._buf.is_some()
72    }
73}
74
75impl<T> ContiguousOperandMut<T> {
76    /// Raw mutable pointer to the operand data at the base offset.
77    #[inline]
78    pub fn ptr(&self) -> *mut T {
79        self.ptr
80    }
81
82    /// Row (lo-group) stride for the fused 2D matrix.
83    #[inline]
84    pub fn row_stride(&self) -> isize {
85        self.row_stride
86    }
87
88    /// Column (ro-group) stride for the fused 2D matrix.
89    #[inline]
90    pub fn col_stride(&self) -> isize {
91        self.col_stride
92    }
93
94    /// Batch dimension strides.
95    #[inline]
96    pub fn batch_strides(&self) -> &[isize] {
97        &self.batch_strides
98    }
99
100    /// Returns `true` if this operand owns a buffer (copy was made).
101    #[cfg(test)]
102    #[inline]
103    pub(crate) fn has_buf(&self) -> bool {
104        self._buf.is_some()
105    }
106
107    /// Returns `true` if the caller must copy the buffer back to the original
108    /// destination after GEMM completes.
109    #[cfg(test)]
110    #[inline]
111    pub(crate) fn needs_writeback(&self) -> bool {
112        self.needs_writeback
113    }
114}
115
116impl<T: Copy + Send + Sync> ContiguousOperandMut<T> {
117    /// After GEMM: copy the internal buffer back to `dest` if needed.
118    ///
119    /// This is a no-op when the GEMM wrote directly to the destination
120    /// (contiguous case or owned output).
121    pub fn finalize_into(self, dest: &mut StridedViewMut<T>) -> crate::Result<()> {
122        if self.needs_writeback {
123            if let Some(ref buf) = self._buf {
124                strided_kernel::copy_into(dest, &buf.view())?;
125            }
126        }
127        Ok(())
128    }
129}
130
131/// Allocate a column-major StridedArray with uninitialized data.
132///
133/// With batch-last canonical order `[inner..., batch...]`, pure column-major
134/// naturally gives batch dims the largest strides — each batch slice is a
135/// contiguous column-major matrix.
136pub(crate) fn alloc_col_major_uninit<T: Copy>(dims: &[usize]) -> StridedArray<T> {
137    let total: usize = dims.iter().product::<usize>().max(1);
138    // SAFETY: `T: Copy` guarantees no drop glue, so leaving elements
139    // uninitialised is safe. Every call-site writes all elements before
140    // reading: A and B via `copy_into`, C via `copy_into` (beta != 0)
141    // or GEMM with replace semantics (beta == 0).
142    let mut data = Vec::with_capacity(total);
143    unsafe { data.set_len(total) };
144
145    // Pure column-major: stride 1 for first dim, each subsequent dim
146    // has stride = previous stride * previous dim size.
147    let mut strides = vec![0isize; dims.len()];
148    if !dims.is_empty() {
149        strides[0] = 1;
150        for i in 1..dims.len() {
151            strides[i] = strides[i - 1] * dims[i - 1] as isize;
152        }
153    }
154
155    let arr = StridedArray::from_parts(data, dims, &strides, 0).expect("col-major allocation");
156    arr
157}
158
159/// Prepare a borrowed input view for GEMM.
160///
161/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
162/// Checks if the two inner dimension groups are fusable.
163/// If not, copies to a contiguous col-major buffer.
164pub fn prepare_input_view<T: Scalar>(
165    view: &StridedView<T>,
166    _n_batch: usize,
167    n_group1: usize,
168    n_group2: usize,
169    conj: bool,
170) -> crate::Result<ContiguousOperand<T>> {
171    let dims = view.dims();
172    let strides = view.strides();
173    let n_inner = n_group1 + n_group2;
174
175    // Extract dimension/stride groups (batch-last: inner first, batch at end)
176    let group1_dims = &dims[..n_group1];
177    let group1_strides = &strides[..n_group1];
178    let group2_dims = &dims[n_group1..n_inner];
179    let group2_strides = &strides[n_group1..n_inner];
180
181    // For backends that cannot pass conjugation flags to GEMM (e.g., CBLAS),
182    // materialize conj into the data before the GEMM call.
183    if ActiveBackend::MATERIALIZES_CONJ && conj {
184        use strided_view::Conj as ConjOp;
185        use strided_view::ElementOp;
186
187        let m: usize = group1_dims.iter().product::<usize>().max(1);
188        let mut buf = alloc_col_major_uninit(dims);
189        strided_kernel::map_into(&mut buf.view_mut(), view, |x| ConjOp::apply(x))?;
190        let ptr = buf.view().ptr();
191        let batch_strides = buf.strides()[n_inner..].to_vec();
192        let row_stride = if m == 0 { 0 } else { 1isize };
193        let col_stride = m as isize;
194        return Ok(ContiguousOperand {
195            ptr,
196            row_stride,
197            col_stride,
198            batch_strides,
199            conj: false,
200            _buf: Some(buf),
201        });
202    }
203
204    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
205    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
206
207    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
208
209    // Backends requiring unit stride (e.g., CBLAS) need one of {row_stride, col_stride}
210    // to be 1 (or 0 for size-1 dims). Batched multi-dim arrays may fuse successfully
211    // but still have non-unit strides in both groups. Force a copy in that case.
212    if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
213        let (_, rs) = fused_g1.unwrap();
214        let (_, cs) = fused_g2.unwrap();
215        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
216            needs_copy = true;
217        }
218    }
219
220    if needs_copy {
221        let m: usize = group1_dims.iter().product::<usize>().max(1);
222        let mut buf = alloc_col_major_uninit(dims);
223        strided_kernel::copy_into(&mut buf.view_mut(), view)?;
224        let ptr = buf.view().ptr();
225        let batch_strides = buf.strides()[n_inner..].to_vec();
226        let row_stride = if m == 0 { 0 } else { 1isize };
227        let col_stride = m as isize;
228        Ok(ContiguousOperand {
229            ptr,
230            row_stride,
231            col_stride,
232            batch_strides,
233            conj,
234            _buf: Some(buf),
235        })
236    } else {
237        let (_, rs) = fused_g1.unwrap();
238        let (_, cs) = fused_g2.unwrap();
239        let batch_strides = strides[n_inner..].to_vec();
240        Ok(ContiguousOperand {
241            ptr: view.ptr(),
242            row_stride: rs,
243            col_stride: cs,
244            batch_strides,
245            conj,
246            _buf: None,
247        })
248    }
249}
250
251/// Prepare an owned input array for GEMM.
252///
253/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
254/// If already contiguous after dimension grouping, transfers ownership without copying.
255/// Otherwise, copies to a new col-major buffer.
256pub fn prepare_input_owned<T: Scalar>(
257    arr: StridedArray<T>,
258    _n_batch: usize,
259    n_group1: usize,
260    n_group2: usize,
261    conj: bool,
262) -> crate::Result<ContiguousOperand<T>> {
263    let dims = arr.dims().to_vec();
264    let strides = arr.strides().to_vec();
265    let n_inner = n_group1 + n_group2;
266
267    // Extract dimension/stride groups (batch-last)
268    let group1_dims = &dims[..n_group1];
269    let group1_strides = &strides[..n_group1];
270    let group2_dims = &dims[n_group1..n_inner];
271    let group2_strides = &strides[n_group1..n_inner];
272
273    // For backends that cannot pass conjugation flags to GEMM,
274    // materialize conj into the data before the GEMM call.
275    if ActiveBackend::MATERIALIZES_CONJ && conj {
276        use strided_view::Conj as ConjOp;
277        use strided_view::ElementOp;
278
279        let m: usize = group1_dims.iter().product::<usize>().max(1);
280        let mut buf = alloc_col_major_uninit(&dims);
281        strided_kernel::map_into(&mut buf.view_mut(), &arr.view(), |x| ConjOp::apply(x))?;
282        let ptr = buf.view().ptr();
283        let batch_strides = buf.strides()[n_inner..].to_vec();
284        let row_stride = if m == 0 { 0 } else { 1isize };
285        let col_stride = m as isize;
286        return Ok(ContiguousOperand {
287            ptr,
288            row_stride,
289            col_stride,
290            batch_strides,
291            conj: false,
292            _buf: Some(buf),
293        });
294    }
295
296    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
297    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
298
299    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
300
301    // Backends requiring unit stride need one of {row_stride, col_stride} to be 1 (or 0).
302    if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
303        let (_, rs) = fused_g1.unwrap();
304        let (_, cs) = fused_g2.unwrap();
305        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
306            needs_copy = true;
307        }
308    }
309
310    if needs_copy {
311        let m: usize = group1_dims.iter().product::<usize>().max(1);
312        let mut buf = alloc_col_major_uninit(&dims);
313        strided_kernel::copy_into(&mut buf.view_mut(), &arr.view())?;
314        let ptr = buf.view().ptr();
315        let batch_strides = buf.strides()[n_inner..].to_vec();
316        let row_stride = if m == 0 { 0 } else { 1isize };
317        let col_stride = m as isize;
318        Ok(ContiguousOperand {
319            ptr,
320            row_stride,
321            col_stride,
322            batch_strides,
323            conj,
324            _buf: Some(buf),
325        })
326    } else {
327        let (_, rs) = fused_g1.unwrap();
328        let (_, cs) = fused_g2.unwrap();
329        let batch_strides = strides[n_inner..].to_vec();
330        let ptr = arr.view().ptr();
331        Ok(ContiguousOperand {
332            ptr,
333            row_stride: rs,
334            col_stride: cs,
335            batch_strides,
336            conj,
337            _buf: Some(arr),
338        })
339    }
340}
341
342/// Prepare a borrowed mutable output view for GEMM.
343///
344/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
345/// Checks if the two inner dimension groups (lo, ro) are fusable.
346/// If not, allocates a col-major buffer and copies the existing data into it
347/// when `beta` is non-zero (so the GEMM accumulation is correct).
348///
349/// After GEMM, call [`ContiguousOperandMut::finalize_into`] with the original
350/// view to copy results back if needed.
351///
352/// # Safety contract
353///
354/// When inner dims are fusable (no copy needed), the returned `ContiguousOperandMut`
355/// holds a raw pointer into `view`'s data. The caller must ensure `view` outlives
356/// the returned operand and that no aliasing mutable references exist during GEMM.
357pub fn prepare_output_view<T: Scalar>(
358    view: &mut StridedViewMut<T>,
359    _n_batch: usize,
360    n_group1: usize,
361    n_group2: usize,
362    beta: T,
363) -> crate::Result<ContiguousOperandMut<T>> {
364    let dims = view.dims().to_vec();
365    let strides = view.strides().to_vec();
366    let n_inner = n_group1 + n_group2;
367
368    let group1_dims = &dims[..n_group1];
369    let group1_strides = &strides[..n_group1];
370    let group2_dims = &dims[n_group1..n_inner];
371    let group2_strides = &strides[n_group1..n_inner];
372
373    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
374    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
375
376    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
377
378    // Backends requiring unit stride need one of {row_stride, col_stride} to be 1 (or 0).
379    if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
380        let (_, rs) = fused_g1.unwrap();
381        let (_, cs) = fused_g2.unwrap();
382        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
383            needs_copy = true;
384        }
385    }
386
387    if needs_copy {
388        let m: usize = group1_dims.iter().product::<usize>().max(1);
389        let mut buf = alloc_col_major_uninit(&dims);
390        if beta != T::zero() {
391            // Need to preserve existing values for accumulation
392            strided_kernel::copy_into(&mut buf.view_mut(), &view.as_view())?;
393        }
394        let ptr = buf.view_mut().as_mut_ptr();
395        let batch_strides = buf.strides()[n_inner..].to_vec();
396        let row_stride = if m == 0 { 0 } else { 1isize };
397        let col_stride = m as isize;
398        Ok(ContiguousOperandMut {
399            ptr,
400            row_stride,
401            col_stride,
402            batch_strides,
403            needs_writeback: true,
404            _buf: Some(buf),
405        })
406    } else {
407        let (_, rs) = fused_g1.unwrap();
408        let (_, cs) = fused_g2.unwrap();
409        let batch_strides = strides[n_inner..].to_vec();
410        Ok(ContiguousOperandMut {
411            ptr: view.as_mut_ptr(),
412            row_stride: rs,
413            col_stride: cs,
414            batch_strides,
415            needs_writeback: false,
416            _buf: None,
417        })
418    }
419}
420
421/// Prepare an owned mutable output array for GEMM.
422///
423/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
424/// If already contiguous after dimension grouping, uses the array in-place.
425/// Otherwise, allocates a col-major buffer and copies existing data when
426/// `beta` is non-zero.
427///
428/// Unlike [`prepare_output_view`], `needs_writeback` is always `false` for owned
429/// arrays because the caller owns the buffer and can use it directly.
430///
431/// Currently unused in production (C is always a `StridedViewMut` from the caller).
432/// Kept for future use when `einsum2_into` accepts owned output arrays.
433#[allow(dead_code)]
434pub fn prepare_output_owned<T: Scalar>(
435    arr: &mut StridedArray<T>,
436    _n_batch: usize,
437    n_group1: usize,
438    n_group2: usize,
439    beta: T,
440) -> crate::Result<ContiguousOperandMut<T>> {
441    let dims = arr.dims().to_vec();
442    let strides = arr.strides().to_vec();
443    let n_inner = n_group1 + n_group2;
444
445    let group1_dims = &dims[..n_group1];
446    let group1_strides = &strides[..n_group1];
447    let group2_dims = &dims[n_group1..n_inner];
448    let group2_strides = &strides[n_group1..n_inner];
449
450    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
451    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
452
453    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
454
455    // Backends requiring unit stride need one of {row_stride, col_stride} to be 1 (or 0).
456    if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
457        let (_, rs) = fused_g1.unwrap();
458        let (_, cs) = fused_g2.unwrap();
459        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
460            needs_copy = true;
461        }
462    }
463
464    if needs_copy {
465        let m: usize = group1_dims.iter().product::<usize>().max(1);
466        let mut buf = alloc_col_major_uninit(&dims);
467        if beta != T::zero() {
468            strided_kernel::copy_into(&mut buf.view_mut(), &arr.view())?;
469        }
470        let ptr = buf.view_mut().as_mut_ptr();
471        let batch_strides = buf.strides()[n_inner..].to_vec();
472        let row_stride = if m == 0 { 0 } else { 1isize };
473        let col_stride = m as isize;
474        Ok(ContiguousOperandMut {
475            ptr,
476            row_stride,
477            col_stride,
478            batch_strides,
479            needs_writeback: false,
480            _buf: Some(buf),
481        })
482    } else {
483        let (_, rs) = fused_g1.unwrap();
484        let (_, cs) = fused_g2.unwrap();
485        let batch_strides = strides[n_inner..].to_vec();
486        Ok(ContiguousOperandMut {
487            ptr: arr.view_mut().as_mut_ptr(),
488            row_stride: rs,
489            col_stride: cs,
490            batch_strides,
491            needs_writeback: false,
492            _buf: None,
493        })
494    }
495}
496
497/// Prepare a borrowed input view for a generic GEMM backend.
498///
499/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
500/// Like [`prepare_input_view`] but works with any `ScalarBase` type and
501/// does not handle conjugation materialization. The `conj` field of the
502/// returned operand is always `false`.
503pub fn prepare_input_view_for_backend<T: ScalarBase, B: BackendConfig>(
504    view: &StridedView<T>,
505    _n_batch: usize,
506    n_group1: usize,
507    n_group2: usize,
508) -> crate::Result<ContiguousOperand<T>> {
509    let dims = view.dims();
510    let strides = view.strides();
511    let n_inner = n_group1 + n_group2;
512
513    let group1_dims = &dims[..n_group1];
514    let group1_strides = &strides[..n_group1];
515    let group2_dims = &dims[n_group1..n_inner];
516    let group2_strides = &strides[n_group1..n_inner];
517
518    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
519    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
520
521    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
522
523    if B::REQUIRES_UNIT_STRIDE && !needs_copy {
524        let (_, rs) = fused_g1.unwrap();
525        let (_, cs) = fused_g2.unwrap();
526        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
527            needs_copy = true;
528        }
529    }
530
531    if needs_copy {
532        let m: usize = group1_dims.iter().product::<usize>().max(1);
533        let mut buf = alloc_col_major_uninit(dims);
534        strided_kernel::copy_into(&mut buf.view_mut(), view)?;
535        let ptr = buf.view().ptr();
536        let batch_strides = buf.strides()[n_inner..].to_vec();
537        let row_stride = if m == 0 { 0 } else { 1isize };
538        let col_stride = m as isize;
539        Ok(ContiguousOperand {
540            ptr,
541            row_stride,
542            col_stride,
543            batch_strides,
544            conj: false,
545            _buf: Some(buf),
546        })
547    } else {
548        let (_, rs) = fused_g1.unwrap();
549        let (_, cs) = fused_g2.unwrap();
550        let batch_strides = strides[n_inner..].to_vec();
551        Ok(ContiguousOperand {
552            ptr: view.ptr(),
553            row_stride: rs,
554            col_stride: cs,
555            batch_strides,
556            conj: false,
557            _buf: None,
558        })
559    }
560}
561
562/// Prepare a borrowed mutable output view for a generic GEMM backend.
563///
564/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
565/// Like [`prepare_output_view`] but works with any `ScalarBase` type.
566pub fn prepare_output_view_for_backend<T: ScalarBase, B: BackendConfig>(
567    view: &mut StridedViewMut<T>,
568    _n_batch: usize,
569    n_group1: usize,
570    n_group2: usize,
571    beta: T,
572) -> crate::Result<ContiguousOperandMut<T>> {
573    let dims = view.dims().to_vec();
574    let strides = view.strides().to_vec();
575    let n_inner = n_group1 + n_group2;
576
577    let group1_dims = &dims[..n_group1];
578    let group1_strides = &strides[..n_group1];
579    let group2_dims = &dims[n_group1..n_inner];
580    let group2_strides = &strides[n_group1..n_inner];
581
582    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
583    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
584
585    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
586
587    if B::REQUIRES_UNIT_STRIDE && !needs_copy {
588        let (_, rs) = fused_g1.unwrap();
589        let (_, cs) = fused_g2.unwrap();
590        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
591            needs_copy = true;
592        }
593    }
594
595    if needs_copy {
596        let m: usize = group1_dims.iter().product::<usize>().max(1);
597        let mut buf = alloc_col_major_uninit(&dims);
598        if beta != T::zero() {
599            strided_kernel::copy_into(&mut buf.view_mut(), &view.as_view())?;
600        }
601        let ptr = buf.view_mut().as_mut_ptr();
602        let batch_strides = buf.strides()[n_inner..].to_vec();
603        let row_stride = if m == 0 { 0 } else { 1isize };
604        let col_stride = m as isize;
605        Ok(ContiguousOperandMut {
606            ptr,
607            row_stride,
608            col_stride,
609            batch_strides,
610            needs_writeback: true,
611            _buf: Some(buf),
612        })
613    } else {
614        let (_, rs) = fused_g1.unwrap();
615        let (_, cs) = fused_g2.unwrap();
616        let batch_strides = strides[n_inner..].to_vec();
617        Ok(ContiguousOperandMut {
618            ptr: view.as_mut_ptr(),
619            row_stride: rs,
620            col_stride: cs,
621            batch_strides,
622            needs_writeback: false,
623            _buf: None,
624        })
625    }
626}
627
628#[cfg(test)]
629mod tests_generic_backend {
630    use super::*;
631    use crate::backend::NaiveBackend;
632
633    #[test]
634    fn test_input_for_backend_contiguous() {
635        let a = StridedArray::<f64>::col_major(&[2, 3]);
636        let view = a.view();
637        let op = prepare_input_view_for_backend::<f64, NaiveBackend>(&view, 0, 1, 1).unwrap();
638        // Contiguous col-major: no copy
639        assert!(op._buf.is_none());
640        assert_eq!(op.row_stride(), 1);
641        assert_eq!(op.col_stride(), 2);
642        assert!(!op.conj());
643    }
644
645    #[test]
646    fn test_input_for_backend_non_contiguous() {
647        let data = vec![0.0f64; 100];
648        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
649        let view = a.view();
650        let op = prepare_input_view_for_backend::<f64, NaiveBackend>(&view, 0, 2, 1).unwrap();
651        assert!(op._buf.is_some());
652        assert_eq!(op.row_stride(), 1);
653        assert_eq!(op.col_stride(), 6);
654    }
655
656    #[test]
657    fn test_output_for_backend_contiguous() {
658        let mut c = StridedArray::<f64>::col_major(&[2, 3]);
659        let mut view = c.view_mut();
660        let op =
661            prepare_output_view_for_backend::<f64, NaiveBackend>(&mut view, 0, 1, 1, 0.0).unwrap();
662        assert!(!op.needs_writeback);
663        assert!(op._buf.is_none());
664        assert_eq!(op.row_stride(), 1);
665        assert_eq!(op.col_stride(), 2);
666    }
667
668    #[test]
669    fn test_output_for_backend_non_contiguous_beta_zero() {
670        let data = vec![0.0f64; 100];
671        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
672        let mut view = c.view_mut();
673        let op =
674            prepare_output_view_for_backend::<f64, NaiveBackend>(&mut view, 0, 2, 1, 0.0).unwrap();
675        assert!(op.needs_writeback);
676        assert!(op._buf.is_some());
677        assert_eq!(op.row_stride(), 1);
678        assert_eq!(op.col_stride(), 6);
679    }
680
681    #[test]
682    fn test_output_for_backend_non_contiguous_beta_nonzero_and_finalize() {
683        let mut data = vec![0.0f64; 30];
684        data[0] = 10.0;
685        data[1] = 20.0;
686        data[10] = 40.0;
687        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
688        let mut view = c.view_mut();
689        let op =
690            prepare_output_view_for_backend::<f64, NaiveBackend>(&mut view, 0, 2, 1, 1.0).unwrap();
691        assert!(op.needs_writeback);
692        // beta != 0 -> existing data copied into buffer
693        let buf = op._buf.as_ref().unwrap();
694        assert_eq!(buf.get(&[0, 0, 0]), 10.0);
695        assert_eq!(buf.get(&[0, 1, 0]), 20.0);
696        assert_eq!(buf.get(&[1, 0, 0]), 40.0);
697        // finalize copies back
698        op.finalize_into(&mut view).unwrap();
699    }
700}
701
702#[cfg(test)]
703#[cfg(any(
704    all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
705    all(
706        not(feature = "faer"),
707        any(
708            all(feature = "blas", not(feature = "blas-inject")),
709            all(feature = "blas-inject", not(feature = "blas"))
710        )
711    )
712))]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn test_borrowed_contiguous_no_copy() {
718        // Col-major [2,3]: strides [1,2]. n_batch=0, n_group1=1, n_group2=1.
719        // Group1 = dim [2], stride [1] -> fuses to (2, 1).
720        // Group2 = dim [3], stride [2] -> fuses to (3, 2).
721        // Both fuse -> no copy needed.
722        let a = StridedArray::<f64>::col_major(&[2, 3]);
723        let view = a.view();
724
725        let op = prepare_input_view(&view, 0, 1, 1, false).unwrap();
726
727        assert!(!op.has_buf());
728        assert_eq!(op.row_stride(), 1);
729        assert_eq!(op.col_stride(), 2);
730        assert!(!op.conj());
731    }
732
733    #[test]
734    fn test_borrowed_non_contiguous_copies() {
735        // dims [2,3,4] with strides [20,4,1], n_batch=0, n_group1=2, n_group2=1.
736        // Group1 = dims [2,3], strides [20,4]. Try fuse: sorted by |stride| -> [(3,4),(2,20)].
737        // Check: 4*3=12 != 20, so fusion fails -> needs copy.
738        let data = vec![0.0f64; 100];
739        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
740        let view = a.view();
741
742        let op = prepare_input_view(&view, 0, 2, 1, false).unwrap();
743
744        assert!(op.has_buf());
745        // After copy to col-major: row_stride=1, col_stride = m = 2*3 = 6
746        assert_eq!(op.row_stride(), 1);
747        assert_eq!(op.col_stride(), 6);
748    }
749
750    #[test]
751    fn test_owned_contiguous_no_copy() {
752        // Col-major [2,3]: strides [1,2]. n_batch=0, n_group1=1, n_group2=1.
753        // Fusable -> ownership transferred, no copy.
754        let a = StridedArray::<f64>::col_major(&[2, 3]);
755
756        let op = prepare_input_owned(a, 0, 1, 1, false).unwrap();
757
758        // Ownership transferred: _buf = Some (the original array).
759        assert!(op.has_buf());
760        assert_eq!(op.row_stride(), 1);
761        assert_eq!(op.col_stride(), 2);
762    }
763
764    #[test]
765    fn test_owned_non_contiguous_copies() {
766        // dims [2,3,4] with strides [20,4,1], n_batch=0, n_group1=2, n_group2=1.
767        // Non-fusable -> copies to new buffer.
768        let data = vec![0.0f64; 100];
769        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
770
771        let op = prepare_input_owned(a, 0, 2, 1, false).unwrap();
772
773        assert!(op.has_buf());
774        // After copy: row_stride=1, col_stride = m = 2*3 = 6
775        assert_eq!(op.row_stride(), 1);
776        assert_eq!(op.col_stride(), 6);
777    }
778
779    // ---- Output preparation tests ----
780
781    #[test]
782    fn test_output_view_contiguous() {
783        // Col-major [2,3]: strides [1,2]. n_batch=0, n_group1=1, n_group2=1.
784        // Both groups fuse -> no copy, no writeback.
785        let mut c = StridedArray::<f64>::col_major(&[2, 3]);
786        let mut view = c.view_mut();
787
788        let op = prepare_output_view(&mut view, 0, 1, 1, 0.0).unwrap();
789
790        assert!(!op.needs_writeback());
791        assert!(!op.has_buf());
792        assert_eq!(op.row_stride(), 1);
793        assert_eq!(op.col_stride(), 2);
794    }
795
796    #[test]
797    fn test_output_view_non_contiguous_beta_zero() {
798        // dims [2,3,4] with strides [20,4,1], n_batch=0, n_group1=2, n_group2=1.
799        // Group1 = dims [2,3], strides [20,4] -> non-fusable -> needs copy.
800        // beta=0 -> no copy-in of existing data, but writeback needed.
801        let data = vec![0.0f64; 100];
802        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
803        let mut view = c.view_mut();
804
805        let op = prepare_output_view(&mut view, 0, 2, 1, 0.0).unwrap();
806
807        assert!(op.needs_writeback());
808        assert!(op.has_buf());
809        // After alloc col-major: row_stride=1, col_stride = m = 2*3 = 6
810        assert_eq!(op.row_stride(), 1);
811        assert_eq!(op.col_stride(), 6);
812    }
813
814    #[test]
815    fn test_output_view_non_contiguous_beta_nonzero_and_finalize() {
816        // Use 3D: dims [2,3,1] with strides [10,1,1], group1=2 dims, group2=1 dim.
817        // group1 = dims [2,3], strides [10,1]. Sorted: [(3,1),(2,10)]. 1*3=3 != 10 -> non-fusable!
818
819        // Pre-populate a data buffer with known values at the right offsets.
820        // With strides [10,1,1], element [i,j,0] is at offset i*10 + j*1.
821        let mut data = vec![0.0f64; 30];
822        data[0] = 10.0; // [0,0,0]
823        data[1] = 20.0; // [0,1,0]
824        data[2] = 30.0; // [0,2,0]
825        data[10] = 40.0; // [1,0,0]
826        data[11] = 50.0; // [1,1,0]
827        data[12] = 60.0; // [1,2,0]
828        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
829
830        // Verify the known values
831        assert_eq!(c.get(&[0, 0, 0]), 10.0);
832        assert_eq!(c.get(&[1, 1, 0]), 50.0);
833
834        let mut view = c.view_mut();
835
836        // group1 dims [2,3], group2 dims [1] -> group1 is non-fusable.
837        let mut op = prepare_output_view(&mut view, 0, 2, 1, 1.0).unwrap();
838
839        assert!(op.needs_writeback());
840        assert!(op.has_buf());
841
842        // beta=1.0 -> existing data should have been copied into the buffer.
843        // Verify by reading from the buffer via the internal _buf.
844        let buf = op._buf.as_ref().unwrap();
845        assert_eq!(buf.get(&[0, 0, 0]), 10.0);
846        assert_eq!(buf.get(&[1, 1, 0]), 50.0);
847
848        // Simulate GEMM by writing to the buffer through copy_into.
849        {
850            let result_data = vec![100.0f64; 6];
851            let result =
852                StridedArray::<f64>::from_parts(result_data, &[2, 3, 1], &[3, 1, 1], 0).unwrap();
853            strided_kernel::copy_into(&mut op._buf.as_mut().unwrap().view_mut(), &result.view())
854                .unwrap();
855            // Update ptr to the buffer (in case of reallocation)
856            op.ptr = op._buf.as_mut().unwrap().view_mut().as_mut_ptr();
857        }
858
859        // finalize_into should copy the buffer back to the original view.
860        op.finalize_into(&mut view).unwrap();
861
862        // All elements should now be 100.0 in the original non-contiguous array.
863        assert_eq!(c.get(&[0, 0, 0]), 100.0);
864        assert_eq!(c.get(&[0, 1, 0]), 100.0);
865        assert_eq!(c.get(&[0, 2, 0]), 100.0);
866        assert_eq!(c.get(&[1, 0, 0]), 100.0);
867        assert_eq!(c.get(&[1, 1, 0]), 100.0);
868        assert_eq!(c.get(&[1, 2, 0]), 100.0);
869    }
870
871    #[test]
872    fn test_output_owned_no_writeback() {
873        // Non-fusable owned array: needs_writeback should be false.
874        // dims [2,3,4] with strides [20,4,1], n_batch=0, n_group1=2, n_group2=1.
875        let data = vec![0.0f64; 100];
876        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
877
878        let op = prepare_output_owned(&mut c, 0, 2, 1, 0.0).unwrap();
879
880        // Non-fusable -> has buffer, but owned -> no writeback.
881        assert!(!op.needs_writeback());
882        assert!(op.has_buf());
883        assert_eq!(op.row_stride(), 1);
884        assert_eq!(op.col_stride(), 6);
885    }
886}