tenferro_tensor/
layout.rs1use std::ops::Add;
2
3use tenferro_device::{Error, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum MemoryOrder {
20 ColumnMajor,
22 RowMajor,
24}
25
26pub(crate) fn compute_contiguous_strides(dims: &[usize], order: MemoryOrder) -> Vec<isize> {
27 let ndim = dims.len();
28 if ndim == 0 {
29 return vec![];
30 }
31 let mut strides = vec![0isize; ndim];
32 match order {
33 MemoryOrder::ColumnMajor => {
34 strides[0] = 1;
35 for i in 1..ndim {
36 strides[i] = strides[i - 1] * dims[i - 1] as isize;
37 }
38 }
39 MemoryOrder::RowMajor => {
40 strides[ndim - 1] = 1;
41 for i in (0..ndim - 1).rev() {
42 strides[i] = strides[i + 1] * dims[i + 1] as isize;
43 }
44 }
45 }
46 strides
47}
48
49pub(crate) fn is_contiguous_in_order(
50 dims: &[usize],
51 strides: &[isize],
52 order: MemoryOrder,
53) -> bool {
54 if dims.is_empty() || dims.contains(&0) {
55 return true;
56 }
57
58 let expected = compute_contiguous_strides(dims, order);
59 dims.iter()
60 .zip(strides)
61 .zip(expected)
62 .all(|((&dim, &actual), expected)| dim <= 1 || actual == expected)
63}
64
65pub(crate) fn validate_layout_against_len(
66 dims: &[usize],
67 strides: &[isize],
68 offset: isize,
69 data_len: usize,
70) -> Result<()> {
71 if strides.len() != dims.len() {
72 return Err(Error::InvalidArgument(format!(
73 "strides length {} doesn't match dims length {}",
74 strides.len(),
75 dims.len()
76 )));
77 }
78
79 if dims.iter().product::<usize>() == 0 {
80 return Ok(());
81 }
82
83 let mut min_pos = offset;
84 let mut max_pos = offset;
85 for (axis, (&dim, &stride)) in dims.iter().zip(strides).enumerate() {
86 if dim == 0 {
87 continue;
88 }
89 let extent = isize::try_from(dim - 1)
90 .ok()
91 .and_then(|d| d.checked_mul(stride))
92 .ok_or_else(|| {
93 Error::StrideError(format!(
94 "extent overflow for dimension {axis} (size={dim}, stride={stride})"
95 ))
96 })?;
97 if extent >= 0 {
98 max_pos += extent;
99 } else {
100 min_pos += extent;
101 }
102 }
103 if min_pos < 0 || max_pos >= data_len as isize {
104 return Err(Error::StrideError(format!(
105 "layout accesses buffer positions {}..={} but buffer length is {}",
106 min_pos, max_pos, data_len
107 )));
108 }
109
110 Ok(())
111}
112
113pub(crate) fn copy_strided<T: Copy>(
114 src: &[T],
115 dims: &[usize],
116 src_strides: &[isize],
117 src_offset: isize,
118 dst: &mut [T],
119 dst_strides: &[isize],
120) {
121 let n_elements: usize = dims.iter().product();
122 if n_elements == 0 {
123 return;
124 }
125 if dims.is_empty() {
126 dst[0] = src[src_offset as usize];
127 return;
128 }
129
130 let mut index = vec![0usize; dims.len()];
131 for _ in 0..n_elements {
132 let src_pos = src_offset
133 + index
134 .iter()
135 .zip(src_strides)
136 .map(|(&i, &s)| i as isize * s)
137 .sum::<isize>();
138 let dst_pos = index
139 .iter()
140 .zip(dst_strides)
141 .map(|(&i, &s)| i as isize * s)
142 .sum::<isize>();
143 debug_assert!(
144 src_pos >= 0 && (src_pos as usize) < src.len(),
145 "copy_strided: source position {} out of bounds for buffer length {}",
146 src_pos,
147 src.len()
148 );
149 debug_assert!(
150 dst_pos >= 0 && (dst_pos as usize) < dst.len(),
151 "copy_strided: destination position {} out of bounds for buffer length {}",
152 dst_pos,
153 dst.len()
154 );
155 dst[dst_pos as usize] = src[src_pos as usize];
156
157 for axis in 0..dims.len() {
158 index[axis] += 1;
159 if index[axis] < dims[axis] {
160 break;
161 }
162 index[axis] = 0;
163 }
164 }
165}
166
167pub(crate) struct StridedInput<'a, T> {
168 pub(crate) data: &'a [T],
169 pub(crate) strides: &'a [isize],
170 pub(crate) offset: isize,
171}
172
173pub(crate) fn add_strided<T: Copy + Add<Output = T>>(
174 dims: &[usize],
175 a: StridedInput<'_, T>,
176 b: StridedInput<'_, T>,
177 dst: &mut [T],
178 dst_strides: &[isize],
179) {
180 let n_elements: usize = dims.iter().product();
181 if n_elements == 0 {
182 return;
183 }
184 if dims.is_empty() {
185 dst[0] = a.data[a.offset as usize] + b.data[b.offset as usize];
186 return;
187 }
188
189 let mut index = vec![0usize; dims.len()];
190 for _ in 0..n_elements {
191 let a_pos = a.offset
192 + index
193 .iter()
194 .zip(a.strides.iter())
195 .map(|(&i, &s)| i as isize * s)
196 .sum::<isize>();
197 let b_pos = b.offset
198 + index
199 .iter()
200 .zip(b.strides.iter())
201 .map(|(&i, &s)| i as isize * s)
202 .sum::<isize>();
203 let dst_pos = index
204 .iter()
205 .zip(dst_strides.iter())
206 .map(|(&i, &s)| i as isize * s)
207 .sum::<isize>();
208 debug_assert!(
209 a_pos >= 0 && (a_pos as usize) < a.data.len(),
210 "add_strided: input a position {} out of bounds for buffer length {}",
211 a_pos,
212 a.data.len()
213 );
214 debug_assert!(
215 b_pos >= 0 && (b_pos as usize) < b.data.len(),
216 "add_strided: input b position {} out of bounds for buffer length {}",
217 b_pos,
218 b.data.len()
219 );
220 debug_assert!(
221 dst_pos >= 0 && (dst_pos as usize) < dst.len(),
222 "add_strided: destination position {} out of bounds for buffer length {}",
223 dst_pos,
224 dst.len()
225 );
226 dst[dst_pos as usize] = a.data[a_pos as usize] + b.data[b_pos as usize];
227
228 for axis in 0..dims.len() {
229 index[axis] += 1;
230 if index[axis] < dims[axis] {
231 break;
232 }
233 index[axis] = 0;
234 }
235 }
236}