Skip to main content

tenferro_tensor/cpu/
structural.rs

1use num_complex::{Complex32, Complex64};
2use num_traits::Zero;
3use strided_kernel::{copy_into, map_into, Identity, StridedView};
4
5use crate::{
6    types::{flat_to_multi, Tensor, TypedTensor},
7    DType,
8};
9
10use super::{tensor_from_array, typed_array_uninit, typed_view};
11
12fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
13    crate::Error::BackendFailure {
14        op,
15        message: err.to_string(),
16    }
17}
18
19fn validate_rank(op: &'static str, expected: usize, actual: usize) -> crate::Result<()> {
20    if expected != actual {
21        return Err(crate::Error::RankMismatch {
22            op,
23            expected,
24            actual,
25        });
26    }
27    Ok(())
28}
29
30fn validate_axis(op: &'static str, axis: usize, rank: usize) -> crate::Result<()> {
31    if axis >= rank {
32        return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
33    }
34    Ok(())
35}
36
37fn validate_axes_distinct(op: &'static str, axis_a: usize, axis_b: usize) -> crate::Result<()> {
38    if axis_a == axis_b {
39        return Err(crate::Error::DuplicateAxis {
40            op,
41            axis: axis_a,
42            role: "axes",
43        });
44    }
45    Ok(())
46}
47
48fn validate_permutation(op: &'static str, perm: &[usize], rank: usize) -> crate::Result<()> {
49    validate_rank(op, rank, perm.len())?;
50    let mut seen = vec![false; rank];
51    for &axis in perm {
52        validate_axis(op, axis, rank)?;
53        if seen[axis] {
54            return Err(crate::Error::DuplicateAxis {
55                op,
56                axis,
57                role: "perm",
58            });
59        }
60        seen[axis] = true;
61    }
62    Ok(())
63}
64
65fn host_view<T: Copy>(tensor: &TypedTensor<T>) -> crate::Result<StridedView<'_, T, Identity>> {
66    match &tensor.buffer {
67        crate::Buffer::Host(data) => {
68            let strides = crate::col_major_strides(&tensor.shape);
69            StridedView::new(data, &tensor.shape, &strides, 0)
70                .map_err(|err| backend_failure("structural", err))
71        }
72        crate::Buffer::Backend(_) => Err(crate::Error::BackendFailure {
73            op: "structural",
74            message: "backend buffers are not supported for structural CPU helpers".into(),
75        }),
76        #[cfg(feature = "cubecl")]
77        crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
78    }
79}
80
81fn copy_view_to_array<T: Copy + Clone>(
82    op: &'static str,
83    mut out: strided_kernel::StridedArray<T>,
84    src: &StridedView<'_, T>,
85) -> crate::Result<TypedTensor<T>> {
86    copy_into(&mut out.view_mut(), src).map_err(|err| backend_failure(op, err))?;
87    Ok(tensor_from_array(out))
88}
89
90pub fn transpose(input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
91    match input {
92        Tensor::F32(t) => Ok(Tensor::F32(typed_transpose(t, perm)?)),
93        Tensor::F64(t) => Ok(Tensor::F64(typed_transpose(t, perm)?)),
94        Tensor::C32(t) => Ok(Tensor::C32(typed_transpose(t, perm)?)),
95        Tensor::C64(t) => Ok(Tensor::C64(typed_transpose(t, perm)?)),
96    }
97}
98
99pub fn reshape(input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
100    match input {
101        Tensor::F32(t) => Ok(Tensor::F32(typed_reshape(t, shape)?)),
102        Tensor::F64(t) => Ok(Tensor::F64(typed_reshape(t, shape)?)),
103        Tensor::C32(t) => Ok(Tensor::C32(typed_reshape(t, shape)?)),
104        Tensor::C64(t) => Ok(Tensor::C64(typed_reshape(t, shape)?)),
105    }
106}
107
108pub fn broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> crate::Result<Tensor> {
109    match input {
110        Tensor::F32(t) => Ok(Tensor::F32(typed_broadcast_in_dim(t, shape, dims)?)),
111        Tensor::F64(t) => Ok(Tensor::F64(typed_broadcast_in_dim(t, shape, dims)?)),
112        Tensor::C32(t) => Ok(Tensor::C32(typed_broadcast_in_dim(t, shape, dims)?)),
113        Tensor::C64(t) => Ok(Tensor::C64(typed_broadcast_in_dim(t, shape, dims)?)),
114    }
115}
116
117pub fn convert(input: &Tensor, to: DType) -> Tensor {
118    match (input, to) {
119        (Tensor::F32(t), DType::F32) => Tensor::F32(t.clone()),
120        (Tensor::F32(t), DType::F64) => Tensor::F64(typed_convert(t, |x| x as f64)),
121        (Tensor::F32(t), DType::C32) => Tensor::C32(typed_convert(t, |x| Complex32::new(x, 0.0))),
122        (Tensor::F32(t), DType::C64) => {
123            Tensor::C64(typed_convert(t, |x| Complex64::new(x as f64, 0.0)))
124        }
125        (Tensor::F64(t), DType::F32) => Tensor::F32(typed_convert(t, |x| x as f32)),
126        (Tensor::F64(t), DType::F64) => Tensor::F64(t.clone()),
127        (Tensor::F64(t), DType::C32) => {
128            Tensor::C32(typed_convert(t, |x| Complex32::new(x as f32, 0.0)))
129        }
130        (Tensor::F64(t), DType::C64) => Tensor::C64(typed_convert(t, |x| Complex64::new(x, 0.0))),
131        (Tensor::C32(t), DType::F32) => Tensor::F32(typed_convert(t, |z| z.re)),
132        (Tensor::C32(t), DType::F64) => Tensor::F64(typed_convert(t, |z| z.re as f64)),
133        (Tensor::C32(t), DType::C32) => Tensor::C32(t.clone()),
134        (Tensor::C32(t), DType::C64) => Tensor::C64(typed_convert(t, |z| {
135            Complex64::new(z.re as f64, z.im as f64)
136        })),
137        (Tensor::C64(t), DType::F32) => Tensor::F32(typed_convert(t, |z| z.re as f32)),
138        (Tensor::C64(t), DType::F64) => Tensor::F64(typed_convert(t, |z| z.re)),
139        (Tensor::C64(t), DType::C32) => Tensor::C32(typed_convert(t, |z| {
140            Complex32::new(z.re as f32, z.im as f32)
141        })),
142        (Tensor::C64(t), DType::C64) => Tensor::C64(t.clone()),
143    }
144}
145
146pub fn extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
147    match input {
148        Tensor::F32(t) => Ok(Tensor::F32(typed_extract_diagonal(t, axis_a, axis_b)?)),
149        Tensor::F64(t) => Ok(Tensor::F64(typed_extract_diagonal(t, axis_a, axis_b)?)),
150        Tensor::C32(t) => Ok(Tensor::C32(typed_extract_diagonal(t, axis_a, axis_b)?)),
151        Tensor::C64(t) => Ok(Tensor::C64(typed_extract_diagonal(t, axis_a, axis_b)?)),
152    }
153}
154
155pub fn embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
156    match input {
157        Tensor::F32(t) => Ok(Tensor::F32(typed_embed_diagonal(t, axis_a, axis_b)?)),
158        Tensor::F64(t) => Ok(Tensor::F64(typed_embed_diagonal(t, axis_a, axis_b)?)),
159        Tensor::C32(t) => Ok(Tensor::C32(typed_embed_diagonal(t, axis_a, axis_b)?)),
160        Tensor::C64(t) => Ok(Tensor::C64(typed_embed_diagonal(t, axis_a, axis_b)?)),
161    }
162}
163
164pub fn tril(input: &Tensor, k: i64) -> crate::Result<Tensor> {
165    match input {
166        Tensor::F32(t) => Ok(Tensor::F32(typed_tril(t, k)?)),
167        Tensor::F64(t) => Ok(Tensor::F64(typed_tril(t, k)?)),
168        Tensor::C32(t) => Ok(Tensor::C32(typed_tril(t, k)?)),
169        Tensor::C64(t) => Ok(Tensor::C64(typed_tril(t, k)?)),
170    }
171}
172
173pub fn triu(input: &Tensor, k: i64) -> crate::Result<Tensor> {
174    match input {
175        Tensor::F32(t) => Ok(Tensor::F32(typed_triu(t, k)?)),
176        Tensor::F64(t) => Ok(Tensor::F64(typed_triu(t, k)?)),
177        Tensor::C32(t) => Ok(Tensor::C32(typed_triu(t, k)?)),
178        Tensor::C64(t) => Ok(Tensor::C64(typed_triu(t, k)?)),
179    }
180}
181
182pub fn typed_transpose<T: Copy + Zero + Clone>(
183    tensor: &TypedTensor<T>,
184    perm: &[usize],
185) -> crate::Result<TypedTensor<T>> {
186    validate_permutation("transpose", perm, tensor.shape.len())?;
187    let src = host_view(tensor)?;
188    let permuted = src
189        .permute(perm)
190        .map_err(|err| backend_failure("transpose", err))?;
191    // SAFETY: copy_into overwrites every output element.
192    let out = unsafe { typed_array_uninit(permuted.dims()) };
193    copy_view_to_array("transpose", out, &permuted)
194}
195
196pub fn typed_reshape<T: Clone>(
197    tensor: &TypedTensor<T>,
198    shape: &[usize],
199) -> crate::Result<TypedTensor<T>> {
200    let old_n: usize = tensor.shape.iter().product();
201    let new_n: usize = shape.iter().product();
202    if old_n != new_n {
203        return Err(crate::Error::ShapeMismatch {
204            op: "reshape",
205            lhs: tensor.shape.clone(),
206            rhs: shape.to_vec(),
207        });
208    }
209    Ok(TypedTensor {
210        buffer: tensor.buffer.clone(),
211        shape: shape.to_vec(),
212        placement: tensor.placement.clone(),
213    })
214}
215
216pub fn typed_broadcast_in_dim<T: Copy + Zero + Clone>(
217    tensor: &TypedTensor<T>,
218    shape: &[usize],
219    dims: &[usize],
220) -> crate::Result<TypedTensor<T>> {
221    validate_rank("broadcast_in_dim", tensor.shape.len(), dims.len())?;
222    let mut seen = vec![false; shape.len()];
223    let mut base_dims = vec![1usize; shape.len()];
224    let mut base_strides = vec![0isize; shape.len()];
225    let source_strides = crate::col_major_strides(&tensor.shape);
226    for (src_axis, &dst_axis) in dims.iter().enumerate() {
227        validate_axis("broadcast_in_dim", dst_axis, shape.len())?;
228        if seen[dst_axis] {
229            return Err(crate::Error::DuplicateAxis {
230                op: "broadcast_in_dim",
231                axis: dst_axis,
232                role: "dims",
233            });
234        }
235        seen[dst_axis] = true;
236        let source_dim = tensor.shape[src_axis];
237        let target_dim = shape[dst_axis];
238        if source_dim != target_dim && source_dim != 1 {
239            return Err(crate::Error::ShapeMismatch {
240                op: "broadcast_in_dim",
241                lhs: tensor.shape.clone(),
242                rhs: shape.to_vec(),
243            });
244        }
245        base_dims[dst_axis] = source_dim;
246        base_strides[dst_axis] = source_strides[src_axis];
247    }
248    let base: StridedView<'_, T, Identity> = match &tensor.buffer {
249        crate::Buffer::Host(data) => StridedView::new(data, &base_dims, &base_strides, 0)
250            .map_err(|err| backend_failure("broadcast_in_dim", err))?,
251        crate::Buffer::Backend(_) => {
252            return Err(crate::Error::BackendFailure {
253                op: "broadcast_in_dim",
254                message: "backend buffers are not supported for structural CPU helpers".into(),
255            })
256        }
257        #[cfg(feature = "cubecl")]
258        crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
259    };
260    let broadcast: StridedView<'_, T, Identity> = base
261        .broadcast(shape)
262        .map_err(|err| backend_failure("broadcast_in_dim", err))?;
263    // SAFETY: copy_into overwrites every output element.
264    let mut out = unsafe { typed_array_uninit(shape) };
265    copy_into(&mut out.view_mut(), &broadcast)
266        .map_err(|err| backend_failure("broadcast_in_dim", err))?;
267    Ok(tensor_from_array(out))
268}
269
270fn typed_convert<S, T>(tensor: &TypedTensor<S>, f: impl Fn(S) -> T) -> TypedTensor<T>
271where
272    S: Copy,
273    T: Copy + Clone + Zero,
274{
275    // SAFETY: map_into overwrites every output element.
276    let mut out = unsafe { typed_array_uninit(&tensor.shape) };
277    map_into(&mut out.view_mut(), &typed_view(tensor), f).expect("typed_convert");
278    tensor_from_array(out)
279}
280
281pub fn typed_extract_diagonal<T: Copy + Zero + Clone>(
282    tensor: &TypedTensor<T>,
283    axis_a: usize,
284    axis_b: usize,
285) -> crate::Result<TypedTensor<T>> {
286    validate_axis("extract_diagonal", axis_a, tensor.shape.len())?;
287    validate_axis("extract_diagonal", axis_b, tensor.shape.len())?;
288    validate_axes_distinct("extract_diagonal", axis_a, axis_b)?;
289
290    let diag = host_view(tensor)?
291        .diagonal_view(&[(axis_a, axis_b)])
292        .map_err(|err| backend_failure("extract_diagonal", err))?;
293    // SAFETY: copy_into overwrites every output element.
294    let mut out = unsafe { typed_array_uninit(diag.dims()) };
295    copy_into(&mut out.view_mut(), &diag)
296        .map_err(|err| backend_failure("extract_diagonal", err))?;
297    Ok(tensor_from_array(out))
298}
299
300pub fn typed_embed_diagonal<T: Copy + Zero + Clone>(
301    tensor: &TypedTensor<T>,
302    axis_a: usize,
303    axis_b: usize,
304) -> crate::Result<TypedTensor<T>> {
305    validate_axis("embed_diagonal", axis_a, tensor.shape.len())?;
306    if axis_b > tensor.shape.len() {
307        return Err(crate::Error::AxisOutOfBounds {
308            op: "embed_diagonal",
309            axis: axis_b,
310            rank: tensor.shape.len(),
311        });
312    }
313
314    let n = tensor.shape[axis_a];
315    let mut out_shape = tensor.shape.clone();
316    out_shape.insert(axis_b, n);
317    let mut out = TypedTensor::zeros(out_shape);
318
319    let in_rank = tensor.shape.len();
320    let out_rank = out.shape.len();
321    let mut in_idx = vec![0usize; in_rank];
322    let mut out_idx = vec![0usize; out_rank];
323
324    let input_data = match &tensor.buffer {
325        crate::Buffer::Host(data) => data,
326        crate::Buffer::Backend(_) => {
327            return Err(crate::Error::BackendFailure {
328                op: "embed_diagonal",
329                message: "backend buffers are not supported for structural CPU helpers".into(),
330            })
331        }
332        #[cfg(feature = "cubecl")]
333        crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
334    };
335
336    for flat in 0..tensor.n_elements() {
337        flat_to_multi(flat, &tensor.shape, &mut in_idx);
338        let diag_val = in_idx[axis_a];
339        let mut src_axis = 0usize;
340        for out_axis in 0..out_rank {
341            if out_axis == axis_b {
342                out_idx[out_axis] = diag_val;
343            } else {
344                out_idx[out_axis] = in_idx[src_axis];
345                src_axis += 1;
346            }
347        }
348        *out.get_mut(&out_idx) = input_data[flat];
349    }
350    Ok(out)
351}
352
353pub fn typed_tril<T: Copy + Zero + Clone>(
354    tensor: &TypedTensor<T>,
355    k: i64,
356) -> crate::Result<TypedTensor<T>> {
357    typed_triangular_mask(tensor, k, false)
358}
359
360pub fn typed_triu<T: Copy + Zero + Clone>(
361    tensor: &TypedTensor<T>,
362    k: i64,
363) -> crate::Result<TypedTensor<T>> {
364    typed_triangular_mask(tensor, k, true)
365}
366
367fn typed_triangular_mask<T: Copy + Zero + Clone>(
368    tensor: &TypedTensor<T>,
369    k: i64,
370    upper: bool,
371) -> crate::Result<TypedTensor<T>> {
372    if tensor.shape.len() < 2 {
373        return Err(crate::Error::RankMismatch {
374            op: if upper { "triu" } else { "tril" },
375            expected: 2,
376            actual: tensor.shape.len(),
377        });
378    }
379
380    let rows = tensor.shape[0];
381    let cols = tensor.shape[1];
382    if tensor.shape.contains(&0) {
383        return Ok(tensor.clone());
384    }
385
386    let batch_count: usize = tensor.shape[2..].iter().product();
387    let block_size = rows * cols;
388    let mut out = tensor.clone();
389    let data = match &mut out.buffer {
390        crate::Buffer::Host(data) => data,
391        crate::Buffer::Backend(_) => {
392            return Err(crate::Error::BackendFailure {
393                op: if upper { "triu" } else { "tril" },
394                message: "backend buffers are not supported for structural CPU helpers".into(),
395            })
396        }
397        #[cfg(feature = "cubecl")]
398        crate::Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
399    };
400
401    for batch_idx in 0..batch_count {
402        let base = batch_idx * block_size;
403        for col in 0..cols {
404            let boundary = col as i64 - k;
405            for row in 0..rows {
406                let keep = if upper {
407                    (row as i64) <= boundary
408                } else {
409                    (row as i64) >= boundary
410                };
411                if !keep {
412                    data[base + row + col * rows] = T::zero();
413                }
414            }
415        }
416    }
417
418    Ok(out)
419}