tenferro_tensor/
layout.rs

1use std::ops::Add;
2
3use tenferro_device::{Error, Result};
4
5/// Memory ordering for new allocations.
6///
7/// Specifies how elements are laid out in memory when creating new tensors
8/// or copying data into a contiguous buffer. This is **not** stored on the
9/// tensor itself. The tensor's strides fully describe the actual layout.
10///
11/// # Examples
12///
13/// ```ignore
14/// use tenferro_tensor::MemoryOrder;
15///
16/// let order = MemoryOrder::RowMajor;
17/// ```
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum MemoryOrder {
20    /// Column-major (Fortran/Julia order). First dimension has stride 1.
21    ColumnMajor,
22    /// Row-major (C/NumPy order). Last dimension has stride 1.
23    RowMajor,
24}
25
26pub(crate) fn compute_contiguous_strides(dims: &[usize], order: MemoryOrder) -> Vec<isize> {
27    let ndim = dims.len();
28    if ndim == 0 {
29        return vec![];
30    }
31    let mut strides = vec![0isize; ndim];
32    match order {
33        MemoryOrder::ColumnMajor => {
34            strides[0] = 1;
35            for i in 1..ndim {
36                strides[i] = strides[i - 1] * dims[i - 1] as isize;
37            }
38        }
39        MemoryOrder::RowMajor => {
40            strides[ndim - 1] = 1;
41            for i in (0..ndim - 1).rev() {
42                strides[i] = strides[i + 1] * dims[i + 1] as isize;
43            }
44        }
45    }
46    strides
47}
48
49pub(crate) fn is_contiguous_in_order(
50    dims: &[usize],
51    strides: &[isize],
52    order: MemoryOrder,
53) -> bool {
54    if dims.is_empty() || dims.contains(&0) {
55        return true;
56    }
57
58    let expected = compute_contiguous_strides(dims, order);
59    dims.iter()
60        .zip(strides)
61        .zip(expected)
62        .all(|((&dim, &actual), expected)| dim <= 1 || actual == expected)
63}
64
65pub(crate) fn validate_layout_against_len(
66    dims: &[usize],
67    strides: &[isize],
68    offset: isize,
69    data_len: usize,
70) -> Result<()> {
71    if strides.len() != dims.len() {
72        return Err(Error::InvalidArgument(format!(
73            "strides length {} doesn't match dims length {}",
74            strides.len(),
75            dims.len()
76        )));
77    }
78
79    if dims.iter().product::<usize>() == 0 {
80        return Ok(());
81    }
82
83    let mut min_pos = offset;
84    let mut max_pos = offset;
85    for (axis, (&dim, &stride)) in dims.iter().zip(strides).enumerate() {
86        if dim == 0 {
87            continue;
88        }
89        let extent = isize::try_from(dim - 1)
90            .ok()
91            .and_then(|d| d.checked_mul(stride))
92            .ok_or_else(|| {
93                Error::StrideError(format!(
94                    "extent overflow for dimension {axis} (size={dim}, stride={stride})"
95                ))
96            })?;
97        if extent >= 0 {
98            max_pos += extent;
99        } else {
100            min_pos += extent;
101        }
102    }
103    if min_pos < 0 || max_pos >= data_len as isize {
104        return Err(Error::StrideError(format!(
105            "layout accesses buffer positions {}..={} but buffer length is {}",
106            min_pos, max_pos, data_len
107        )));
108    }
109
110    Ok(())
111}
112
113pub(crate) fn copy_strided<T: Copy>(
114    src: &[T],
115    dims: &[usize],
116    src_strides: &[isize],
117    src_offset: isize,
118    dst: &mut [T],
119    dst_strides: &[isize],
120) {
121    let n_elements: usize = dims.iter().product();
122    if n_elements == 0 {
123        return;
124    }
125    if dims.is_empty() {
126        dst[0] = src[src_offset as usize];
127        return;
128    }
129
130    let mut index = vec![0usize; dims.len()];
131    for _ in 0..n_elements {
132        let src_pos = src_offset
133            + index
134                .iter()
135                .zip(src_strides)
136                .map(|(&i, &s)| i as isize * s)
137                .sum::<isize>();
138        let dst_pos = index
139            .iter()
140            .zip(dst_strides)
141            .map(|(&i, &s)| i as isize * s)
142            .sum::<isize>();
143        debug_assert!(
144            src_pos >= 0 && (src_pos as usize) < src.len(),
145            "copy_strided: source position {} out of bounds for buffer length {}",
146            src_pos,
147            src.len()
148        );
149        debug_assert!(
150            dst_pos >= 0 && (dst_pos as usize) < dst.len(),
151            "copy_strided: destination position {} out of bounds for buffer length {}",
152            dst_pos,
153            dst.len()
154        );
155        dst[dst_pos as usize] = src[src_pos as usize];
156
157        for axis in 0..dims.len() {
158            index[axis] += 1;
159            if index[axis] < dims[axis] {
160                break;
161            }
162            index[axis] = 0;
163        }
164    }
165}
166
167pub(crate) struct StridedInput<'a, T> {
168    pub(crate) data: &'a [T],
169    pub(crate) strides: &'a [isize],
170    pub(crate) offset: isize,
171}
172
173pub(crate) fn add_strided<T: Copy + Add<Output = T>>(
174    dims: &[usize],
175    a: StridedInput<'_, T>,
176    b: StridedInput<'_, T>,
177    dst: &mut [T],
178    dst_strides: &[isize],
179) {
180    let n_elements: usize = dims.iter().product();
181    if n_elements == 0 {
182        return;
183    }
184    if dims.is_empty() {
185        dst[0] = a.data[a.offset as usize] + b.data[b.offset as usize];
186        return;
187    }
188
189    let mut index = vec![0usize; dims.len()];
190    for _ in 0..n_elements {
191        let a_pos = a.offset
192            + index
193                .iter()
194                .zip(a.strides.iter())
195                .map(|(&i, &s)| i as isize * s)
196                .sum::<isize>();
197        let b_pos = b.offset
198            + index
199                .iter()
200                .zip(b.strides.iter())
201                .map(|(&i, &s)| i as isize * s)
202                .sum::<isize>();
203        let dst_pos = index
204            .iter()
205            .zip(dst_strides.iter())
206            .map(|(&i, &s)| i as isize * s)
207            .sum::<isize>();
208        debug_assert!(
209            a_pos >= 0 && (a_pos as usize) < a.data.len(),
210            "add_strided: input a position {} out of bounds for buffer length {}",
211            a_pos,
212            a.data.len()
213        );
214        debug_assert!(
215            b_pos >= 0 && (b_pos as usize) < b.data.len(),
216            "add_strided: input b position {} out of bounds for buffer length {}",
217            b_pos,
218            b.data.len()
219        );
220        debug_assert!(
221            dst_pos >= 0 && (dst_pos as usize) < dst.len(),
222            "add_strided: destination position {} out of bounds for buffer length {}",
223            dst_pos,
224            dst.len()
225        );
226        dst[dst_pos as usize] = a.data[a_pos as usize] + b.data[b_pos as usize];
227
228        for axis in 0..dims.len() {
229            index[axis] += 1;
230            if index[axis] < dims[axis] {
231                break;
232            }
233            index[axis] = 0;
234        }
235    }
236}