Skip to main content

tenferro_cpu/
structural.rs

1use num_complex::{Complex32, Complex64};
2use num_traits::Zero;
3use strided_kernel::{col_major_strides, copy_into, map_into, Identity, StridedView};
4
5use crate::{
6    buffer_pool::{BufferPool, PoolScalar},
7    flat_to_multi,
8};
9use tenferro_tensor::{DType, Tensor, TensorRank, TypedTensor, TypedTensorView};
10
11#[cfg(test)]
12use super::typed_array_uninit;
13use super::{
14    cpu_backend_buffer_error, tensor_from_array, typed_array_uninit_from_pool, typed_view,
15    typed_view_from_view,
16};
17
18fn with_local_pool<T>(f: impl FnOnce(&mut BufferPool) -> T) -> T {
19    let mut buffers = BufferPool::new();
20    f(&mut buffers)
21}
22
23fn validate_rank(op: &'static str, expected: usize, actual: usize) -> crate::Result<()> {
24    if expected != actual {
25        return Err(crate::Error::RankMismatch {
26            op,
27            expected,
28            actual,
29        });
30    }
31    Ok(())
32}
33
34fn validate_axis(op: &'static str, axis: usize, rank: usize) -> crate::Result<()> {
35    if axis >= rank {
36        return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
37    }
38    Ok(())
39}
40
41fn validate_axes_distinct(op: &'static str, axis_a: usize, axis_b: usize) -> crate::Result<()> {
42    if axis_a == axis_b {
43        return Err(crate::Error::DuplicateAxis {
44            op,
45            axis: axis_a,
46            role: "axes",
47        });
48    }
49    Ok(())
50}
51
52fn checked_shape_product(
53    op: &'static str,
54    role: &'static str,
55    shape: &[usize],
56) -> crate::Result<usize> {
57    shape.iter().try_fold(1usize, |acc, &dim| {
58        acc.checked_mul(dim)
59            .ok_or_else(|| crate::Error::InvalidConfig {
60                op,
61                message: format!("{role} element count overflows usize"),
62            })
63    })
64}
65
66fn validate_permutation(op: &'static str, perm: &[usize], rank: usize) -> crate::Result<()> {
67    validate_rank(op, rank, perm.len())?;
68    let mut seen = vec![false; rank];
69    for &axis in perm {
70        validate_axis(op, axis, rank)?;
71        if seen[axis] {
72            return Err(crate::Error::DuplicateAxis {
73                op,
74                axis,
75                role: "perm",
76            });
77        }
78        seen[axis] = true;
79    }
80    Ok(())
81}
82
83macro_rules! dispatch_tensor_unary_result {
84    ($input:expr, |$tensor:ident| $body:expr) => {
85        match $input {
86            Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
87            Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
88            Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
89            Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
90            Tensor::Bool($tensor) => Ok(Tensor::Bool($body?)),
91            Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
92            Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
93        }
94    };
95}
96
97macro_rules! dispatch_tensor_unary_with_bool_special_result {
98    ($input:expr, |$tensor:ident| $body:expr, bool |$bool_tensor:ident| $bool_body:expr) => {
99        match $input {
100            Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
101            Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
102            Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
103            Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
104            Tensor::Bool($bool_tensor) => Ok(Tensor::Bool($bool_body?)),
105            Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
106            Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
107        }
108    };
109}
110
111fn host_view<'a, T: Copy>(
112    op: &'static str,
113    tensor: &'a TypedTensor<T>,
114) -> crate::Result<StridedView<'a, T, Identity>> {
115    match tensor.buffer() {
116        crate::Buffer::Host(data) => {
117            let strides = col_major_strides(tensor.shape());
118            StridedView::new(data.as_slice(), tensor.shape(), &strides, 0)
119                .map_err(|err| crate::Error::backend_failure(op, err))
120        }
121        crate::Buffer::Backend(_) => Err(cpu_backend_buffer_error(op)),
122    }
123}
124
125fn copy_view_to_array<T: Copy + Clone + Send + Sync>(
126    op: &'static str,
127    mut out: strided_kernel::StridedArray<T>,
128    src: &StridedView<'_, T>,
129) -> crate::Result<TypedTensor<T>> {
130    copy_into(&mut out.view_mut(), src).map_err(|err| crate::Error::backend_failure(op, err))?;
131    Ok(tensor_from_array(out))
132}
133
134fn zeroed_tensor_from_pool<T>(
135    buffers: &mut BufferPool,
136    op: &'static str,
137    shape: Vec<usize>,
138) -> crate::Result<TypedTensor<T>>
139where
140    T: Zero + Clone + PoolScalar + 'static,
141{
142    filled_tensor_from_pool(buffers, op, shape, T::zero())
143}
144
145fn filled_tensor_from_pool<T>(
146    buffers: &mut BufferPool,
147    op: &'static str,
148    shape: Vec<usize>,
149    fill: T,
150) -> crate::Result<TypedTensor<T>>
151where
152    T: Copy + Clone + PoolScalar + 'static,
153{
154    let len = checked_shape_product(op, "output shape", &shape)?;
155    // SAFETY: every pooled element is initialized with `fill` before returning.
156    let mut data = unsafe { T::pool_acquire(buffers, len) };
157    data.fill(fill);
158    TypedTensor::from_vec_col_major(shape, data)
159}
160
161fn clone_host_tensor_from_pool<T>(
162    buffers: &mut BufferPool,
163    op: &'static str,
164    tensor: &TypedTensor<T>,
165) -> crate::Result<TypedTensor<T>>
166where
167    T: Copy + PoolScalar + 'static,
168{
169    let input = match tensor.buffer() {
170        crate::Buffer::Host(data) => data.as_slice(),
171        crate::Buffer::Backend(_) => return Err(cpu_backend_buffer_error(op)),
172    };
173    // SAFETY: copy_from_slice initializes every pooled element before returning.
174    let mut data = unsafe { T::pool_acquire(buffers, input.len()) };
175    data.copy_from_slice(input);
176    TypedTensor::from_buffer_col_major(
177        tensor.shape().to_vec(),
178        crate::Buffer::Host(data),
179        tensor.placement().clone(),
180    )
181}
182
183pub fn transpose(input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
184    with_local_pool(|buffers| transpose_with_pool(buffers, input, perm))
185}
186
187pub(crate) fn transpose_with_pool(
188    buffers: &mut BufferPool,
189    input: &Tensor,
190    perm: &[usize],
191) -> crate::Result<Tensor> {
192    dispatch_tensor_unary_result!(input, |t| typed_transpose_with_pool(buffers, t, perm))
193}
194
195pub fn reshape(input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
196    dispatch_tensor_unary_result!(input, |t| typed_reshape(t, shape))
197}
198
199pub fn broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> crate::Result<Tensor> {
200    with_local_pool(|buffers| broadcast_in_dim_with_pool(buffers, input, shape, dims))
201}
202
203pub(crate) fn broadcast_in_dim_with_pool(
204    buffers: &mut BufferPool,
205    input: &Tensor,
206    shape: &[usize],
207    dims: &[usize],
208) -> crate::Result<Tensor> {
209    dispatch_tensor_unary_result!(input, |t| typed_broadcast_in_dim_with_pool(
210        buffers, t, shape, dims
211    ))
212}
213
214/// Convert a tensor to another dtype using checked dtype conversion.
215///
216/// Use `TensorStructural::cast` when an explicit lossy dtype projection is
217/// intended.
218///
219/// # Examples
220///
221/// ```rust
222/// use tenferro_cpu::CpuBackend;
223/// use tenferro_tensor::{DType, Tensor, TensorStructural};
224///
225/// let mut backend = CpuBackend::new();
226/// let x = Tensor::from_vec_col_major(vec![2], vec![1.0_f32, 2.0]).unwrap();
227/// let y = backend.convert(&x, DType::F64).unwrap();
228/// assert_eq!(y.as_slice::<f64>().unwrap(), &[1.0, 2.0]);
229/// ```
230///
231/// # Errors
232///
233/// Returns an error when the requested conversion is outside tenferro's checked
234/// dtype-promotion lattice.
235pub fn convert(input: &Tensor, to: DType) -> crate::Result<Tensor> {
236    with_local_pool(|buffers| convert_with_pool(buffers, input, to))
237}
238
239pub(crate) fn convert_with_pool(
240    buffers: &mut BufferPool,
241    input: &Tensor,
242    to: DType,
243) -> crate::Result<Tensor> {
244    tenferro_tensor::validate::validate_convert_dtype("convert", input.dtype(), to)?;
245    cast_with_pool(buffers, input, to)
246}
247
248pub(crate) fn cast_with_pool(
249    buffers: &mut BufferPool,
250    input: &Tensor,
251    to: DType,
252) -> crate::Result<Tensor> {
253    macro_rules! converted {
254        ($variant:ident, $tensor:expr, $map:expr) => {
255            Ok(Tensor::$variant(typed_convert_with_pool(
256                buffers, $tensor, $map,
257            )?))
258        };
259    }
260
261    match (input, to) {
262        (Tensor::F32(t), DType::F32) => Ok(Tensor::F32(t.clone())),
263        (Tensor::F32(t), DType::F64) => converted!(F64, t, |x| x as f64),
264        (Tensor::F32(t), DType::I32) => converted!(I32, t, |x| x as i32),
265        (Tensor::F32(t), DType::I64) => converted!(I64, t, |x| x as i64),
266        (Tensor::F32(t), DType::Bool) => converted!(Bool, t, |x| x != 0.0),
267        (Tensor::F32(t), DType::C32) => converted!(C32, t, |x| Complex32::new(x, 0.0)),
268        (Tensor::F32(t), DType::C64) => {
269            converted!(C64, t, |x| Complex64::new(x as f64, 0.0))
270        }
271        (Tensor::F64(t), DType::F32) => converted!(F32, t, |x| x as f32),
272        (Tensor::F64(t), DType::F64) => Ok(Tensor::F64(t.clone())),
273        (Tensor::F64(t), DType::I32) => converted!(I32, t, |x| x as i32),
274        (Tensor::F64(t), DType::I64) => converted!(I64, t, |x| x as i64),
275        (Tensor::F64(t), DType::Bool) => converted!(Bool, t, |x| x != 0.0),
276        (Tensor::F64(t), DType::C32) => {
277            converted!(C32, t, |x| Complex32::new(x as f32, 0.0))
278        }
279        (Tensor::F64(t), DType::C64) => converted!(C64, t, |x| Complex64::new(x, 0.0)),
280        (Tensor::I32(t), DType::F32) => converted!(F32, t, |x| x as f32),
281        (Tensor::I32(t), DType::F64) => converted!(F64, t, |x| x as f64),
282        (Tensor::I32(t), DType::I32) => Ok(Tensor::I32(t.clone())),
283        (Tensor::I32(t), DType::I64) => converted!(I64, t, |x| x as i64),
284        (Tensor::I32(t), DType::Bool) => converted!(Bool, t, |x| x != 0),
285        (Tensor::I32(t), DType::C32) => {
286            converted!(C32, t, |x| Complex32::new(x as f32, 0.0))
287        }
288        (Tensor::I32(t), DType::C64) => {
289            converted!(C64, t, |x| Complex64::new(x as f64, 0.0))
290        }
291        (Tensor::I64(t), DType::F32) => converted!(F32, t, |x| x as f32),
292        (Tensor::I64(t), DType::F64) => converted!(F64, t, |x| x as f64),
293        (Tensor::I64(t), DType::I32) => converted!(I32, t, |x| x as i32),
294        (Tensor::I64(t), DType::I64) => Ok(Tensor::I64(t.clone())),
295        (Tensor::I64(t), DType::Bool) => converted!(Bool, t, |x| x != 0),
296        (Tensor::I64(t), DType::C32) => {
297            converted!(C32, t, |x| Complex32::new(x as f32, 0.0))
298        }
299        (Tensor::I64(t), DType::C64) => {
300            converted!(C64, t, |x| Complex64::new(x as f64, 0.0))
301        }
302        (Tensor::Bool(t), DType::F32) => converted!(F32, t, |x| if x { 1.0 } else { 0.0 }),
303        (Tensor::Bool(t), DType::F64) => converted!(F64, t, |x| if x { 1.0 } else { 0.0 }),
304        (Tensor::Bool(t), DType::I32) => converted!(I32, t, |x| if x { 1 } else { 0 }),
305        (Tensor::Bool(t), DType::I64) => converted!(I64, t, |x| if x { 1 } else { 0 }),
306        (Tensor::Bool(t), DType::Bool) => Ok(Tensor::Bool(t.clone())),
307        (Tensor::Bool(t), DType::C32) => {
308            converted!(C32, t, |x| Complex32::new(if x { 1.0 } else { 0.0 }, 0.0))
309        }
310        (Tensor::Bool(t), DType::C64) => {
311            converted!(C64, t, |x| Complex64::new(if x { 1.0 } else { 0.0 }, 0.0))
312        }
313        (Tensor::C32(t), DType::F32) => converted!(F32, t, |z| z.re),
314        (Tensor::C32(t), DType::F64) => converted!(F64, t, |z| z.re as f64),
315        (Tensor::C32(t), DType::I32) => converted!(I32, t, |z| z.re as i32),
316        (Tensor::C32(t), DType::I64) => converted!(I64, t, |z| z.re as i64),
317        (Tensor::C32(t), DType::Bool) => converted!(Bool, t, |z| z.re != 0.0 || z.im != 0.0),
318        (Tensor::C32(t), DType::C32) => Ok(Tensor::C32(t.clone())),
319        (Tensor::C32(t), DType::C64) => {
320            converted!(C64, t, |z| Complex64::new(z.re as f64, z.im as f64))
321        }
322        (Tensor::C64(t), DType::F32) => converted!(F32, t, |z| z.re as f32),
323        (Tensor::C64(t), DType::F64) => converted!(F64, t, |z| z.re),
324        (Tensor::C64(t), DType::I32) => converted!(I32, t, |z| z.re as i32),
325        (Tensor::C64(t), DType::I64) => converted!(I64, t, |z| z.re as i64),
326        (Tensor::C64(t), DType::Bool) => converted!(Bool, t, |z| z.re != 0.0 || z.im != 0.0),
327        (Tensor::C64(t), DType::C32) => {
328            converted!(C32, t, |z| Complex32::new(z.re as f32, z.im as f32))
329        }
330        (Tensor::C64(t), DType::C64) => Ok(Tensor::C64(t.clone())),
331    }
332}
333
334pub fn extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
335    with_local_pool(|buffers| extract_diagonal_with_pool(buffers, input, axis_a, axis_b))
336}
337
338pub(crate) fn extract_diagonal_with_pool(
339    buffers: &mut BufferPool,
340    input: &Tensor,
341    axis_a: usize,
342    axis_b: usize,
343) -> crate::Result<Tensor> {
344    dispatch_tensor_unary_result!(input, |t| typed_extract_diagonal_with_pool(
345        buffers, t, axis_a, axis_b
346    ))
347}
348
349pub fn embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
350    with_local_pool(|buffers| embed_diagonal_with_pool(buffers, input, axis_a, axis_b))
351}
352
353pub(crate) fn embed_diagonal_with_pool(
354    buffers: &mut BufferPool,
355    input: &Tensor,
356    axis_a: usize,
357    axis_b: usize,
358) -> crate::Result<Tensor> {
359    dispatch_tensor_unary_with_bool_special_result!(
360        input,
361        |t| typed_embed_diagonal_with_pool(buffers, t, axis_a, axis_b),
362        bool | t
363            | typed_embed_diagonal_impl(t, axis_a, axis_b, |shape| {
364                filled_tensor_from_pool(buffers, "embed_diagonal", shape, false)
365            })
366    )
367}
368
369pub fn tril(input: &Tensor, k: i64) -> crate::Result<Tensor> {
370    with_local_pool(|buffers| tril_with_pool(buffers, input, k))
371}
372
373pub(crate) fn tril_with_pool(
374    buffers: &mut BufferPool,
375    input: &Tensor,
376    k: i64,
377) -> crate::Result<Tensor> {
378    dispatch_tensor_unary_with_bool_special_result!(
379        input,
380        |t| typed_tril_with_pool(buffers, t, k),
381        bool | t | typed_triangular_mask_with_fill_pool(buffers, t, k, false, false)
382    )
383}
384
385pub fn triu(input: &Tensor, k: i64) -> crate::Result<Tensor> {
386    with_local_pool(|buffers| triu_with_pool(buffers, input, k))
387}
388
389pub(crate) fn triu_with_pool(
390    buffers: &mut BufferPool,
391    input: &Tensor,
392    k: i64,
393) -> crate::Result<Tensor> {
394    dispatch_tensor_unary_with_bool_special_result!(
395        input,
396        |t| typed_triu_with_pool(buffers, t, k),
397        bool | t | typed_triangular_mask_with_fill_pool(buffers, t, k, true, false)
398    )
399}
400
401#[cfg(test)]
402pub(crate) fn typed_transpose<T: Copy + Clone + Send + Sync>(
403    tensor: &TypedTensor<T>,
404    perm: &[usize],
405) -> crate::Result<TypedTensor<T>> {
406    validate_permutation("transpose", perm, tensor.shape().len())?;
407    let src = host_view("transpose", tensor)?;
408    let permuted = src
409        .permute(perm)
410        .map_err(|err| crate::Error::backend_failure("transpose", err))?;
411    // SAFETY: copy_into overwrites every output element.
412    let out = unsafe { typed_array_uninit(permuted.dims()) };
413    copy_view_to_array("transpose", out, &permuted)
414}
415
416fn typed_transpose_view_impl<T, R>(
417    view: &TypedTensorView<'_, T, R>,
418    perm: &[usize],
419    make_out: impl FnOnce(&[usize]) -> strided_kernel::StridedArray<T>,
420) -> crate::Result<TypedTensor<T>>
421where
422    T: Copy + Clone + Send + Sync + 'static,
423    R: TensorRank,
424{
425    validate_permutation("transpose", perm, view.shape().len())?;
426    let src = typed_view_from_view("transpose", view)?;
427    let permuted = src
428        .permute(perm)
429        .map_err(|err| crate::Error::backend_failure("transpose", err))?;
430    checked_shape_product("transpose", "output shape", permuted.dims())?;
431    // SAFETY: copy_into overwrites every output element.
432    let out = make_out(permuted.dims());
433    copy_view_to_array("transpose", out, &permuted)
434}
435
436pub(crate) fn typed_transpose_with_pool<T>(
437    buffers: &mut BufferPool,
438    tensor: &TypedTensor<T>,
439    perm: &[usize],
440) -> crate::Result<TypedTensor<T>>
441where
442    T: Copy + Clone + PoolScalar + 'static,
443{
444    typed_transpose_view_with_pool(buffers, &tensor.as_view(), perm)
445}
446
447pub(crate) fn typed_transpose_view_with_pool<T, R>(
448    buffers: &mut BufferPool,
449    view: &TypedTensorView<'_, T, R>,
450    perm: &[usize],
451) -> crate::Result<TypedTensor<T>>
452where
453    T: Copy + Clone + PoolScalar + 'static,
454    R: TensorRank,
455{
456    typed_transpose_view_impl(view, perm, |shape| unsafe {
457        // SAFETY: transpose materialization copies every output element before returning.
458        typed_array_uninit_from_pool(buffers, shape)
459    })
460}
461
462pub fn typed_reshape<T: Clone + 'static>(
463    tensor: &TypedTensor<T>,
464    shape: &[usize],
465) -> crate::Result<TypedTensor<T>> {
466    let old_n = checked_shape_product("reshape", "input shape", tensor.shape())?;
467    let new_n = checked_shape_product("reshape", "output shape", shape)?;
468    if old_n != new_n {
469        return Err(crate::Error::ShapeMismatch {
470            op: "reshape",
471            lhs: tensor.shape().to_vec(),
472            rhs: shape.to_vec(),
473        });
474    }
475    TypedTensor::from_buffer_col_major(
476        shape.to_vec(),
477        tensor.buffer().clone(),
478        tensor.placement().clone(),
479    )
480}
481
482#[cfg(test)]
483pub(crate) fn typed_broadcast_in_dim<T: Copy + Clone + Send + Sync>(
484    tensor: &TypedTensor<T>,
485    shape: &[usize],
486    dims: &[usize],
487) -> crate::Result<TypedTensor<T>> {
488    typed_broadcast_in_dim_impl(tensor, shape, dims, |shape| unsafe {
489        // SAFETY: broadcast materialization writes every output element before returning.
490        typed_array_uninit(shape)
491    })
492}
493
494pub(crate) fn typed_broadcast_in_dim_with_pool<T>(
495    buffers: &mut BufferPool,
496    tensor: &TypedTensor<T>,
497    shape: &[usize],
498    dims: &[usize],
499) -> crate::Result<TypedTensor<T>>
500where
501    T: Copy + Clone + PoolScalar,
502{
503    typed_broadcast_in_dim_impl(tensor, shape, dims, |shape| unsafe {
504        // SAFETY: broadcast materialization writes every output element before returning.
505        typed_array_uninit_from_pool(buffers, shape)
506    })
507}
508
509fn typed_broadcast_in_dim_impl<T>(
510    tensor: &TypedTensor<T>,
511    shape: &[usize],
512    dims: &[usize],
513    make_out: impl FnOnce(&[usize]) -> strided_kernel::StridedArray<T>,
514) -> crate::Result<TypedTensor<T>>
515where
516    T: Copy + Clone + Send + Sync,
517{
518    validate_rank("broadcast_in_dim", tensor.shape().len(), dims.len())?;
519    let mut seen = vec![false; shape.len()];
520    let mut base_dims = vec![1usize; shape.len()];
521    let mut base_strides = vec![0isize; shape.len()];
522    let source_strides = col_major_strides(tensor.shape());
523    for (src_axis, &dst_axis) in dims.iter().enumerate() {
524        validate_axis("broadcast_in_dim", dst_axis, shape.len())?;
525        if seen[dst_axis] {
526            return Err(crate::Error::DuplicateAxis {
527                op: "broadcast_in_dim",
528                axis: dst_axis,
529                role: "dims",
530            });
531        }
532        seen[dst_axis] = true;
533        let source_dim = tensor.shape()[src_axis];
534        let target_dim = shape[dst_axis];
535        if source_dim != target_dim && source_dim != 1 {
536            return Err(crate::Error::ShapeMismatch {
537                op: "broadcast_in_dim",
538                lhs: tensor.shape().to_vec(),
539                rhs: shape.to_vec(),
540            });
541        }
542        base_dims[dst_axis] = source_dim;
543        base_strides[dst_axis] = source_strides[src_axis];
544    }
545    let base: StridedView<'_, T, Identity> = match tensor.buffer() {
546        crate::Buffer::Host(data) => {
547            StridedView::new(data.as_slice(), &base_dims, &base_strides, 0)
548                .map_err(|err| crate::Error::backend_failure("broadcast_in_dim", err))?
549        }
550        crate::Buffer::Backend(_) => return Err(cpu_backend_buffer_error("broadcast_in_dim")),
551    };
552    let broadcast: StridedView<'_, T, Identity> = base
553        .broadcast(shape)
554        .map_err(|err| crate::Error::backend_failure("broadcast_in_dim", err))?;
555    checked_shape_product("broadcast_in_dim", "output shape", shape)?;
556    // SAFETY: copy_into overwrites every output element.
557    let mut out = make_out(shape);
558    copy_into(&mut out.view_mut(), &broadcast)
559        .map_err(|err| crate::Error::backend_failure("broadcast_in_dim", err))?;
560    Ok(tensor_from_array(out))
561}
562
563fn typed_convert_with_pool<S, T>(
564    buffers: &mut BufferPool,
565    tensor: &TypedTensor<S>,
566    f: impl Fn(S) -> T + Sync,
567) -> crate::Result<TypedTensor<T>>
568where
569    S: Copy + Send + Sync,
570    T: Copy + Clone + PoolScalar,
571{
572    // SAFETY: map_into overwrites every output element.
573    let mut out = unsafe { typed_array_uninit_from_pool(buffers, tensor.shape()) };
574    map_into(&mut out.view_mut(), &typed_view("convert", tensor)?, f)
575        .map_err(|err| crate::Error::backend_failure("convert", err))?;
576    Ok(tensor_from_array(out))
577}
578
579#[cfg(test)]
580pub(crate) fn typed_extract_diagonal<T: Copy + Clone + Send + Sync>(
581    tensor: &TypedTensor<T>,
582    axis_a: usize,
583    axis_b: usize,
584) -> crate::Result<TypedTensor<T>> {
585    validate_axis("extract_diagonal", axis_a, tensor.shape().len())?;
586    validate_axis("extract_diagonal", axis_b, tensor.shape().len())?;
587    validate_axes_distinct("extract_diagonal", axis_a, axis_b)?;
588
589    let diag = host_view("extract_diagonal", tensor)?
590        .diagonal_view(&[(axis_a, axis_b)])
591        .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
592    // SAFETY: copy_into overwrites every output element.
593    let mut out = unsafe { typed_array_uninit(diag.dims()) };
594    copy_into(&mut out.view_mut(), &diag)
595        .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
596    Ok(tensor_from_array(out))
597}
598
599pub(crate) fn typed_extract_diagonal_with_pool<T>(
600    buffers: &mut BufferPool,
601    tensor: &TypedTensor<T>,
602    axis_a: usize,
603    axis_b: usize,
604) -> crate::Result<TypedTensor<T>>
605where
606    T: Copy + Clone + PoolScalar,
607{
608    validate_axis("extract_diagonal", axis_a, tensor.shape().len())?;
609    validate_axis("extract_diagonal", axis_b, tensor.shape().len())?;
610    validate_axes_distinct("extract_diagonal", axis_a, axis_b)?;
611
612    let diag = host_view("extract_diagonal", tensor)?
613        .diagonal_view(&[(axis_a, axis_b)])
614        .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
615    // SAFETY: copy_into overwrites every output element.
616    let mut out = unsafe { typed_array_uninit_from_pool(buffers, diag.dims()) };
617    copy_into(&mut out.view_mut(), &diag)
618        .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
619    Ok(tensor_from_array(out))
620}
621
622#[cfg(test)]
623pub(crate) fn typed_embed_diagonal<T: Copy + Zero + Clone>(
624    tensor: &TypedTensor<T>,
625    axis_a: usize,
626    axis_b: usize,
627) -> crate::Result<TypedTensor<T>> {
628    typed_embed_diagonal_impl(tensor, axis_a, axis_b, TypedTensor::zeros)
629}
630
631pub(crate) fn typed_embed_diagonal_with_pool<T>(
632    buffers: &mut BufferPool,
633    tensor: &TypedTensor<T>,
634    axis_a: usize,
635    axis_b: usize,
636) -> crate::Result<TypedTensor<T>>
637where
638    T: Copy + Zero + Clone + PoolScalar + 'static,
639{
640    typed_embed_diagonal_impl(tensor, axis_a, axis_b, |shape| {
641        zeroed_tensor_from_pool(buffers, "embed_diagonal", shape)
642    })
643}
644
645fn typed_embed_diagonal_impl<T>(
646    tensor: &TypedTensor<T>,
647    axis_a: usize,
648    axis_b: usize,
649    make_zeroed: impl FnOnce(Vec<usize>) -> crate::Result<TypedTensor<T>>,
650) -> crate::Result<TypedTensor<T>>
651where
652    T: Copy + Clone,
653{
654    validate_axis("embed_diagonal", axis_a, tensor.shape().len())?;
655    if axis_b > tensor.shape().len() {
656        return Err(crate::Error::AxisOutOfBounds {
657            op: "embed_diagonal",
658            axis: axis_b,
659            rank: tensor.shape().len(),
660        });
661    }
662
663    let n = tensor.shape()[axis_a];
664    let mut out_shape = tensor.shape().to_vec();
665    out_shape.insert(axis_b, n);
666    let mut out = make_zeroed(out_shape)?;
667
668    let in_rank = tensor.shape().len();
669    let out_rank = out.shape().len();
670    let mut in_idx = vec![0usize; in_rank];
671    let mut out_idx = vec![0usize; out_rank];
672
673    let input_data = match tensor.buffer() {
674        crate::Buffer::Host(data) => data.as_slice(),
675        crate::Buffer::Backend(_) => return Err(cpu_backend_buffer_error("embed_diagonal")),
676    };
677
678    // Intentionally sequential: embed_diagonal writes a sparse diagonal subset
679    // into a zeroed output and has no current strided-kernel parallel primitive.
680    for (flat, value) in input_data
681        .iter()
682        .copied()
683        .enumerate()
684        .take(tensor.n_elements())
685    {
686        flat_to_multi(flat, tensor.shape(), &mut in_idx);
687        let diag_val = in_idx[axis_a];
688        let mut src_axis = 0usize;
689        for (out_axis, out_slot) in out_idx.iter_mut().enumerate().take(out_rank) {
690            if out_axis == axis_b {
691                *out_slot = diag_val;
692            } else {
693                *out_slot = in_idx[src_axis];
694                src_axis += 1;
695            }
696        }
697        *out.get_mut(&out_idx)? = value;
698    }
699    Ok(out)
700}
701
702#[cfg(test)]
703pub(crate) fn typed_tril<T: Copy + Zero + Clone>(
704    tensor: &TypedTensor<T>,
705    k: i64,
706) -> crate::Result<TypedTensor<T>> {
707    typed_triangular_mask(tensor, k, false)
708}
709
710pub(crate) fn typed_tril_with_pool<T>(
711    buffers: &mut BufferPool,
712    tensor: &TypedTensor<T>,
713    k: i64,
714) -> crate::Result<TypedTensor<T>>
715where
716    T: Copy + Zero + Clone + PoolScalar + 'static,
717{
718    typed_triangular_mask_with_fill_pool(buffers, tensor, k, false, T::zero())
719}
720
721#[cfg(test)]
722pub(crate) fn typed_triu<T: Copy + Zero + Clone>(
723    tensor: &TypedTensor<T>,
724    k: i64,
725) -> crate::Result<TypedTensor<T>> {
726    typed_triangular_mask(tensor, k, true)
727}
728
729pub(crate) fn typed_triu_with_pool<T>(
730    buffers: &mut BufferPool,
731    tensor: &TypedTensor<T>,
732    k: i64,
733) -> crate::Result<TypedTensor<T>>
734where
735    T: Copy + Zero + Clone + PoolScalar + 'static,
736{
737    typed_triangular_mask_with_fill_pool(buffers, tensor, k, true, T::zero())
738}
739
740#[cfg(test)]
741fn typed_triangular_mask<T: Copy + Zero + Clone>(
742    tensor: &TypedTensor<T>,
743    k: i64,
744    upper: bool,
745) -> crate::Result<TypedTensor<T>> {
746    let op = if upper { "triu" } else { "tril" };
747    if tensor.shape().len() < 2 {
748        return Err(crate::Error::RankMismatch {
749            op,
750            expected: 2,
751            actual: tensor.shape().len(),
752        });
753    }
754
755    let rows = tensor.shape()[0];
756    let cols = tensor.shape()[1];
757    if tensor.shape().contains(&0) {
758        return Ok(tensor.clone());
759    }
760
761    let (batch_count, block_size) = checked_triangular_extent(op, tensor.shape(), rows, cols)?;
762    let mut out = tensor.clone();
763    let data = out.host_data_mut()?;
764
765    // Intentionally sequential: triangular masks are index-dependent in the
766    // innermost matrix plane and remain a dedicated CPU-kernel exception.
767    for batch_idx in 0..batch_count {
768        for col in 0..cols {
769            let boundary = col as i128 - k as i128;
770            for row in 0..rows {
771                let row_idx = row;
772                let row = row_idx as i128;
773                let keep = if upper {
774                    row <= boundary
775                } else {
776                    row >= boundary
777                };
778                if !keep {
779                    let offset =
780                        checked_triangular_offset(op, batch_idx, block_size, col, rows, row_idx)?;
781                    data[offset] = T::zero();
782                }
783            }
784        }
785    }
786
787    Ok(out)
788}
789
790fn typed_triangular_mask_with_fill_pool<T>(
791    buffers: &mut BufferPool,
792    tensor: &TypedTensor<T>,
793    k: i64,
794    upper: bool,
795    fill: T,
796) -> crate::Result<TypedTensor<T>>
797where
798    T: Copy + Clone + PoolScalar + 'static,
799{
800    let op = if upper { "triu" } else { "tril" };
801    if tensor.shape().len() < 2 {
802        return Err(crate::Error::RankMismatch {
803            op,
804            expected: 2,
805            actual: tensor.shape().len(),
806        });
807    }
808
809    let rows = tensor.shape()[0];
810    let cols = tensor.shape()[1];
811    if tensor.shape().contains(&0) {
812        return Ok(tensor.clone());
813    }
814
815    let (batch_count, block_size) = checked_triangular_extent(op, tensor.shape(), rows, cols)?;
816    let mut out = clone_host_tensor_from_pool(buffers, op, tensor)?;
817    let data = out.host_data_mut()?;
818
819    // Intentionally sequential: triangular masks are index-dependent in the
820    // innermost matrix plane and remain a dedicated CPU-kernel exception.
821    for batch_idx in 0..batch_count {
822        for col in 0..cols {
823            let boundary = col as i128 - k as i128;
824            for row in 0..rows {
825                let row_idx = row;
826                let row = row_idx as i128;
827                let keep = if upper {
828                    row <= boundary
829                } else {
830                    row >= boundary
831                };
832                if !keep {
833                    let offset =
834                        checked_triangular_offset(op, batch_idx, block_size, col, rows, row_idx)?;
835                    data[offset] = fill;
836                }
837            }
838        }
839    }
840
841    Ok(out)
842}
843
844fn checked_triangular_extent(
845    op: &'static str,
846    shape: &[usize],
847    rows: usize,
848    cols: usize,
849) -> crate::Result<(usize, usize)> {
850    let batch_count = shape[2..].iter().try_fold(1usize, |acc, &dim| {
851        acc.checked_mul(dim)
852            .ok_or_else(|| crate::Error::InvalidConfig {
853                op,
854                message: format!("batch extent overflows usize: {acc} * {dim}"),
855            })
856    })?;
857    let block_size = rows
858        .checked_mul(cols)
859        .ok_or_else(|| crate::Error::InvalidConfig {
860            op,
861            message: format!("matrix block size overflows usize: {rows} * {cols}"),
862        })?;
863    Ok((batch_count, block_size))
864}
865
866fn checked_triangular_offset(
867    op: &'static str,
868    batch_idx: usize,
869    block_size: usize,
870    col: usize,
871    rows: usize,
872    row_idx: usize,
873) -> crate::Result<usize> {
874    let base = batch_idx
875        .checked_mul(block_size)
876        .ok_or_else(|| crate::Error::InvalidConfig {
877            op,
878            message: format!("batch offset overflows usize: {batch_idx} * {block_size}"),
879        })?;
880    let col_offset = col
881        .checked_mul(rows)
882        .ok_or_else(|| crate::Error::InvalidConfig {
883            op,
884            message: format!("column offset overflows usize: {col} * {rows}"),
885        })?;
886    base.checked_add(col_offset)
887        .and_then(|offset| offset.checked_add(row_idx))
888        .ok_or_else(|| crate::Error::InvalidConfig {
889            op,
890            message: "triangular mask offset overflows usize".to_string(),
891        })
892}