1use num_traits::ToPrimitive;
2use tenferro_algebra::Scalar;
3use tenferro_device::{flatten_col_major_index, Error, LogicalMemorySpace, Result};
4
5use super::Tensor;
6use crate::MemoryOrder;
7
8#[cfg(not(feature = "cuda"))]
22pub trait KeepCountScalar: Scalar + ToPrimitive {}
23
24#[cfg(not(feature = "cuda"))]
25impl KeepCountScalar for f32 {}
26#[cfg(not(feature = "cuda"))]
27impl KeepCountScalar for f64 {}
28
29#[cfg(feature = "cuda")]
43pub trait KeepCountScalar:
44 Scalar + ToPrimitive + tenferro_device::cuda::runtime::RuntimeKeepCountScalar
45{
46}
47
48#[cfg(feature = "cuda")]
49impl KeepCountScalar for f32 {}
50#[cfg(feature = "cuda")]
51impl KeepCountScalar for f64 {}
52
53fn parse_keep_count<R: ToPrimitive>(value: &R, axis_len: usize, index: usize) -> Result<usize> {
54 let raw = value.to_f64().ok_or_else(|| {
55 Error::InvalidArgument(format!(
56 "keep_counts[{index}] must be representable as a real scalar"
57 ))
58 })?;
59 if !raw.is_finite() {
60 return Err(Error::InvalidArgument(format!(
61 "keep_counts[{index}] must be finite"
62 )));
63 }
64 if raw < 0.0 {
65 return Err(Error::InvalidArgument(format!(
66 "keep_counts[{index}] must be non-negative"
67 )));
68 }
69 if raw.fract() != 0.0 {
70 return Err(Error::InvalidArgument(format!(
71 "keep_counts[{index}] must be integer-valued"
72 )));
73 }
74 let count = raw as usize;
75 if count > axis_len {
76 return Err(Error::InvalidArgument(format!(
77 "keep_counts[{index}]={count} exceeds axis length {axis_len}"
78 )));
79 }
80 Ok(count)
81}
82
83fn validate_triangular_merge_shapes<T: Scalar>(
84 lower: &Tensor<T>,
85 upper: &Tensor<T>,
86) -> Result<(usize, usize, usize, Vec<usize>)> {
87 if lower.ndim() < 2 || upper.ndim() < 2 {
88 return Err(Error::InvalidArgument(
89 "merge_strict_lower_and_upper requires rank >= 2".into(),
90 ));
91 }
92 if lower.ndim() != upper.ndim() {
93 return Err(Error::RankMismatch {
94 expected: lower.ndim(),
95 got: upper.ndim(),
96 });
97 }
98 if lower.logical_memory_space() != upper.logical_memory_space() {
99 return Err(Error::InvalidArgument(format!(
100 "merge_strict_lower_and_upper requires matching memory spaces, got {:?} and {:?}",
101 lower.logical_memory_space(),
102 upper.logical_memory_space()
103 )));
104 }
105 if lower.is_conjugated() || upper.is_conjugated() {
106 return Err(Error::InvalidArgument(
107 "merge_strict_lower_and_upper does not support conjugated tensors".into(),
108 ));
109 }
110 if lower.dims()[2..] != upper.dims()[2..] {
111 return Err(Error::ShapeMismatch {
112 expected: lower.dims()[2..].to_vec(),
113 got: upper.dims()[2..].to_vec(),
114 });
115 }
116
117 let m = lower.dims()[0];
118 let k = lower.dims()[1];
119 let upper_k = upper.dims()[0];
120 let n = upper.dims()[1];
121 if k != m.min(n) || upper_k != k {
122 return Err(Error::InvalidArgument(format!(
123 "merge_strict_lower_and_upper requires lower.cols == upper.rows == min(m, n); got m={m}, lower.cols={k}, upper.rows={upper_k}, n={n}"
124 )));
125 }
126
127 Ok((m, k, n, lower.dims()[2..].to_vec()))
128}
129
130impl<T: Scalar> Tensor<T> {
131 pub fn zero_trailing_by_counts<R>(
151 &self,
152 keep_counts: &Tensor<R>,
153 axis: usize,
154 structural_rank: usize,
155 ) -> Result<Tensor<T>>
156 where
157 R: KeepCountScalar,
158 {
159 self.wait();
160 keep_counts.wait();
161
162 if self.logical_memory_space != keep_counts.logical_memory_space() {
163 return Err(Error::InvalidArgument(format!(
164 "zero_trailing_by_counts requires matching memory spaces, got {:?} and {:?}",
165 self.logical_memory_space,
166 keep_counts.logical_memory_space()
167 )));
168 }
169 #[cfg(feature = "cuda")]
170 if matches!(
171 self.logical_memory_space,
172 LogicalMemorySpace::GpuMemory { .. }
173 ) {
174 return crate::cuda_runtime::zero_trailing_by_counts_tensor(
175 self,
176 keep_counts,
177 axis,
178 structural_rank,
179 );
180 }
181
182 if self.logical_memory_space != LogicalMemorySpace::MainMemory {
183 return Err(Error::DeviceError(
184 "zero_trailing_by_counts only supports main-memory tensors in Phase 1".into(),
185 ));
186 }
187 if structural_rank == 0 {
188 return Err(Error::InvalidArgument(
189 "zero_trailing_by_counts requires structural_rank >= 1".into(),
190 ));
191 }
192 if structural_rank > self.ndim() {
193 return Err(Error::RankMismatch {
194 expected: self.ndim(),
195 got: structural_rank,
196 });
197 }
198 if axis >= structural_rank {
199 return Err(Error::InvalidArgument(format!(
200 "axis {axis} out of range for structural_rank {structural_rank}"
201 )));
202 }
203
204 let expected_batch_dims = &self.dims[structural_rank..];
205 if keep_counts.dims() != expected_batch_dims {
206 return Err(Error::ShapeMismatch {
207 expected: expected_batch_dims.to_vec(),
208 got: keep_counts.dims().to_vec(),
209 });
210 }
211
212 let keep_counts = keep_counts.contiguous(MemoryOrder::ColumnMajor);
213 let keep_counts_slice = keep_counts.buffer().as_slice().ok_or_else(|| {
214 Error::DeviceError(
215 "zero_trailing_by_counts requires CPU-accessible keep_counts in Phase 1".into(),
216 )
217 })?;
218 let axis_len = self.dims[axis];
219 let parsed_counts = keep_counts_slice
220 .iter()
221 .enumerate()
222 .map(|(index, value)| parse_keep_count(value, axis_len, index))
223 .collect::<Result<Vec<_>>>()?;
224
225 let src = self.buffer().as_slice().ok_or_else(|| {
226 Error::DeviceError(
227 "zero_trailing_by_counts requires a CPU-accessible payload in Phase 1".into(),
228 )
229 })?;
230 let mut out = vec![T::zero(); self.len()];
231 if out.is_empty() {
232 return Ok(self.materialized_from_vec(out, MemoryOrder::ColumnMajor));
233 }
234
235 let out_len = out.len();
236 let mut index = vec![0usize; self.ndim()];
237 for (dst_pos, dst) in out.iter_mut().enumerate() {
238 let batch_pos =
239 flatten_col_major_index(expected_batch_dims, &index[structural_rank..])?;
240 let keep = parsed_counts.get(batch_pos).copied().ok_or_else(|| {
241 Error::StrideError(format!("batch position {batch_pos} out of bounds"))
242 })?;
243 if index[axis] < keep {
244 let src_pos = self
245 .offset
246 .checked_add(
247 index
248 .iter()
249 .zip(self.strides.iter())
250 .try_fold(0isize, |acc, (&coord, &stride)| {
251 (coord as isize)
252 .checked_mul(stride)
253 .and_then(|term| acc.checked_add(term))
254 })
255 .ok_or_else(|| {
256 Error::StrideError(format!(
257 "source offset overflow for index {:?} and strides {:?}",
258 index, self.strides
259 ))
260 })?,
261 )
262 .and_then(|pos| usize::try_from(pos).ok())
263 .ok_or_else(|| {
264 Error::StrideError(format!(
265 "source position overflow for index {:?} and offset {}",
266 index, self.offset
267 ))
268 })?;
269 *dst = src[src_pos];
270 }
271
272 if dst_pos + 1 == out_len {
273 continue;
274 }
275 for (axis_index, coord) in index.iter_mut().enumerate().take(self.ndim()) {
276 *coord += 1;
277 if *coord < self.dims[axis_index] {
278 break;
279 }
280 *coord = 0;
281 }
282 }
283
284 Ok(self.materialized_from_vec(out, MemoryOrder::ColumnMajor))
285 }
286
287 pub fn merge_strict_lower_and_upper(lower: &Tensor<T>, upper: &Tensor<T>) -> Result<Tensor<T>> {
305 lower.wait();
306 upper.wait();
307
308 let (m, _k, n, batch_dims) = validate_triangular_merge_shapes(lower, upper)?;
309
310 #[cfg(feature = "cuda")]
311 if matches!(
312 lower.logical_memory_space(),
313 LogicalMemorySpace::GpuMemory { .. }
314 ) {
315 return crate::cuda_runtime::merge_strict_lower_and_upper_tensor(lower, upper);
316 }
317
318 if lower.logical_memory_space() != LogicalMemorySpace::MainMemory {
319 return Err(Error::DeviceError(
320 "merge_strict_lower_and_upper only supports main-memory tensors in Phase 1".into(),
321 ));
322 }
323
324 let lower_src = lower.buffer().as_slice().ok_or_else(|| {
325 Error::DeviceError(
326 "merge_strict_lower_and_upper requires a CPU-accessible lower tensor".into(),
327 )
328 })?;
329 let upper_src = upper.buffer().as_slice().ok_or_else(|| {
330 Error::DeviceError(
331 "merge_strict_lower_and_upper requires a CPU-accessible upper tensor".into(),
332 )
333 })?;
334 let mut output_dims = vec![m, n];
335 output_dims.extend(batch_dims.iter().copied());
336 let mut out = vec![T::zero(); output_dims.iter().product()];
337 if out.is_empty() {
338 return Ok(Tensor::from_owned_contiguous_data(
339 out,
340 output_dims.into(),
341 MemoryOrder::ColumnMajor,
342 lower.logical_memory_space(),
343 None,
344 lower.is_conjugated(),
345 ));
346 }
347
348 let out_len = out.len();
349 let mut index = vec![0usize; output_dims.len()];
350 for (dst_pos, dst) in out.iter_mut().enumerate().take(out_len) {
351 let src_tensor = if index[0] > index[1] { lower } else { upper };
352 let src_slice = if index[0] > index[1] {
353 lower_src
354 } else {
355 upper_src
356 };
357 let src_pos = src_tensor
358 .offset()
359 .checked_add(
360 index
361 .iter()
362 .zip(src_tensor.strides().iter())
363 .try_fold(0isize, |acc, (&coord, &stride)| {
364 (coord as isize)
365 .checked_mul(stride)
366 .and_then(|term| acc.checked_add(term))
367 })
368 .ok_or_else(|| {
369 Error::StrideError(format!(
370 "source offset overflow for index {:?} and strides {:?}",
371 index,
372 src_tensor.strides()
373 ))
374 })?,
375 )
376 .and_then(|pos| usize::try_from(pos).ok())
377 .ok_or_else(|| {
378 Error::StrideError(format!(
379 "source position overflow for index {:?} and offset {}",
380 index,
381 src_tensor.offset()
382 ))
383 })?;
384 *dst = src_slice[src_pos];
385
386 if dst_pos + 1 == out_len {
387 continue;
388 }
389 for axis_index in 0..output_dims.len() {
390 index[axis_index] += 1;
391 if index[axis_index] < output_dims[axis_index] {
392 break;
393 }
394 index[axis_index] = 0;
395 }
396 }
397
398 Ok(Tensor::from_owned_contiguous_data(
399 out,
400 output_dims.into(),
401 MemoryOrder::ColumnMajor,
402 lower.logical_memory_space(),
403 None,
404 lower.is_conjugated(),
405 ))
406 }
407}