tenferro_tensor/tensor/
structural.rs

1use num_traits::ToPrimitive;
2use tenferro_algebra::Scalar;
3use tenferro_device::{flatten_col_major_index, Error, LogicalMemorySpace, Result};
4
5use super::Tensor;
6use crate::MemoryOrder;
7
8/// Scalar types accepted as `keep_counts` in [`Tensor::zero_trailing_by_counts`].
9///
10/// Implemented for `f32` and `f64`.
11///
12/// # Examples
13///
14/// ```ignore
15/// use tenferro_tensor::KeepCountScalar;
16///
17/// fn accepts_keep_counts<T: KeepCountScalar>() {}
18/// accepts_keep_counts::<f32>();
19/// accepts_keep_counts::<f64>();
20/// ```
21#[cfg(not(feature = "cuda"))]
22pub trait KeepCountScalar: Scalar + ToPrimitive {}
23
24#[cfg(not(feature = "cuda"))]
25impl KeepCountScalar for f32 {}
26#[cfg(not(feature = "cuda"))]
27impl KeepCountScalar for f64 {}
28
29/// Scalar types accepted as `keep_counts` in [`Tensor::zero_trailing_by_counts`].
30///
31/// Implemented for `f32` and `f64`.
32///
33/// # Examples
34///
35/// ```ignore
36/// use tenferro_tensor::KeepCountScalar;
37///
38/// fn accepts_keep_counts<T: KeepCountScalar>() {}
39/// accepts_keep_counts::<f32>();
40/// accepts_keep_counts::<f64>();
41/// ```
42#[cfg(feature = "cuda")]
43pub trait KeepCountScalar:
44    Scalar + ToPrimitive + tenferro_device::cuda::runtime::RuntimeKeepCountScalar
45{
46}
47
48#[cfg(feature = "cuda")]
49impl KeepCountScalar for f32 {}
50#[cfg(feature = "cuda")]
51impl KeepCountScalar for f64 {}
52
53fn parse_keep_count<R: ToPrimitive>(value: &R, axis_len: usize, index: usize) -> Result<usize> {
54    let raw = value.to_f64().ok_or_else(|| {
55        Error::InvalidArgument(format!(
56            "keep_counts[{index}] must be representable as a real scalar"
57        ))
58    })?;
59    if !raw.is_finite() {
60        return Err(Error::InvalidArgument(format!(
61            "keep_counts[{index}] must be finite"
62        )));
63    }
64    if raw < 0.0 {
65        return Err(Error::InvalidArgument(format!(
66            "keep_counts[{index}] must be non-negative"
67        )));
68    }
69    if raw.fract() != 0.0 {
70        return Err(Error::InvalidArgument(format!(
71            "keep_counts[{index}] must be integer-valued"
72        )));
73    }
74    let count = raw as usize;
75    if count > axis_len {
76        return Err(Error::InvalidArgument(format!(
77            "keep_counts[{index}]={count} exceeds axis length {axis_len}"
78        )));
79    }
80    Ok(count)
81}
82
83fn validate_triangular_merge_shapes<T: Scalar>(
84    lower: &Tensor<T>,
85    upper: &Tensor<T>,
86) -> Result<(usize, usize, usize, Vec<usize>)> {
87    if lower.ndim() < 2 || upper.ndim() < 2 {
88        return Err(Error::InvalidArgument(
89            "merge_strict_lower_and_upper requires rank >= 2".into(),
90        ));
91    }
92    if lower.ndim() != upper.ndim() {
93        return Err(Error::RankMismatch {
94            expected: lower.ndim(),
95            got: upper.ndim(),
96        });
97    }
98    if lower.logical_memory_space() != upper.logical_memory_space() {
99        return Err(Error::InvalidArgument(format!(
100            "merge_strict_lower_and_upper requires matching memory spaces, got {:?} and {:?}",
101            lower.logical_memory_space(),
102            upper.logical_memory_space()
103        )));
104    }
105    if lower.is_conjugated() || upper.is_conjugated() {
106        return Err(Error::InvalidArgument(
107            "merge_strict_lower_and_upper does not support conjugated tensors".into(),
108        ));
109    }
110    if lower.dims()[2..] != upper.dims()[2..] {
111        return Err(Error::ShapeMismatch {
112            expected: lower.dims()[2..].to_vec(),
113            got: upper.dims()[2..].to_vec(),
114        });
115    }
116
117    let m = lower.dims()[0];
118    let k = lower.dims()[1];
119    let upper_k = upper.dims()[0];
120    let n = upper.dims()[1];
121    if k != m.min(n) || upper_k != k {
122        return Err(Error::InvalidArgument(format!(
123            "merge_strict_lower_and_upper requires lower.cols == upper.rows == min(m, n); got m={m}, lower.cols={k}, upper.rows={upper_k}, n={n}"
124        )));
125    }
126
127    Ok((m, k, n, lower.dims()[2..].to_vec()))
128}
129
130impl<T: Scalar> Tensor<T> {
131    /// Return a contiguous tensor with trailing elements zeroed according to
132    /// batch-local keep counts.
133    ///
134    /// `structural_rank` splits the payload dims from the trailing batch dims.
135    /// `axis` is interpreted within the structural prefix `[0, structural_rank)`.
136    ///
137    /// Phase 1 supports main-memory tensors only.
138    ///
139    /// # Examples
140    ///
141    /// ```ignore
142    /// use tenferro_tensor::{MemoryOrder, Tensor};
143    ///
144    /// let payload = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor)?;
145    /// let keep_counts = Tensor::from_slice(&[1.0], &[], MemoryOrder::ColumnMajor)?;
146    /// let trimmed = payload.zero_trailing_by_counts(&keep_counts, 1, 2)?;
147    /// assert_eq!(trimmed.buffer().as_slice().unwrap(), &[1.0, 2.0, 0.0, 0.0]);
148    /// # Ok::<(), tenferro_device::Error>(())
149    /// ```
150    pub fn zero_trailing_by_counts<R>(
151        &self,
152        keep_counts: &Tensor<R>,
153        axis: usize,
154        structural_rank: usize,
155    ) -> Result<Tensor<T>>
156    where
157        R: KeepCountScalar,
158    {
159        self.wait();
160        keep_counts.wait();
161
162        if self.logical_memory_space != keep_counts.logical_memory_space() {
163            return Err(Error::InvalidArgument(format!(
164                "zero_trailing_by_counts requires matching memory spaces, got {:?} and {:?}",
165                self.logical_memory_space,
166                keep_counts.logical_memory_space()
167            )));
168        }
169        #[cfg(feature = "cuda")]
170        if matches!(
171            self.logical_memory_space,
172            LogicalMemorySpace::GpuMemory { .. }
173        ) {
174            return crate::cuda_runtime::zero_trailing_by_counts_tensor(
175                self,
176                keep_counts,
177                axis,
178                structural_rank,
179            );
180        }
181
182        if self.logical_memory_space != LogicalMemorySpace::MainMemory {
183            return Err(Error::DeviceError(
184                "zero_trailing_by_counts only supports main-memory tensors in Phase 1".into(),
185            ));
186        }
187        if structural_rank == 0 {
188            return Err(Error::InvalidArgument(
189                "zero_trailing_by_counts requires structural_rank >= 1".into(),
190            ));
191        }
192        if structural_rank > self.ndim() {
193            return Err(Error::RankMismatch {
194                expected: self.ndim(),
195                got: structural_rank,
196            });
197        }
198        if axis >= structural_rank {
199            return Err(Error::InvalidArgument(format!(
200                "axis {axis} out of range for structural_rank {structural_rank}"
201            )));
202        }
203
204        let expected_batch_dims = &self.dims[structural_rank..];
205        if keep_counts.dims() != expected_batch_dims {
206            return Err(Error::ShapeMismatch {
207                expected: expected_batch_dims.to_vec(),
208                got: keep_counts.dims().to_vec(),
209            });
210        }
211
212        let keep_counts = keep_counts.contiguous(MemoryOrder::ColumnMajor);
213        let keep_counts_slice = keep_counts.buffer().as_slice().ok_or_else(|| {
214            Error::DeviceError(
215                "zero_trailing_by_counts requires CPU-accessible keep_counts in Phase 1".into(),
216            )
217        })?;
218        let axis_len = self.dims[axis];
219        let parsed_counts = keep_counts_slice
220            .iter()
221            .enumerate()
222            .map(|(index, value)| parse_keep_count(value, axis_len, index))
223            .collect::<Result<Vec<_>>>()?;
224
225        let src = self.buffer().as_slice().ok_or_else(|| {
226            Error::DeviceError(
227                "zero_trailing_by_counts requires a CPU-accessible payload in Phase 1".into(),
228            )
229        })?;
230        let mut out = vec![T::zero(); self.len()];
231        if out.is_empty() {
232            return Ok(self.materialized_from_vec(out, MemoryOrder::ColumnMajor));
233        }
234
235        let out_len = out.len();
236        let mut index = vec![0usize; self.ndim()];
237        for (dst_pos, dst) in out.iter_mut().enumerate() {
238            let batch_pos =
239                flatten_col_major_index(expected_batch_dims, &index[structural_rank..])?;
240            let keep = parsed_counts.get(batch_pos).copied().ok_or_else(|| {
241                Error::StrideError(format!("batch position {batch_pos} out of bounds"))
242            })?;
243            if index[axis] < keep {
244                let src_pos = self
245                    .offset
246                    .checked_add(
247                        index
248                            .iter()
249                            .zip(self.strides.iter())
250                            .try_fold(0isize, |acc, (&coord, &stride)| {
251                                (coord as isize)
252                                    .checked_mul(stride)
253                                    .and_then(|term| acc.checked_add(term))
254                            })
255                            .ok_or_else(|| {
256                                Error::StrideError(format!(
257                                    "source offset overflow for index {:?} and strides {:?}",
258                                    index, self.strides
259                                ))
260                            })?,
261                    )
262                    .and_then(|pos| usize::try_from(pos).ok())
263                    .ok_or_else(|| {
264                        Error::StrideError(format!(
265                            "source position overflow for index {:?} and offset {}",
266                            index, self.offset
267                        ))
268                    })?;
269                *dst = src[src_pos];
270            }
271
272            if dst_pos + 1 == out_len {
273                continue;
274            }
275            for (axis_index, coord) in index.iter_mut().enumerate().take(self.ndim()) {
276                *coord += 1;
277                if *coord < self.dims[axis_index] {
278                    break;
279                }
280                *coord = 0;
281            }
282        }
283
284        Ok(self.materialized_from_vec(out, MemoryOrder::ColumnMajor))
285    }
286
287    /// Merge a strict-lower source and an upper-with-diagonal source into one packed matrix.
288    ///
289    /// `lower` must have shape `[m, k, *batch]` and `upper` must have shape `[k, n, *batch]`
290    /// where `k = min(m, n)`. The output has shape `[m, n, *batch]` with entries selected
291    /// from `lower` when `row > col` and from `upper` otherwise.
292    ///
293    /// # Examples
294    ///
295    /// ```ignore
296    /// use tenferro_tensor::{MemoryOrder, Tensor};
297    ///
298    /// let lower = Tensor::from_slice(&[1.0, 2.0, 1.0, 3.0], &[2, 2], MemoryOrder::ColumnMajor)?;
299    /// let upper = Tensor::from_slice(&[4.0, 0.0, 5.0, 6.0], &[2, 2], MemoryOrder::ColumnMajor)?;
300    /// let packed = Tensor::merge_strict_lower_and_upper(&lower, &upper)?;
301    /// assert_eq!(packed.buffer().as_slice().unwrap(), &[4.0, 2.0, 5.0, 6.0]);
302    /// # Ok::<(), tenferro_device::Error>(())
303    /// ```
304    pub fn merge_strict_lower_and_upper(lower: &Tensor<T>, upper: &Tensor<T>) -> Result<Tensor<T>> {
305        lower.wait();
306        upper.wait();
307
308        let (m, _k, n, batch_dims) = validate_triangular_merge_shapes(lower, upper)?;
309
310        #[cfg(feature = "cuda")]
311        if matches!(
312            lower.logical_memory_space(),
313            LogicalMemorySpace::GpuMemory { .. }
314        ) {
315            return crate::cuda_runtime::merge_strict_lower_and_upper_tensor(lower, upper);
316        }
317
318        if lower.logical_memory_space() != LogicalMemorySpace::MainMemory {
319            return Err(Error::DeviceError(
320                "merge_strict_lower_and_upper only supports main-memory tensors in Phase 1".into(),
321            ));
322        }
323
324        let lower_src = lower.buffer().as_slice().ok_or_else(|| {
325            Error::DeviceError(
326                "merge_strict_lower_and_upper requires a CPU-accessible lower tensor".into(),
327            )
328        })?;
329        let upper_src = upper.buffer().as_slice().ok_or_else(|| {
330            Error::DeviceError(
331                "merge_strict_lower_and_upper requires a CPU-accessible upper tensor".into(),
332            )
333        })?;
334        let mut output_dims = vec![m, n];
335        output_dims.extend(batch_dims.iter().copied());
336        let mut out = vec![T::zero(); output_dims.iter().product()];
337        if out.is_empty() {
338            return Ok(Tensor::from_owned_contiguous_data(
339                out,
340                output_dims.into(),
341                MemoryOrder::ColumnMajor,
342                lower.logical_memory_space(),
343                None,
344                lower.is_conjugated(),
345            ));
346        }
347
348        let out_len = out.len();
349        let mut index = vec![0usize; output_dims.len()];
350        for (dst_pos, dst) in out.iter_mut().enumerate().take(out_len) {
351            let src_tensor = if index[0] > index[1] { lower } else { upper };
352            let src_slice = if index[0] > index[1] {
353                lower_src
354            } else {
355                upper_src
356            };
357            let src_pos = src_tensor
358                .offset()
359                .checked_add(
360                    index
361                        .iter()
362                        .zip(src_tensor.strides().iter())
363                        .try_fold(0isize, |acc, (&coord, &stride)| {
364                            (coord as isize)
365                                .checked_mul(stride)
366                                .and_then(|term| acc.checked_add(term))
367                        })
368                        .ok_or_else(|| {
369                            Error::StrideError(format!(
370                                "source offset overflow for index {:?} and strides {:?}",
371                                index,
372                                src_tensor.strides()
373                            ))
374                        })?,
375                )
376                .and_then(|pos| usize::try_from(pos).ok())
377                .ok_or_else(|| {
378                    Error::StrideError(format!(
379                        "source position overflow for index {:?} and offset {}",
380                        index,
381                        src_tensor.offset()
382                    ))
383                })?;
384            *dst = src_slice[src_pos];
385
386            if dst_pos + 1 == out_len {
387                continue;
388            }
389            for axis_index in 0..output_dims.len() {
390                index[axis_index] += 1;
391                if index[axis_index] < output_dims[axis_index] {
392                    break;
393                }
394                index[axis_index] = 0;
395            }
396        }
397
398        Ok(Tensor::from_owned_contiguous_data(
399            out,
400            output_dims.into(),
401            MemoryOrder::ColumnMajor,
402            lower.logical_memory_space(),
403            None,
404            lower.is_conjugated(),
405        ))
406    }
407}