1use num_complex::{Complex32, Complex64};
2use num_traits::Zero;
3use strided_kernel::{copy_into, map_into, Identity, StridedView};
4
5use crate::{
6 types::{flat_to_multi, Tensor, TypedTensor},
7 DType,
8};
9
10use super::{tensor_from_array, typed_array_uninit, typed_view};
11
12fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
13 crate::Error::BackendFailure {
14 op,
15 message: err.to_string(),
16 }
17}
18
19fn validate_rank(op: &'static str, expected: usize, actual: usize) -> crate::Result<()> {
20 if expected != actual {
21 return Err(crate::Error::RankMismatch {
22 op,
23 expected,
24 actual,
25 });
26 }
27 Ok(())
28}
29
30fn validate_axis(op: &'static str, axis: usize, rank: usize) -> crate::Result<()> {
31 if axis >= rank {
32 return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
33 }
34 Ok(())
35}
36
37fn validate_axes_distinct(op: &'static str, axis_a: usize, axis_b: usize) -> crate::Result<()> {
38 if axis_a == axis_b {
39 return Err(crate::Error::DuplicateAxis {
40 op,
41 axis: axis_a,
42 role: "axes",
43 });
44 }
45 Ok(())
46}
47
48fn validate_permutation(op: &'static str, perm: &[usize], rank: usize) -> crate::Result<()> {
49 validate_rank(op, rank, perm.len())?;
50 let mut seen = vec![false; rank];
51 for &axis in perm {
52 validate_axis(op, axis, rank)?;
53 if seen[axis] {
54 return Err(crate::Error::DuplicateAxis {
55 op,
56 axis,
57 role: "perm",
58 });
59 }
60 seen[axis] = true;
61 }
62 Ok(())
63}
64
65fn host_view<T: Copy>(tensor: &TypedTensor<T>) -> crate::Result<StridedView<'_, T, Identity>> {
66 match &tensor.buffer {
67 crate::Buffer::Host(data) => {
68 let strides = crate::col_major_strides(&tensor.shape);
69 StridedView::new(data, &tensor.shape, &strides, 0)
70 .map_err(|err| backend_failure("structural", err))
71 }
72 crate::Buffer::Backend(_) => Err(crate::Error::BackendFailure {
73 op: "structural",
74 message: "backend buffers are not supported for structural CPU helpers".into(),
75 }),
76 #[cfg(feature = "cubecl")]
77 crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
78 }
79}
80
81fn copy_view_to_array<T: Copy + Clone>(
82 op: &'static str,
83 mut out: strided_kernel::StridedArray<T>,
84 src: &StridedView<'_, T>,
85) -> crate::Result<TypedTensor<T>> {
86 copy_into(&mut out.view_mut(), src).map_err(|err| backend_failure(op, err))?;
87 Ok(tensor_from_array(out))
88}
89
90pub fn transpose(input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
91 match input {
92 Tensor::F32(t) => Ok(Tensor::F32(typed_transpose(t, perm)?)),
93 Tensor::F64(t) => Ok(Tensor::F64(typed_transpose(t, perm)?)),
94 Tensor::C32(t) => Ok(Tensor::C32(typed_transpose(t, perm)?)),
95 Tensor::C64(t) => Ok(Tensor::C64(typed_transpose(t, perm)?)),
96 }
97}
98
99pub fn reshape(input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
100 match input {
101 Tensor::F32(t) => Ok(Tensor::F32(typed_reshape(t, shape)?)),
102 Tensor::F64(t) => Ok(Tensor::F64(typed_reshape(t, shape)?)),
103 Tensor::C32(t) => Ok(Tensor::C32(typed_reshape(t, shape)?)),
104 Tensor::C64(t) => Ok(Tensor::C64(typed_reshape(t, shape)?)),
105 }
106}
107
108pub fn broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> crate::Result<Tensor> {
109 match input {
110 Tensor::F32(t) => Ok(Tensor::F32(typed_broadcast_in_dim(t, shape, dims)?)),
111 Tensor::F64(t) => Ok(Tensor::F64(typed_broadcast_in_dim(t, shape, dims)?)),
112 Tensor::C32(t) => Ok(Tensor::C32(typed_broadcast_in_dim(t, shape, dims)?)),
113 Tensor::C64(t) => Ok(Tensor::C64(typed_broadcast_in_dim(t, shape, dims)?)),
114 }
115}
116
117pub fn convert(input: &Tensor, to: DType) -> Tensor {
118 match (input, to) {
119 (Tensor::F32(t), DType::F32) => Tensor::F32(t.clone()),
120 (Tensor::F32(t), DType::F64) => Tensor::F64(typed_convert(t, |x| x as f64)),
121 (Tensor::F32(t), DType::C32) => Tensor::C32(typed_convert(t, |x| Complex32::new(x, 0.0))),
122 (Tensor::F32(t), DType::C64) => {
123 Tensor::C64(typed_convert(t, |x| Complex64::new(x as f64, 0.0)))
124 }
125 (Tensor::F64(t), DType::F32) => Tensor::F32(typed_convert(t, |x| x as f32)),
126 (Tensor::F64(t), DType::F64) => Tensor::F64(t.clone()),
127 (Tensor::F64(t), DType::C32) => {
128 Tensor::C32(typed_convert(t, |x| Complex32::new(x as f32, 0.0)))
129 }
130 (Tensor::F64(t), DType::C64) => Tensor::C64(typed_convert(t, |x| Complex64::new(x, 0.0))),
131 (Tensor::C32(t), DType::F32) => Tensor::F32(typed_convert(t, |z| z.re)),
132 (Tensor::C32(t), DType::F64) => Tensor::F64(typed_convert(t, |z| z.re as f64)),
133 (Tensor::C32(t), DType::C32) => Tensor::C32(t.clone()),
134 (Tensor::C32(t), DType::C64) => Tensor::C64(typed_convert(t, |z| {
135 Complex64::new(z.re as f64, z.im as f64)
136 })),
137 (Tensor::C64(t), DType::F32) => Tensor::F32(typed_convert(t, |z| z.re as f32)),
138 (Tensor::C64(t), DType::F64) => Tensor::F64(typed_convert(t, |z| z.re)),
139 (Tensor::C64(t), DType::C32) => Tensor::C32(typed_convert(t, |z| {
140 Complex32::new(z.re as f32, z.im as f32)
141 })),
142 (Tensor::C64(t), DType::C64) => Tensor::C64(t.clone()),
143 }
144}
145
146pub fn extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
147 match input {
148 Tensor::F32(t) => Ok(Tensor::F32(typed_extract_diagonal(t, axis_a, axis_b)?)),
149 Tensor::F64(t) => Ok(Tensor::F64(typed_extract_diagonal(t, axis_a, axis_b)?)),
150 Tensor::C32(t) => Ok(Tensor::C32(typed_extract_diagonal(t, axis_a, axis_b)?)),
151 Tensor::C64(t) => Ok(Tensor::C64(typed_extract_diagonal(t, axis_a, axis_b)?)),
152 }
153}
154
155pub fn embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
156 match input {
157 Tensor::F32(t) => Ok(Tensor::F32(typed_embed_diagonal(t, axis_a, axis_b)?)),
158 Tensor::F64(t) => Ok(Tensor::F64(typed_embed_diagonal(t, axis_a, axis_b)?)),
159 Tensor::C32(t) => Ok(Tensor::C32(typed_embed_diagonal(t, axis_a, axis_b)?)),
160 Tensor::C64(t) => Ok(Tensor::C64(typed_embed_diagonal(t, axis_a, axis_b)?)),
161 }
162}
163
164pub fn tril(input: &Tensor, k: i64) -> crate::Result<Tensor> {
165 match input {
166 Tensor::F32(t) => Ok(Tensor::F32(typed_tril(t, k)?)),
167 Tensor::F64(t) => Ok(Tensor::F64(typed_tril(t, k)?)),
168 Tensor::C32(t) => Ok(Tensor::C32(typed_tril(t, k)?)),
169 Tensor::C64(t) => Ok(Tensor::C64(typed_tril(t, k)?)),
170 }
171}
172
173pub fn triu(input: &Tensor, k: i64) -> crate::Result<Tensor> {
174 match input {
175 Tensor::F32(t) => Ok(Tensor::F32(typed_triu(t, k)?)),
176 Tensor::F64(t) => Ok(Tensor::F64(typed_triu(t, k)?)),
177 Tensor::C32(t) => Ok(Tensor::C32(typed_triu(t, k)?)),
178 Tensor::C64(t) => Ok(Tensor::C64(typed_triu(t, k)?)),
179 }
180}
181
182pub fn typed_transpose<T: Copy + Zero + Clone>(
183 tensor: &TypedTensor<T>,
184 perm: &[usize],
185) -> crate::Result<TypedTensor<T>> {
186 validate_permutation("transpose", perm, tensor.shape.len())?;
187 let src = host_view(tensor)?;
188 let permuted = src
189 .permute(perm)
190 .map_err(|err| backend_failure("transpose", err))?;
191 let out = unsafe { typed_array_uninit(permuted.dims()) };
193 copy_view_to_array("transpose", out, &permuted)
194}
195
196pub fn typed_reshape<T: Clone>(
197 tensor: &TypedTensor<T>,
198 shape: &[usize],
199) -> crate::Result<TypedTensor<T>> {
200 let old_n: usize = tensor.shape.iter().product();
201 let new_n: usize = shape.iter().product();
202 if old_n != new_n {
203 return Err(crate::Error::ShapeMismatch {
204 op: "reshape",
205 lhs: tensor.shape.clone(),
206 rhs: shape.to_vec(),
207 });
208 }
209 Ok(TypedTensor {
210 buffer: tensor.buffer.clone(),
211 shape: shape.to_vec(),
212 placement: tensor.placement.clone(),
213 })
214}
215
216pub fn typed_broadcast_in_dim<T: Copy + Zero + Clone>(
217 tensor: &TypedTensor<T>,
218 shape: &[usize],
219 dims: &[usize],
220) -> crate::Result<TypedTensor<T>> {
221 validate_rank("broadcast_in_dim", tensor.shape.len(), dims.len())?;
222 let mut seen = vec![false; shape.len()];
223 let mut base_dims = vec![1usize; shape.len()];
224 let mut base_strides = vec![0isize; shape.len()];
225 let source_strides = crate::col_major_strides(&tensor.shape);
226 for (src_axis, &dst_axis) in dims.iter().enumerate() {
227 validate_axis("broadcast_in_dim", dst_axis, shape.len())?;
228 if seen[dst_axis] {
229 return Err(crate::Error::DuplicateAxis {
230 op: "broadcast_in_dim",
231 axis: dst_axis,
232 role: "dims",
233 });
234 }
235 seen[dst_axis] = true;
236 let source_dim = tensor.shape[src_axis];
237 let target_dim = shape[dst_axis];
238 if source_dim != target_dim && source_dim != 1 {
239 return Err(crate::Error::ShapeMismatch {
240 op: "broadcast_in_dim",
241 lhs: tensor.shape.clone(),
242 rhs: shape.to_vec(),
243 });
244 }
245 base_dims[dst_axis] = source_dim;
246 base_strides[dst_axis] = source_strides[src_axis];
247 }
248 let base: StridedView<'_, T, Identity> = match &tensor.buffer {
249 crate::Buffer::Host(data) => StridedView::new(data, &base_dims, &base_strides, 0)
250 .map_err(|err| backend_failure("broadcast_in_dim", err))?,
251 crate::Buffer::Backend(_) => {
252 return Err(crate::Error::BackendFailure {
253 op: "broadcast_in_dim",
254 message: "backend buffers are not supported for structural CPU helpers".into(),
255 })
256 }
257 #[cfg(feature = "cubecl")]
258 crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
259 };
260 let broadcast: StridedView<'_, T, Identity> = base
261 .broadcast(shape)
262 .map_err(|err| backend_failure("broadcast_in_dim", err))?;
263 let mut out = unsafe { typed_array_uninit(shape) };
265 copy_into(&mut out.view_mut(), &broadcast)
266 .map_err(|err| backend_failure("broadcast_in_dim", err))?;
267 Ok(tensor_from_array(out))
268}
269
270fn typed_convert<S, T>(tensor: &TypedTensor<S>, f: impl Fn(S) -> T) -> TypedTensor<T>
271where
272 S: Copy,
273 T: Copy + Clone + Zero,
274{
275 let mut out = unsafe { typed_array_uninit(&tensor.shape) };
277 map_into(&mut out.view_mut(), &typed_view(tensor), f).expect("typed_convert");
278 tensor_from_array(out)
279}
280
281pub fn typed_extract_diagonal<T: Copy + Zero + Clone>(
282 tensor: &TypedTensor<T>,
283 axis_a: usize,
284 axis_b: usize,
285) -> crate::Result<TypedTensor<T>> {
286 validate_axis("extract_diagonal", axis_a, tensor.shape.len())?;
287 validate_axis("extract_diagonal", axis_b, tensor.shape.len())?;
288 validate_axes_distinct("extract_diagonal", axis_a, axis_b)?;
289
290 let diag = host_view(tensor)?
291 .diagonal_view(&[(axis_a, axis_b)])
292 .map_err(|err| backend_failure("extract_diagonal", err))?;
293 let mut out = unsafe { typed_array_uninit(diag.dims()) };
295 copy_into(&mut out.view_mut(), &diag)
296 .map_err(|err| backend_failure("extract_diagonal", err))?;
297 Ok(tensor_from_array(out))
298}
299
300pub fn typed_embed_diagonal<T: Copy + Zero + Clone>(
301 tensor: &TypedTensor<T>,
302 axis_a: usize,
303 axis_b: usize,
304) -> crate::Result<TypedTensor<T>> {
305 validate_axis("embed_diagonal", axis_a, tensor.shape.len())?;
306 if axis_b > tensor.shape.len() {
307 return Err(crate::Error::AxisOutOfBounds {
308 op: "embed_diagonal",
309 axis: axis_b,
310 rank: tensor.shape.len(),
311 });
312 }
313
314 let n = tensor.shape[axis_a];
315 let mut out_shape = tensor.shape.clone();
316 out_shape.insert(axis_b, n);
317 let mut out = TypedTensor::zeros(out_shape);
318
319 let in_rank = tensor.shape.len();
320 let out_rank = out.shape.len();
321 let mut in_idx = vec![0usize; in_rank];
322 let mut out_idx = vec![0usize; out_rank];
323
324 let input_data = match &tensor.buffer {
325 crate::Buffer::Host(data) => data,
326 crate::Buffer::Backend(_) => {
327 return Err(crate::Error::BackendFailure {
328 op: "embed_diagonal",
329 message: "backend buffers are not supported for structural CPU helpers".into(),
330 })
331 }
332 #[cfg(feature = "cubecl")]
333 crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
334 };
335
336 for flat in 0..tensor.n_elements() {
337 flat_to_multi(flat, &tensor.shape, &mut in_idx);
338 let diag_val = in_idx[axis_a];
339 let mut src_axis = 0usize;
340 for out_axis in 0..out_rank {
341 if out_axis == axis_b {
342 out_idx[out_axis] = diag_val;
343 } else {
344 out_idx[out_axis] = in_idx[src_axis];
345 src_axis += 1;
346 }
347 }
348 *out.get_mut(&out_idx) = input_data[flat];
349 }
350 Ok(out)
351}
352
353pub fn typed_tril<T: Copy + Zero + Clone>(
354 tensor: &TypedTensor<T>,
355 k: i64,
356) -> crate::Result<TypedTensor<T>> {
357 typed_triangular_mask(tensor, k, false)
358}
359
360pub fn typed_triu<T: Copy + Zero + Clone>(
361 tensor: &TypedTensor<T>,
362 k: i64,
363) -> crate::Result<TypedTensor<T>> {
364 typed_triangular_mask(tensor, k, true)
365}
366
367fn typed_triangular_mask<T: Copy + Zero + Clone>(
368 tensor: &TypedTensor<T>,
369 k: i64,
370 upper: bool,
371) -> crate::Result<TypedTensor<T>> {
372 if tensor.shape.len() < 2 {
373 return Err(crate::Error::RankMismatch {
374 op: if upper { "triu" } else { "tril" },
375 expected: 2,
376 actual: tensor.shape.len(),
377 });
378 }
379
380 let rows = tensor.shape[0];
381 let cols = tensor.shape[1];
382 if tensor.shape.contains(&0) {
383 return Ok(tensor.clone());
384 }
385
386 let batch_count: usize = tensor.shape[2..].iter().product();
387 let block_size = rows * cols;
388 let mut out = tensor.clone();
389 let data = match &mut out.buffer {
390 crate::Buffer::Host(data) => data,
391 crate::Buffer::Backend(_) => {
392 return Err(crate::Error::BackendFailure {
393 op: if upper { "triu" } else { "tril" },
394 message: "backend buffers are not supported for structural CPU helpers".into(),
395 })
396 }
397 #[cfg(feature = "cubecl")]
398 crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
399 };
400
401 for batch_idx in 0..batch_count {
402 let base = batch_idx * block_size;
403 for col in 0..cols {
404 let boundary = col as i64 - k;
405 for row in 0..rows {
406 let keep = if upper {
407 (row as i64) <= boundary
408 } else {
409 (row as i64) >= boundary
410 };
411 if !keep {
412 data[base + row + col * rows] = T::zero();
413 }
414 }
415 }
416 }
417
418 Ok(out)
419}