1use num_complex::{Complex32, Complex64};
2use num_traits::Zero;
3use strided_kernel::{col_major_strides, copy_into, map_into, Identity, StridedView};
4
5use crate::{
6 buffer_pool::{BufferPool, PoolScalar},
7 flat_to_multi,
8};
9use tenferro_tensor::{DType, Tensor, TensorRank, TypedTensor, TypedTensorView};
10
11#[cfg(test)]
12use super::typed_array_uninit;
13use super::{
14 cpu_backend_buffer_error, tensor_from_array, typed_array_uninit_from_pool, typed_view,
15 typed_view_from_view,
16};
17
18fn with_local_pool<T>(f: impl FnOnce(&mut BufferPool) -> T) -> T {
19 let mut buffers = BufferPool::new();
20 f(&mut buffers)
21}
22
23fn validate_rank(op: &'static str, expected: usize, actual: usize) -> crate::Result<()> {
24 if expected != actual {
25 return Err(crate::Error::RankMismatch {
26 op,
27 expected,
28 actual,
29 });
30 }
31 Ok(())
32}
33
34fn validate_axis(op: &'static str, axis: usize, rank: usize) -> crate::Result<()> {
35 if axis >= rank {
36 return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
37 }
38 Ok(())
39}
40
41fn validate_axes_distinct(op: &'static str, axis_a: usize, axis_b: usize) -> crate::Result<()> {
42 if axis_a == axis_b {
43 return Err(crate::Error::DuplicateAxis {
44 op,
45 axis: axis_a,
46 role: "axes",
47 });
48 }
49 Ok(())
50}
51
52fn checked_shape_product(
53 op: &'static str,
54 role: &'static str,
55 shape: &[usize],
56) -> crate::Result<usize> {
57 shape.iter().try_fold(1usize, |acc, &dim| {
58 acc.checked_mul(dim)
59 .ok_or_else(|| crate::Error::InvalidConfig {
60 op,
61 message: format!("{role} element count overflows usize"),
62 })
63 })
64}
65
66fn validate_permutation(op: &'static str, perm: &[usize], rank: usize) -> crate::Result<()> {
67 validate_rank(op, rank, perm.len())?;
68 let mut seen = vec![false; rank];
69 for &axis in perm {
70 validate_axis(op, axis, rank)?;
71 if seen[axis] {
72 return Err(crate::Error::DuplicateAxis {
73 op,
74 axis,
75 role: "perm",
76 });
77 }
78 seen[axis] = true;
79 }
80 Ok(())
81}
82
83macro_rules! dispatch_tensor_unary_result {
84 ($input:expr, |$tensor:ident| $body:expr) => {
85 match $input {
86 Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
87 Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
88 Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
89 Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
90 Tensor::Bool($tensor) => Ok(Tensor::Bool($body?)),
91 Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
92 Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
93 }
94 };
95}
96
97macro_rules! dispatch_tensor_unary_with_bool_special_result {
98 ($input:expr, |$tensor:ident| $body:expr, bool |$bool_tensor:ident| $bool_body:expr) => {
99 match $input {
100 Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
101 Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
102 Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
103 Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
104 Tensor::Bool($bool_tensor) => Ok(Tensor::Bool($bool_body?)),
105 Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
106 Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
107 }
108 };
109}
110
111fn host_view<'a, T: Copy>(
112 op: &'static str,
113 tensor: &'a TypedTensor<T>,
114) -> crate::Result<StridedView<'a, T, Identity>> {
115 match tensor.buffer() {
116 crate::Buffer::Host(data) => {
117 let strides = col_major_strides(tensor.shape());
118 StridedView::new(data.as_slice(), tensor.shape(), &strides, 0)
119 .map_err(|err| crate::Error::backend_failure(op, err))
120 }
121 crate::Buffer::Backend(_) => Err(cpu_backend_buffer_error(op)),
122 }
123}
124
125fn copy_view_to_array<T: Copy + Clone + Send + Sync>(
126 op: &'static str,
127 mut out: strided_kernel::StridedArray<T>,
128 src: &StridedView<'_, T>,
129) -> crate::Result<TypedTensor<T>> {
130 copy_into(&mut out.view_mut(), src).map_err(|err| crate::Error::backend_failure(op, err))?;
131 Ok(tensor_from_array(out))
132}
133
134fn zeroed_tensor_from_pool<T>(
135 buffers: &mut BufferPool,
136 op: &'static str,
137 shape: Vec<usize>,
138) -> crate::Result<TypedTensor<T>>
139where
140 T: Zero + Clone + PoolScalar + 'static,
141{
142 filled_tensor_from_pool(buffers, op, shape, T::zero())
143}
144
145fn filled_tensor_from_pool<T>(
146 buffers: &mut BufferPool,
147 op: &'static str,
148 shape: Vec<usize>,
149 fill: T,
150) -> crate::Result<TypedTensor<T>>
151where
152 T: Copy + Clone + PoolScalar + 'static,
153{
154 let len = checked_shape_product(op, "output shape", &shape)?;
155 let mut data = unsafe { T::pool_acquire(buffers, len) };
157 data.fill(fill);
158 TypedTensor::from_vec_col_major(shape, data)
159}
160
161fn clone_host_tensor_from_pool<T>(
162 buffers: &mut BufferPool,
163 op: &'static str,
164 tensor: &TypedTensor<T>,
165) -> crate::Result<TypedTensor<T>>
166where
167 T: Copy + PoolScalar + 'static,
168{
169 let input = match tensor.buffer() {
170 crate::Buffer::Host(data) => data.as_slice(),
171 crate::Buffer::Backend(_) => return Err(cpu_backend_buffer_error(op)),
172 };
173 let mut data = unsafe { T::pool_acquire(buffers, input.len()) };
175 data.copy_from_slice(input);
176 TypedTensor::from_buffer_col_major(
177 tensor.shape().to_vec(),
178 crate::Buffer::Host(data),
179 tensor.placement().clone(),
180 )
181}
182
183pub fn transpose(input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
184 with_local_pool(|buffers| transpose_with_pool(buffers, input, perm))
185}
186
187pub(crate) fn transpose_with_pool(
188 buffers: &mut BufferPool,
189 input: &Tensor,
190 perm: &[usize],
191) -> crate::Result<Tensor> {
192 dispatch_tensor_unary_result!(input, |t| typed_transpose_with_pool(buffers, t, perm))
193}
194
195pub fn reshape(input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
196 dispatch_tensor_unary_result!(input, |t| typed_reshape(t, shape))
197}
198
199pub fn broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> crate::Result<Tensor> {
200 with_local_pool(|buffers| broadcast_in_dim_with_pool(buffers, input, shape, dims))
201}
202
203pub(crate) fn broadcast_in_dim_with_pool(
204 buffers: &mut BufferPool,
205 input: &Tensor,
206 shape: &[usize],
207 dims: &[usize],
208) -> crate::Result<Tensor> {
209 dispatch_tensor_unary_result!(input, |t| typed_broadcast_in_dim_with_pool(
210 buffers, t, shape, dims
211 ))
212}
213
214pub fn convert(input: &Tensor, to: DType) -> crate::Result<Tensor> {
236 with_local_pool(|buffers| convert_with_pool(buffers, input, to))
237}
238
239pub(crate) fn convert_with_pool(
240 buffers: &mut BufferPool,
241 input: &Tensor,
242 to: DType,
243) -> crate::Result<Tensor> {
244 tenferro_tensor::validate::validate_convert_dtype("convert", input.dtype(), to)?;
245 cast_with_pool(buffers, input, to)
246}
247
248pub(crate) fn cast_with_pool(
249 buffers: &mut BufferPool,
250 input: &Tensor,
251 to: DType,
252) -> crate::Result<Tensor> {
253 macro_rules! converted {
254 ($variant:ident, $tensor:expr, $map:expr) => {
255 Ok(Tensor::$variant(typed_convert_with_pool(
256 buffers, $tensor, $map,
257 )?))
258 };
259 }
260
261 match (input, to) {
262 (Tensor::F32(t), DType::F32) => Ok(Tensor::F32(t.clone())),
263 (Tensor::F32(t), DType::F64) => converted!(F64, t, |x| x as f64),
264 (Tensor::F32(t), DType::I32) => converted!(I32, t, |x| x as i32),
265 (Tensor::F32(t), DType::I64) => converted!(I64, t, |x| x as i64),
266 (Tensor::F32(t), DType::Bool) => converted!(Bool, t, |x| x != 0.0),
267 (Tensor::F32(t), DType::C32) => converted!(C32, t, |x| Complex32::new(x, 0.0)),
268 (Tensor::F32(t), DType::C64) => {
269 converted!(C64, t, |x| Complex64::new(x as f64, 0.0))
270 }
271 (Tensor::F64(t), DType::F32) => converted!(F32, t, |x| x as f32),
272 (Tensor::F64(t), DType::F64) => Ok(Tensor::F64(t.clone())),
273 (Tensor::F64(t), DType::I32) => converted!(I32, t, |x| x as i32),
274 (Tensor::F64(t), DType::I64) => converted!(I64, t, |x| x as i64),
275 (Tensor::F64(t), DType::Bool) => converted!(Bool, t, |x| x != 0.0),
276 (Tensor::F64(t), DType::C32) => {
277 converted!(C32, t, |x| Complex32::new(x as f32, 0.0))
278 }
279 (Tensor::F64(t), DType::C64) => converted!(C64, t, |x| Complex64::new(x, 0.0)),
280 (Tensor::I32(t), DType::F32) => converted!(F32, t, |x| x as f32),
281 (Tensor::I32(t), DType::F64) => converted!(F64, t, |x| x as f64),
282 (Tensor::I32(t), DType::I32) => Ok(Tensor::I32(t.clone())),
283 (Tensor::I32(t), DType::I64) => converted!(I64, t, |x| x as i64),
284 (Tensor::I32(t), DType::Bool) => converted!(Bool, t, |x| x != 0),
285 (Tensor::I32(t), DType::C32) => {
286 converted!(C32, t, |x| Complex32::new(x as f32, 0.0))
287 }
288 (Tensor::I32(t), DType::C64) => {
289 converted!(C64, t, |x| Complex64::new(x as f64, 0.0))
290 }
291 (Tensor::I64(t), DType::F32) => converted!(F32, t, |x| x as f32),
292 (Tensor::I64(t), DType::F64) => converted!(F64, t, |x| x as f64),
293 (Tensor::I64(t), DType::I32) => converted!(I32, t, |x| x as i32),
294 (Tensor::I64(t), DType::I64) => Ok(Tensor::I64(t.clone())),
295 (Tensor::I64(t), DType::Bool) => converted!(Bool, t, |x| x != 0),
296 (Tensor::I64(t), DType::C32) => {
297 converted!(C32, t, |x| Complex32::new(x as f32, 0.0))
298 }
299 (Tensor::I64(t), DType::C64) => {
300 converted!(C64, t, |x| Complex64::new(x as f64, 0.0))
301 }
302 (Tensor::Bool(t), DType::F32) => converted!(F32, t, |x| if x { 1.0 } else { 0.0 }),
303 (Tensor::Bool(t), DType::F64) => converted!(F64, t, |x| if x { 1.0 } else { 0.0 }),
304 (Tensor::Bool(t), DType::I32) => converted!(I32, t, |x| if x { 1 } else { 0 }),
305 (Tensor::Bool(t), DType::I64) => converted!(I64, t, |x| if x { 1 } else { 0 }),
306 (Tensor::Bool(t), DType::Bool) => Ok(Tensor::Bool(t.clone())),
307 (Tensor::Bool(t), DType::C32) => {
308 converted!(C32, t, |x| Complex32::new(if x { 1.0 } else { 0.0 }, 0.0))
309 }
310 (Tensor::Bool(t), DType::C64) => {
311 converted!(C64, t, |x| Complex64::new(if x { 1.0 } else { 0.0 }, 0.0))
312 }
313 (Tensor::C32(t), DType::F32) => converted!(F32, t, |z| z.re),
314 (Tensor::C32(t), DType::F64) => converted!(F64, t, |z| z.re as f64),
315 (Tensor::C32(t), DType::I32) => converted!(I32, t, |z| z.re as i32),
316 (Tensor::C32(t), DType::I64) => converted!(I64, t, |z| z.re as i64),
317 (Tensor::C32(t), DType::Bool) => converted!(Bool, t, |z| z.re != 0.0 || z.im != 0.0),
318 (Tensor::C32(t), DType::C32) => Ok(Tensor::C32(t.clone())),
319 (Tensor::C32(t), DType::C64) => {
320 converted!(C64, t, |z| Complex64::new(z.re as f64, z.im as f64))
321 }
322 (Tensor::C64(t), DType::F32) => converted!(F32, t, |z| z.re as f32),
323 (Tensor::C64(t), DType::F64) => converted!(F64, t, |z| z.re),
324 (Tensor::C64(t), DType::I32) => converted!(I32, t, |z| z.re as i32),
325 (Tensor::C64(t), DType::I64) => converted!(I64, t, |z| z.re as i64),
326 (Tensor::C64(t), DType::Bool) => converted!(Bool, t, |z| z.re != 0.0 || z.im != 0.0),
327 (Tensor::C64(t), DType::C32) => {
328 converted!(C32, t, |z| Complex32::new(z.re as f32, z.im as f32))
329 }
330 (Tensor::C64(t), DType::C64) => Ok(Tensor::C64(t.clone())),
331 }
332}
333
334pub fn extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
335 with_local_pool(|buffers| extract_diagonal_with_pool(buffers, input, axis_a, axis_b))
336}
337
338pub(crate) fn extract_diagonal_with_pool(
339 buffers: &mut BufferPool,
340 input: &Tensor,
341 axis_a: usize,
342 axis_b: usize,
343) -> crate::Result<Tensor> {
344 dispatch_tensor_unary_result!(input, |t| typed_extract_diagonal_with_pool(
345 buffers, t, axis_a, axis_b
346 ))
347}
348
349pub fn embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor> {
350 with_local_pool(|buffers| embed_diagonal_with_pool(buffers, input, axis_a, axis_b))
351}
352
353pub(crate) fn embed_diagonal_with_pool(
354 buffers: &mut BufferPool,
355 input: &Tensor,
356 axis_a: usize,
357 axis_b: usize,
358) -> crate::Result<Tensor> {
359 dispatch_tensor_unary_with_bool_special_result!(
360 input,
361 |t| typed_embed_diagonal_with_pool(buffers, t, axis_a, axis_b),
362 bool | t
363 | typed_embed_diagonal_impl(t, axis_a, axis_b, |shape| {
364 filled_tensor_from_pool(buffers, "embed_diagonal", shape, false)
365 })
366 )
367}
368
369pub fn tril(input: &Tensor, k: i64) -> crate::Result<Tensor> {
370 with_local_pool(|buffers| tril_with_pool(buffers, input, k))
371}
372
373pub(crate) fn tril_with_pool(
374 buffers: &mut BufferPool,
375 input: &Tensor,
376 k: i64,
377) -> crate::Result<Tensor> {
378 dispatch_tensor_unary_with_bool_special_result!(
379 input,
380 |t| typed_tril_with_pool(buffers, t, k),
381 bool | t | typed_triangular_mask_with_fill_pool(buffers, t, k, false, false)
382 )
383}
384
385pub fn triu(input: &Tensor, k: i64) -> crate::Result<Tensor> {
386 with_local_pool(|buffers| triu_with_pool(buffers, input, k))
387}
388
389pub(crate) fn triu_with_pool(
390 buffers: &mut BufferPool,
391 input: &Tensor,
392 k: i64,
393) -> crate::Result<Tensor> {
394 dispatch_tensor_unary_with_bool_special_result!(
395 input,
396 |t| typed_triu_with_pool(buffers, t, k),
397 bool | t | typed_triangular_mask_with_fill_pool(buffers, t, k, true, false)
398 )
399}
400
401#[cfg(test)]
402pub(crate) fn typed_transpose<T: Copy + Clone + Send + Sync>(
403 tensor: &TypedTensor<T>,
404 perm: &[usize],
405) -> crate::Result<TypedTensor<T>> {
406 validate_permutation("transpose", perm, tensor.shape().len())?;
407 let src = host_view("transpose", tensor)?;
408 let permuted = src
409 .permute(perm)
410 .map_err(|err| crate::Error::backend_failure("transpose", err))?;
411 let out = unsafe { typed_array_uninit(permuted.dims()) };
413 copy_view_to_array("transpose", out, &permuted)
414}
415
416fn typed_transpose_view_impl<T, R>(
417 view: &TypedTensorView<'_, T, R>,
418 perm: &[usize],
419 make_out: impl FnOnce(&[usize]) -> strided_kernel::StridedArray<T>,
420) -> crate::Result<TypedTensor<T>>
421where
422 T: Copy + Clone + Send + Sync + 'static,
423 R: TensorRank,
424{
425 validate_permutation("transpose", perm, view.shape().len())?;
426 let src = typed_view_from_view("transpose", view)?;
427 let permuted = src
428 .permute(perm)
429 .map_err(|err| crate::Error::backend_failure("transpose", err))?;
430 checked_shape_product("transpose", "output shape", permuted.dims())?;
431 let out = make_out(permuted.dims());
433 copy_view_to_array("transpose", out, &permuted)
434}
435
436pub(crate) fn typed_transpose_with_pool<T>(
437 buffers: &mut BufferPool,
438 tensor: &TypedTensor<T>,
439 perm: &[usize],
440) -> crate::Result<TypedTensor<T>>
441where
442 T: Copy + Clone + PoolScalar + 'static,
443{
444 typed_transpose_view_with_pool(buffers, &tensor.as_view(), perm)
445}
446
447pub(crate) fn typed_transpose_view_with_pool<T, R>(
448 buffers: &mut BufferPool,
449 view: &TypedTensorView<'_, T, R>,
450 perm: &[usize],
451) -> crate::Result<TypedTensor<T>>
452where
453 T: Copy + Clone + PoolScalar + 'static,
454 R: TensorRank,
455{
456 typed_transpose_view_impl(view, perm, |shape| unsafe {
457 typed_array_uninit_from_pool(buffers, shape)
459 })
460}
461
462pub fn typed_reshape<T: Clone + 'static>(
463 tensor: &TypedTensor<T>,
464 shape: &[usize],
465) -> crate::Result<TypedTensor<T>> {
466 let old_n = checked_shape_product("reshape", "input shape", tensor.shape())?;
467 let new_n = checked_shape_product("reshape", "output shape", shape)?;
468 if old_n != new_n {
469 return Err(crate::Error::ShapeMismatch {
470 op: "reshape",
471 lhs: tensor.shape().to_vec(),
472 rhs: shape.to_vec(),
473 });
474 }
475 TypedTensor::from_buffer_col_major(
476 shape.to_vec(),
477 tensor.buffer().clone(),
478 tensor.placement().clone(),
479 )
480}
481
482#[cfg(test)]
483pub(crate) fn typed_broadcast_in_dim<T: Copy + Clone + Send + Sync>(
484 tensor: &TypedTensor<T>,
485 shape: &[usize],
486 dims: &[usize],
487) -> crate::Result<TypedTensor<T>> {
488 typed_broadcast_in_dim_impl(tensor, shape, dims, |shape| unsafe {
489 typed_array_uninit(shape)
491 })
492}
493
494pub(crate) fn typed_broadcast_in_dim_with_pool<T>(
495 buffers: &mut BufferPool,
496 tensor: &TypedTensor<T>,
497 shape: &[usize],
498 dims: &[usize],
499) -> crate::Result<TypedTensor<T>>
500where
501 T: Copy + Clone + PoolScalar,
502{
503 typed_broadcast_in_dim_impl(tensor, shape, dims, |shape| unsafe {
504 typed_array_uninit_from_pool(buffers, shape)
506 })
507}
508
509fn typed_broadcast_in_dim_impl<T>(
510 tensor: &TypedTensor<T>,
511 shape: &[usize],
512 dims: &[usize],
513 make_out: impl FnOnce(&[usize]) -> strided_kernel::StridedArray<T>,
514) -> crate::Result<TypedTensor<T>>
515where
516 T: Copy + Clone + Send + Sync,
517{
518 validate_rank("broadcast_in_dim", tensor.shape().len(), dims.len())?;
519 let mut seen = vec![false; shape.len()];
520 let mut base_dims = vec![1usize; shape.len()];
521 let mut base_strides = vec![0isize; shape.len()];
522 let source_strides = col_major_strides(tensor.shape());
523 for (src_axis, &dst_axis) in dims.iter().enumerate() {
524 validate_axis("broadcast_in_dim", dst_axis, shape.len())?;
525 if seen[dst_axis] {
526 return Err(crate::Error::DuplicateAxis {
527 op: "broadcast_in_dim",
528 axis: dst_axis,
529 role: "dims",
530 });
531 }
532 seen[dst_axis] = true;
533 let source_dim = tensor.shape()[src_axis];
534 let target_dim = shape[dst_axis];
535 if source_dim != target_dim && source_dim != 1 {
536 return Err(crate::Error::ShapeMismatch {
537 op: "broadcast_in_dim",
538 lhs: tensor.shape().to_vec(),
539 rhs: shape.to_vec(),
540 });
541 }
542 base_dims[dst_axis] = source_dim;
543 base_strides[dst_axis] = source_strides[src_axis];
544 }
545 let base: StridedView<'_, T, Identity> = match tensor.buffer() {
546 crate::Buffer::Host(data) => {
547 StridedView::new(data.as_slice(), &base_dims, &base_strides, 0)
548 .map_err(|err| crate::Error::backend_failure("broadcast_in_dim", err))?
549 }
550 crate::Buffer::Backend(_) => return Err(cpu_backend_buffer_error("broadcast_in_dim")),
551 };
552 let broadcast: StridedView<'_, T, Identity> = base
553 .broadcast(shape)
554 .map_err(|err| crate::Error::backend_failure("broadcast_in_dim", err))?;
555 checked_shape_product("broadcast_in_dim", "output shape", shape)?;
556 let mut out = make_out(shape);
558 copy_into(&mut out.view_mut(), &broadcast)
559 .map_err(|err| crate::Error::backend_failure("broadcast_in_dim", err))?;
560 Ok(tensor_from_array(out))
561}
562
563fn typed_convert_with_pool<S, T>(
564 buffers: &mut BufferPool,
565 tensor: &TypedTensor<S>,
566 f: impl Fn(S) -> T + Sync,
567) -> crate::Result<TypedTensor<T>>
568where
569 S: Copy + Send + Sync,
570 T: Copy + Clone + PoolScalar,
571{
572 let mut out = unsafe { typed_array_uninit_from_pool(buffers, tensor.shape()) };
574 map_into(&mut out.view_mut(), &typed_view("convert", tensor)?, f)
575 .map_err(|err| crate::Error::backend_failure("convert", err))?;
576 Ok(tensor_from_array(out))
577}
578
579#[cfg(test)]
580pub(crate) fn typed_extract_diagonal<T: Copy + Clone + Send + Sync>(
581 tensor: &TypedTensor<T>,
582 axis_a: usize,
583 axis_b: usize,
584) -> crate::Result<TypedTensor<T>> {
585 validate_axis("extract_diagonal", axis_a, tensor.shape().len())?;
586 validate_axis("extract_diagonal", axis_b, tensor.shape().len())?;
587 validate_axes_distinct("extract_diagonal", axis_a, axis_b)?;
588
589 let diag = host_view("extract_diagonal", tensor)?
590 .diagonal_view(&[(axis_a, axis_b)])
591 .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
592 let mut out = unsafe { typed_array_uninit(diag.dims()) };
594 copy_into(&mut out.view_mut(), &diag)
595 .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
596 Ok(tensor_from_array(out))
597}
598
599pub(crate) fn typed_extract_diagonal_with_pool<T>(
600 buffers: &mut BufferPool,
601 tensor: &TypedTensor<T>,
602 axis_a: usize,
603 axis_b: usize,
604) -> crate::Result<TypedTensor<T>>
605where
606 T: Copy + Clone + PoolScalar,
607{
608 validate_axis("extract_diagonal", axis_a, tensor.shape().len())?;
609 validate_axis("extract_diagonal", axis_b, tensor.shape().len())?;
610 validate_axes_distinct("extract_diagonal", axis_a, axis_b)?;
611
612 let diag = host_view("extract_diagonal", tensor)?
613 .diagonal_view(&[(axis_a, axis_b)])
614 .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
615 let mut out = unsafe { typed_array_uninit_from_pool(buffers, diag.dims()) };
617 copy_into(&mut out.view_mut(), &diag)
618 .map_err(|err| crate::Error::backend_failure("extract_diagonal", err))?;
619 Ok(tensor_from_array(out))
620}
621
622#[cfg(test)]
623pub(crate) fn typed_embed_diagonal<T: Copy + Zero + Clone>(
624 tensor: &TypedTensor<T>,
625 axis_a: usize,
626 axis_b: usize,
627) -> crate::Result<TypedTensor<T>> {
628 typed_embed_diagonal_impl(tensor, axis_a, axis_b, TypedTensor::zeros)
629}
630
631pub(crate) fn typed_embed_diagonal_with_pool<T>(
632 buffers: &mut BufferPool,
633 tensor: &TypedTensor<T>,
634 axis_a: usize,
635 axis_b: usize,
636) -> crate::Result<TypedTensor<T>>
637where
638 T: Copy + Zero + Clone + PoolScalar + 'static,
639{
640 typed_embed_diagonal_impl(tensor, axis_a, axis_b, |shape| {
641 zeroed_tensor_from_pool(buffers, "embed_diagonal", shape)
642 })
643}
644
645fn typed_embed_diagonal_impl<T>(
646 tensor: &TypedTensor<T>,
647 axis_a: usize,
648 axis_b: usize,
649 make_zeroed: impl FnOnce(Vec<usize>) -> crate::Result<TypedTensor<T>>,
650) -> crate::Result<TypedTensor<T>>
651where
652 T: Copy + Clone,
653{
654 validate_axis("embed_diagonal", axis_a, tensor.shape().len())?;
655 if axis_b > tensor.shape().len() {
656 return Err(crate::Error::AxisOutOfBounds {
657 op: "embed_diagonal",
658 axis: axis_b,
659 rank: tensor.shape().len(),
660 });
661 }
662
663 let n = tensor.shape()[axis_a];
664 let mut out_shape = tensor.shape().to_vec();
665 out_shape.insert(axis_b, n);
666 let mut out = make_zeroed(out_shape)?;
667
668 let in_rank = tensor.shape().len();
669 let out_rank = out.shape().len();
670 let mut in_idx = vec![0usize; in_rank];
671 let mut out_idx = vec![0usize; out_rank];
672
673 let input_data = match tensor.buffer() {
674 crate::Buffer::Host(data) => data.as_slice(),
675 crate::Buffer::Backend(_) => return Err(cpu_backend_buffer_error("embed_diagonal")),
676 };
677
678 for (flat, value) in input_data
681 .iter()
682 .copied()
683 .enumerate()
684 .take(tensor.n_elements())
685 {
686 flat_to_multi(flat, tensor.shape(), &mut in_idx);
687 let diag_val = in_idx[axis_a];
688 let mut src_axis = 0usize;
689 for (out_axis, out_slot) in out_idx.iter_mut().enumerate().take(out_rank) {
690 if out_axis == axis_b {
691 *out_slot = diag_val;
692 } else {
693 *out_slot = in_idx[src_axis];
694 src_axis += 1;
695 }
696 }
697 *out.get_mut(&out_idx)? = value;
698 }
699 Ok(out)
700}
701
702#[cfg(test)]
703pub(crate) fn typed_tril<T: Copy + Zero + Clone>(
704 tensor: &TypedTensor<T>,
705 k: i64,
706) -> crate::Result<TypedTensor<T>> {
707 typed_triangular_mask(tensor, k, false)
708}
709
710pub(crate) fn typed_tril_with_pool<T>(
711 buffers: &mut BufferPool,
712 tensor: &TypedTensor<T>,
713 k: i64,
714) -> crate::Result<TypedTensor<T>>
715where
716 T: Copy + Zero + Clone + PoolScalar + 'static,
717{
718 typed_triangular_mask_with_fill_pool(buffers, tensor, k, false, T::zero())
719}
720
721#[cfg(test)]
722pub(crate) fn typed_triu<T: Copy + Zero + Clone>(
723 tensor: &TypedTensor<T>,
724 k: i64,
725) -> crate::Result<TypedTensor<T>> {
726 typed_triangular_mask(tensor, k, true)
727}
728
729pub(crate) fn typed_triu_with_pool<T>(
730 buffers: &mut BufferPool,
731 tensor: &TypedTensor<T>,
732 k: i64,
733) -> crate::Result<TypedTensor<T>>
734where
735 T: Copy + Zero + Clone + PoolScalar + 'static,
736{
737 typed_triangular_mask_with_fill_pool(buffers, tensor, k, true, T::zero())
738}
739
740#[cfg(test)]
741fn typed_triangular_mask<T: Copy + Zero + Clone>(
742 tensor: &TypedTensor<T>,
743 k: i64,
744 upper: bool,
745) -> crate::Result<TypedTensor<T>> {
746 let op = if upper { "triu" } else { "tril" };
747 if tensor.shape().len() < 2 {
748 return Err(crate::Error::RankMismatch {
749 op,
750 expected: 2,
751 actual: tensor.shape().len(),
752 });
753 }
754
755 let rows = tensor.shape()[0];
756 let cols = tensor.shape()[1];
757 if tensor.shape().contains(&0) {
758 return Ok(tensor.clone());
759 }
760
761 let (batch_count, block_size) = checked_triangular_extent(op, tensor.shape(), rows, cols)?;
762 let mut out = tensor.clone();
763 let data = out.host_data_mut()?;
764
765 for batch_idx in 0..batch_count {
768 for col in 0..cols {
769 let boundary = col as i128 - k as i128;
770 for row in 0..rows {
771 let row_idx = row;
772 let row = row_idx as i128;
773 let keep = if upper {
774 row <= boundary
775 } else {
776 row >= boundary
777 };
778 if !keep {
779 let offset =
780 checked_triangular_offset(op, batch_idx, block_size, col, rows, row_idx)?;
781 data[offset] = T::zero();
782 }
783 }
784 }
785 }
786
787 Ok(out)
788}
789
790fn typed_triangular_mask_with_fill_pool<T>(
791 buffers: &mut BufferPool,
792 tensor: &TypedTensor<T>,
793 k: i64,
794 upper: bool,
795 fill: T,
796) -> crate::Result<TypedTensor<T>>
797where
798 T: Copy + Clone + PoolScalar + 'static,
799{
800 let op = if upper { "triu" } else { "tril" };
801 if tensor.shape().len() < 2 {
802 return Err(crate::Error::RankMismatch {
803 op,
804 expected: 2,
805 actual: tensor.shape().len(),
806 });
807 }
808
809 let rows = tensor.shape()[0];
810 let cols = tensor.shape()[1];
811 if tensor.shape().contains(&0) {
812 return Ok(tensor.clone());
813 }
814
815 let (batch_count, block_size) = checked_triangular_extent(op, tensor.shape(), rows, cols)?;
816 let mut out = clone_host_tensor_from_pool(buffers, op, tensor)?;
817 let data = out.host_data_mut()?;
818
819 for batch_idx in 0..batch_count {
822 for col in 0..cols {
823 let boundary = col as i128 - k as i128;
824 for row in 0..rows {
825 let row_idx = row;
826 let row = row_idx as i128;
827 let keep = if upper {
828 row <= boundary
829 } else {
830 row >= boundary
831 };
832 if !keep {
833 let offset =
834 checked_triangular_offset(op, batch_idx, block_size, col, rows, row_idx)?;
835 data[offset] = fill;
836 }
837 }
838 }
839 }
840
841 Ok(out)
842}
843
844fn checked_triangular_extent(
845 op: &'static str,
846 shape: &[usize],
847 rows: usize,
848 cols: usize,
849) -> crate::Result<(usize, usize)> {
850 let batch_count = shape[2..].iter().try_fold(1usize, |acc, &dim| {
851 acc.checked_mul(dim)
852 .ok_or_else(|| crate::Error::InvalidConfig {
853 op,
854 message: format!("batch extent overflows usize: {acc} * {dim}"),
855 })
856 })?;
857 let block_size = rows
858 .checked_mul(cols)
859 .ok_or_else(|| crate::Error::InvalidConfig {
860 op,
861 message: format!("matrix block size overflows usize: {rows} * {cols}"),
862 })?;
863 Ok((batch_count, block_size))
864}
865
866fn checked_triangular_offset(
867 op: &'static str,
868 batch_idx: usize,
869 block_size: usize,
870 col: usize,
871 rows: usize,
872 row_idx: usize,
873) -> crate::Result<usize> {
874 let base = batch_idx
875 .checked_mul(block_size)
876 .ok_or_else(|| crate::Error::InvalidConfig {
877 op,
878 message: format!("batch offset overflows usize: {batch_idx} * {block_size}"),
879 })?;
880 let col_offset = col
881 .checked_mul(rows)
882 .ok_or_else(|| crate::Error::InvalidConfig {
883 op,
884 message: format!("column offset overflows usize: {col} * {rows}"),
885 })?;
886 base.checked_add(col_offset)
887 .and_then(|offset| offset.checked_add(row_idx))
888 .ok_or_else(|| crate::Error::InvalidConfig {
889 op,
890 message: "triangular mask offset overflows usize".to_string(),
891 })
892}