1use super::*;
2
3pub fn cross<T: KernelLinalgScalar, C>(
22 ctx: &mut C,
23 a: &Tensor<T>,
24 b: &Tensor<T>,
25) -> Result<Tensor<T>>
26where
27 C: backend::TensorLinalgContextFor<T>
28 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
29 C::Backend: 'static,
30{
31 if a.ndim() != b.ndim() {
32 return Err(Error::InvalidArgument(format!(
33 "cross expects matching ranks, got {:?} and {:?}",
34 a.dims(),
35 b.dims()
36 )));
37 }
38 if a.ndim() == 0 || a.dims()[0] != 3 {
39 return Err(Error::InvalidArgument(format!(
40 "cross expects leading vector dimension of size 3, got {:?}",
41 a.dims()
42 )));
43 }
44 if b.ndim() == 0 || b.dims()[0] != 3 {
45 return Err(Error::InvalidArgument(format!(
46 "cross expects leading vector dimension of size 3, got {:?}",
47 b.dims()
48 )));
49 }
50 let mut out_dims = vec![3];
51 for axis in 1..a.ndim() {
52 let lhs = a.dims()[axis];
53 let rhs = b.dims()[axis];
54 if lhs != rhs && lhs != 1 && rhs != 1 {
55 return Err(Error::InvalidArgument(format!(
56 "cross broadcast mismatch on axis {axis}: left={}, right={}",
57 lhs, rhs
58 )));
59 }
60 out_dims.push(lhs.max(rhs));
61 }
62
63 let a_input = ensure_col_major(a);
64 let b_input = ensure_col_major(b);
65 let out_tail_dims = &out_dims[1..];
66 let ax = a_input.select(0, 0)?.broadcast(out_tail_dims)?;
67 let ay = a_input.select(0, 1)?.broadcast(out_tail_dims)?;
68 let az = a_input.select(0, 2)?.broadcast(out_tail_dims)?;
69 let bx = b_input.select(0, 0)?.broadcast(out_tail_dims)?;
70 let by = b_input.select(0, 1)?.broadcast(out_tail_dims)?;
71 let bz = b_input.select(0, 2)?.broadcast(out_tail_dims)?;
72
73 let ay_bz = crate::prims_bridge::scalar_binary_same_shape(
74 ctx,
75 &ay,
76 &bz,
77 tenferro_prims::ScalarBinaryOp::Mul,
78 )?;
79 let az_by = crate::prims_bridge::scalar_binary_same_shape(
80 ctx,
81 &az,
82 &by,
83 tenferro_prims::ScalarBinaryOp::Mul,
84 )?;
85 let az_bx = crate::prims_bridge::scalar_binary_same_shape(
86 ctx,
87 &az,
88 &bx,
89 tenferro_prims::ScalarBinaryOp::Mul,
90 )?;
91 let ax_bz = crate::prims_bridge::scalar_binary_same_shape(
92 ctx,
93 &ax,
94 &bz,
95 tenferro_prims::ScalarBinaryOp::Mul,
96 )?;
97 let ax_by = crate::prims_bridge::scalar_binary_same_shape(
98 ctx,
99 &ax,
100 &by,
101 tenferro_prims::ScalarBinaryOp::Mul,
102 )?;
103 let ay_bx = crate::prims_bridge::scalar_binary_same_shape(
104 ctx,
105 &ay,
106 &bx,
107 tenferro_prims::ScalarBinaryOp::Mul,
108 )?;
109
110 let out_x = crate::prims_bridge::scalar_binary_same_shape(
111 ctx,
112 &ay_bz,
113 &az_by,
114 tenferro_prims::ScalarBinaryOp::Sub,
115 )?;
116 let out_y = crate::prims_bridge::scalar_binary_same_shape(
117 ctx,
118 &az_bx,
119 &ax_bz,
120 tenferro_prims::ScalarBinaryOp::Sub,
121 )?;
122 let out_z = crate::prims_bridge::scalar_binary_same_shape(
123 ctx,
124 &ax_by,
125 &ay_bx,
126 tenferro_prims::ScalarBinaryOp::Sub,
127 )?;
128
129 Tensor::stack(&[&out_x, &out_y, &out_z], 0)
130}
131
132pub fn householder_product<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
155 ctx: &mut C,
156 a: &Tensor<T>,
157 tau: &Tensor<T>,
158) -> Result<Tensor<T>>
159where
160 C: backend::TensorLinalgContextFor<T>
161 + tenferro_prims::TensorResolveConjContextFor<T>
162 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
163 + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
164 C::Backend: 'static,
165{
166 let (m, n, batch_dims) = validate_2d(a)?;
167 if tau.ndim() != 1 + batch_dims.len() {
168 return Err(Error::InvalidArgument(format!(
169 "householder_product expects tau shape (k, *), got {:?}",
170 tau.dims()
171 )));
172 }
173 if &tau.dims()[1..] != batch_dims {
174 return Err(Error::InvalidArgument(format!(
175 "householder_product batch dims mismatch: expected {:?}, got {:?}",
176 batch_dims,
177 &tau.dims()[1..]
178 )));
179 }
180
181 let k = tau.dims()[0];
182 if k > m.min(n) {
183 return Err(Error::InvalidArgument(format!(
184 "householder_product expects tau length <= min(m, n) = {}, got {}",
185 m.min(n),
186 k
187 )));
188 }
189
190 let a_input = ensure_col_major(a);
191 let tau_input = ensure_col_major(tau);
192 let memory_space = a_input.logical_memory_space();
193 let mut q = crate::prims_bridge::identity_matrix(m, memory_space)?.narrow(1, 0, n)?;
194 for _ in batch_dims {
195 q = q.unsqueeze(-1)?;
196 }
197 let q_target_dims = output_dims(&[m, n], batch_dims);
198 q = q.broadcast(&q_target_dims)?;
199
200 let vector_tail_dims = {
201 let mut dims = Vec::with_capacity(1 + batch_dims.len());
202 dims.push(1);
203 dims.extend_from_slice(batch_dims);
204 dims
205 };
206
207 for reflector in (0..k).rev() {
208 let tail_rows = m - reflector;
209 let tail = if tail_rows == 1 {
210 Tensor::ones(&vector_tail_dims, memory_space, MemoryOrder::ColumnMajor)?
211 } else {
212 let head = Tensor::ones(&vector_tail_dims, memory_space, MemoryOrder::ColumnMajor)?;
213 let lower = a_input
214 .narrow(0, reflector + 1, tail_rows - 1)?
215 .select(1, reflector)?;
216 Tensor::cat(&[&head, &lower], 0)?
217 };
218
219 let q_tail = q.narrow(0, reflector, tail_rows)?;
220 let v_col = tail.unsqueeze(1)?;
221 let mut adj_perm: Vec<usize> = (0..v_col.ndim()).collect();
222 adj_perm.swap(0, 1);
223 let v_adj_view = v_col.conj().permute(&adj_perm)?;
224 let v_adj = crate::prims_bridge::resolve_conj(ctx, &v_adj_view);
225 let reflected = crate::prims_bridge::batched_gemm_with_semiring_tensors(
226 ctx, &v_adj, &q_tail, 1, tail_rows, n,
227 )?;
228
229 let mut tau_scale = tau_input.select(0, reflector)?;
230 tau_scale = tau_scale.unsqueeze(0)?;
231 tau_scale = tau_scale.unsqueeze(0)?;
232 let tau_scale = tau_scale.broadcast(reflected.dims())?;
233 let scaled = crate::prims_bridge::scalar_binary_same_shape(
234 ctx,
235 &tau_scale,
236 &reflected,
237 tenferro_prims::ScalarBinaryOp::Mul,
238 )?;
239 let update = crate::prims_bridge::batched_gemm_with_semiring_tensors(
240 ctx, &v_col, &scaled, tail_rows, 1, n,
241 )?;
242 let updated_tail = crate::prims_bridge::scalar_binary_same_shape(
243 ctx,
244 &q_tail,
245 &update,
246 tenferro_prims::ScalarBinaryOp::Sub,
247 )?;
248
249 q = if reflector == 0 {
250 updated_tail
251 } else {
252 let prefix = q.narrow(0, 0, reflector)?;
253 Tensor::cat(&[&prefix, &updated_tail], 0)?
254 };
255 }
256
257 Ok(q)
258}
259
260pub fn vander<T: KernelLinalgScalar, C>(
276 ctx: &mut C,
277 x: &Tensor<T>,
278 columns: Option<usize>,
279 increasing: bool,
280) -> Result<Tensor<T>>
281where
282 C: backend::TensorLinalgContextFor<T>
283 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
284{
285 let (vector_len, batch_dims): (usize, &[usize]) = if x.ndim() == 0 {
286 (1, &[])
287 } else {
288 (x.dims()[0], &x.dims()[1..])
289 };
290 let columns = columns.unwrap_or(vector_len);
291 let output_dims = output_dims(&[vector_len, columns], batch_dims);
292 let output_numel = output_dims.iter().product::<usize>();
293
294 let x_input = ensure_col_major(x);
295 let base = if x.ndim() == 0 {
296 x_input.reshape(&[1])?
297 } else {
298 x_input
299 };
300 let memory_space = base.logical_memory_space();
301 if output_numel == 0 {
302 return Tensor::zeros(&output_dims, memory_space, MemoryOrder::ColumnMajor);
303 }
304
305 let mut current = Tensor::ones(base.dims(), memory_space, MemoryOrder::ColumnMajor)?;
306 let mut columns_out = Vec::with_capacity(columns);
307
308 columns_out.push(current.clone());
309 for _ in 1..columns {
310 let mut next = Tensor::zeros(base.dims(), memory_space, MemoryOrder::ColumnMajor)?;
311 crate::prims_bridge::scalar_binary_same_shape_into(
312 ctx,
313 ¤t,
314 &base,
315 tenferro_prims::ScalarBinaryOp::Mul,
316 &mut next,
317 )?;
318 columns_out.push(next.clone());
319 current = next;
320 }
321
322 if !increasing {
323 columns_out.reverse();
324 }
325
326 let column_refs: Vec<&Tensor<T>> = columns_out.iter().collect();
327 Tensor::stack(&column_refs, 1)
328}
329
330pub fn tensorinv<T: KernelLinalgScalar, C>(
350 ctx: &mut C,
351 tensor: &Tensor<T>,
352 ind: usize,
353) -> Result<Tensor<T>>
354where
355 T: KernelLinalgScalar,
356 C: backend::TensorLinalgContextFor<T>,
357 C::Backend: 'static,
358{
359 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::TensorInv, "tensorinv")?;
360
361 if ind == 0 || ind >= tensor.ndim() {
362 return Err(Error::InvalidArgument(format!(
363 "tensorinv expects 0 < ind < rank, got ind={ind} for shape {:?}",
364 tensor.dims()
365 )));
366 }
367
368 let left_dims = &tensor.dims()[..ind];
369 let right_dims = &tensor.dims()[ind..];
370 let left_prod = left_dims.iter().product::<usize>();
371 let right_prod = right_dims.iter().product::<usize>();
372 if left_prod != right_prod {
373 return Err(Error::InvalidArgument(format!(
374 "tensorinv requires prod(shape[..ind]) == prod(shape[ind..]); got {} and {} for {:?}",
375 left_prod,
376 right_prod,
377 tensor.dims()
378 )));
379 }
380
381 let input = ensure_col_major(tensor);
382 let matrix = input.reshape(&[left_prod, right_prod])?;
383 let inverse = inv(ctx, &matrix)?;
384
385 let mut out_dims = right_dims.to_vec();
386 out_dims.extend_from_slice(left_dims);
387 inverse.reshape(&out_dims)
388}
389
390pub fn tensorsolve<T: KernelLinalgScalar, C>(
409 ctx: &mut C,
410 a: &Tensor<T>,
411 b: &Tensor<T>,
412 dims: Option<&[usize]>,
413) -> Result<Tensor<T>>
414where
415 T: KernelLinalgScalar,
416 C: backend::TensorLinalgContextFor<T>,
417 C::Backend: 'static,
418{
419 require_linalg_support::<T, C>(backend::LinalgCapabilityOp::TensorSolve, "tensorsolve")?;
420
421 if b.ndim() > a.ndim() {
422 return Err(Error::InvalidArgument(format!(
423 "tensorsolve expects b rank <= a rank, got {:?} and {:?}",
424 a.dims(),
425 b.dims()
426 )));
427 }
428
429 let solution_rank = a.ndim() - b.ndim();
430 let solution_axes = validate_tensor_solve_axes(a.ndim(), solution_rank, dims)?;
431 let perm = axes_to_end_permutation(a.ndim(), &solution_axes);
432 let a_permuted = if is_identity_permutation(&perm) {
433 a.clone()
434 } else {
435 a.permute(&perm)?
436 };
437
438 if &a_permuted.dims()[..b.ndim()] != b.dims() {
439 return Err(Error::InvalidArgument(format!(
440 "tensorsolve leading dims of permuted a must match b; got {:?} and {:?}",
441 a_permuted.dims(),
442 b.dims()
443 )));
444 }
445
446 let lhs_prod = b.dims().iter().product::<usize>();
447 let rhs_dims = &a_permuted.dims()[b.ndim()..];
448 let rhs_prod = rhs_dims.iter().product::<usize>();
449 if lhs_prod != rhs_prod {
450 return Err(Error::InvalidArgument(format!(
451 "tensorsolve requires matching flattened system size, got {} and {}",
452 lhs_prod, rhs_prod
453 )));
454 }
455
456 let a_contiguous = ensure_col_major(&a_permuted);
457 let a_matrix = a_contiguous.reshape(&[lhs_prod, rhs_prod])?;
458 let b_contiguous = ensure_col_major(b);
459 let b_vector = b_contiguous.reshape(&[lhs_prod])?;
460 let x = solve(ctx, &a_matrix, &b_vector)?;
461 x.reshape(rhs_dims)
462}