tenferro_linalg/primal/
tensor_ops.rs

1use super::*;
2
3/// Compute the cross product along the leading vector axis.
4///
5/// Both inputs must have a leading dimension of size 3.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_linalg::cross;
11/// use tenferro_prims::CpuContext;
12/// use tenferro_tensor::{MemoryOrder, Tensor};
13///
14/// let mut ctx = CpuContext::new(1);
15/// let col = MemoryOrder::ColumnMajor;
16/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0], &[3], col).unwrap();
17/// let b = Tensor::<f64>::from_slice(&[0.0, 1.0, 0.0], &[3], col).unwrap();
18/// let c = cross(&mut ctx, &a, &b).unwrap();
19/// assert_eq!(c.dims(), &[3]);
20/// ```
21pub fn cross<T: KernelLinalgScalar, C>(
22    ctx: &mut C,
23    a: &Tensor<T>,
24    b: &Tensor<T>,
25) -> Result<Tensor<T>>
26where
27    C: backend::TensorLinalgContextFor<T>
28        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
29    C::Backend: 'static,
30{
31    if a.ndim() != b.ndim() {
32        return Err(Error::InvalidArgument(format!(
33            "cross expects matching ranks, got {:?} and {:?}",
34            a.dims(),
35            b.dims()
36        )));
37    }
38    if a.ndim() == 0 || a.dims()[0] != 3 {
39        return Err(Error::InvalidArgument(format!(
40            "cross expects leading vector dimension of size 3, got {:?}",
41            a.dims()
42        )));
43    }
44    if b.ndim() == 0 || b.dims()[0] != 3 {
45        return Err(Error::InvalidArgument(format!(
46            "cross expects leading vector dimension of size 3, got {:?}",
47            b.dims()
48        )));
49    }
50    let mut out_dims = vec![3];
51    for axis in 1..a.ndim() {
52        let lhs = a.dims()[axis];
53        let rhs = b.dims()[axis];
54        if lhs != rhs && lhs != 1 && rhs != 1 {
55            return Err(Error::InvalidArgument(format!(
56                "cross broadcast mismatch on axis {axis}: left={}, right={}",
57                lhs, rhs
58            )));
59        }
60        out_dims.push(lhs.max(rhs));
61    }
62
63    let a_input = ensure_col_major(a);
64    let b_input = ensure_col_major(b);
65    let out_tail_dims = &out_dims[1..];
66    let ax = a_input.select(0, 0)?.broadcast(out_tail_dims)?;
67    let ay = a_input.select(0, 1)?.broadcast(out_tail_dims)?;
68    let az = a_input.select(0, 2)?.broadcast(out_tail_dims)?;
69    let bx = b_input.select(0, 0)?.broadcast(out_tail_dims)?;
70    let by = b_input.select(0, 1)?.broadcast(out_tail_dims)?;
71    let bz = b_input.select(0, 2)?.broadcast(out_tail_dims)?;
72
73    let ay_bz = crate::prims_bridge::scalar_binary_same_shape(
74        ctx,
75        &ay,
76        &bz,
77        tenferro_prims::ScalarBinaryOp::Mul,
78    )?;
79    let az_by = crate::prims_bridge::scalar_binary_same_shape(
80        ctx,
81        &az,
82        &by,
83        tenferro_prims::ScalarBinaryOp::Mul,
84    )?;
85    let az_bx = crate::prims_bridge::scalar_binary_same_shape(
86        ctx,
87        &az,
88        &bx,
89        tenferro_prims::ScalarBinaryOp::Mul,
90    )?;
91    let ax_bz = crate::prims_bridge::scalar_binary_same_shape(
92        ctx,
93        &ax,
94        &bz,
95        tenferro_prims::ScalarBinaryOp::Mul,
96    )?;
97    let ax_by = crate::prims_bridge::scalar_binary_same_shape(
98        ctx,
99        &ax,
100        &by,
101        tenferro_prims::ScalarBinaryOp::Mul,
102    )?;
103    let ay_bx = crate::prims_bridge::scalar_binary_same_shape(
104        ctx,
105        &ay,
106        &bx,
107        tenferro_prims::ScalarBinaryOp::Mul,
108    )?;
109
110    let out_x = crate::prims_bridge::scalar_binary_same_shape(
111        ctx,
112        &ay_bz,
113        &az_by,
114        tenferro_prims::ScalarBinaryOp::Sub,
115    )?;
116    let out_y = crate::prims_bridge::scalar_binary_same_shape(
117        ctx,
118        &az_bx,
119        &ax_bz,
120        tenferro_prims::ScalarBinaryOp::Sub,
121    )?;
122    let out_z = crate::prims_bridge::scalar_binary_same_shape(
123        ctx,
124        &ax_by,
125        &ay_bx,
126        tenferro_prims::ScalarBinaryOp::Sub,
127    )?;
128
129    Tensor::stack(&[&out_x, &out_y, &out_z], 0)
130}
131
132/// Form the explicit product of Householder reflectors.
133///
134/// Given a lower-triangular Householder factor matrix `a` of shape `(m, n, *)`
135/// and a vector `tau` of shape `(k, *)`, computes the orthogonal matrix `Q`.
136///
137/// # Examples
138///
139/// ```
140/// use tenferro_linalg::householder_product;
141/// use tenferro_prims::CpuContext;
142/// use tenferro_tensor::{MemoryOrder, Tensor};
143///
144/// let mut ctx = CpuContext::new(1);
145/// let col = MemoryOrder::ColumnMajor;
146/// // Typically obtained from an intermediate QR step.
147/// let a = Tensor::<f64>::from_slice(
148///     &[1.0, 0.5, 0.0, 1.0], &[2, 2], col,
149/// ).unwrap();
150/// let tau = Tensor::<f64>::from_slice(&[1.0, 0.5], &[2], col).unwrap();
151/// let q = householder_product(&mut ctx, &a, &tau).unwrap();
152/// assert_eq!(q.dims(), &[2, 2]);
153/// ```
154pub fn householder_product<T: KernelLinalgScalar + tenferro_algebra::Conjugate, C>(
155    ctx: &mut C,
156    a: &Tensor<T>,
157    tau: &Tensor<T>,
158) -> Result<Tensor<T>>
159where
160    C: backend::TensorLinalgContextFor<T>
161        + tenferro_prims::TensorResolveConjContextFor<T>
162        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
163        + tenferro_prims::TensorSemiringContextFor<tenferro_algebra::Standard<T>>,
164    C::Backend: 'static,
165{
166    let (m, n, batch_dims) = validate_2d(a)?;
167    if tau.ndim() != 1 + batch_dims.len() {
168        return Err(Error::InvalidArgument(format!(
169            "householder_product expects tau shape (k, *), got {:?}",
170            tau.dims()
171        )));
172    }
173    if &tau.dims()[1..] != batch_dims {
174        return Err(Error::InvalidArgument(format!(
175            "householder_product batch dims mismatch: expected {:?}, got {:?}",
176            batch_dims,
177            &tau.dims()[1..]
178        )));
179    }
180
181    let k = tau.dims()[0];
182    if k > m.min(n) {
183        return Err(Error::InvalidArgument(format!(
184            "householder_product expects tau length <= min(m, n) = {}, got {}",
185            m.min(n),
186            k
187        )));
188    }
189
190    let a_input = ensure_col_major(a);
191    let tau_input = ensure_col_major(tau);
192    let memory_space = a_input.logical_memory_space();
193    let mut q = crate::prims_bridge::identity_matrix(m, memory_space)?.narrow(1, 0, n)?;
194    for _ in batch_dims {
195        q = q.unsqueeze(-1)?;
196    }
197    let q_target_dims = output_dims(&[m, n], batch_dims);
198    q = q.broadcast(&q_target_dims)?;
199
200    let vector_tail_dims = {
201        let mut dims = Vec::with_capacity(1 + batch_dims.len());
202        dims.push(1);
203        dims.extend_from_slice(batch_dims);
204        dims
205    };
206
207    for reflector in (0..k).rev() {
208        let tail_rows = m - reflector;
209        let tail = if tail_rows == 1 {
210            Tensor::ones(&vector_tail_dims, memory_space, MemoryOrder::ColumnMajor)?
211        } else {
212            let head = Tensor::ones(&vector_tail_dims, memory_space, MemoryOrder::ColumnMajor)?;
213            let lower = a_input
214                .narrow(0, reflector + 1, tail_rows - 1)?
215                .select(1, reflector)?;
216            Tensor::cat(&[&head, &lower], 0)?
217        };
218
219        let q_tail = q.narrow(0, reflector, tail_rows)?;
220        let v_col = tail.unsqueeze(1)?;
221        let mut adj_perm: Vec<usize> = (0..v_col.ndim()).collect();
222        adj_perm.swap(0, 1);
223        let v_adj_view = v_col.conj().permute(&adj_perm)?;
224        let v_adj = crate::prims_bridge::resolve_conj(ctx, &v_adj_view);
225        let reflected = crate::prims_bridge::batched_gemm_with_semiring_tensors(
226            ctx, &v_adj, &q_tail, 1, tail_rows, n,
227        )?;
228
229        let mut tau_scale = tau_input.select(0, reflector)?;
230        tau_scale = tau_scale.unsqueeze(0)?;
231        tau_scale = tau_scale.unsqueeze(0)?;
232        let tau_scale = tau_scale.broadcast(reflected.dims())?;
233        let scaled = crate::prims_bridge::scalar_binary_same_shape(
234            ctx,
235            &tau_scale,
236            &reflected,
237            tenferro_prims::ScalarBinaryOp::Mul,
238        )?;
239        let update = crate::prims_bridge::batched_gemm_with_semiring_tensors(
240            ctx, &v_col, &scaled, tail_rows, 1, n,
241        )?;
242        let updated_tail = crate::prims_bridge::scalar_binary_same_shape(
243            ctx,
244            &q_tail,
245            &update,
246            tenferro_prims::ScalarBinaryOp::Sub,
247        )?;
248
249        q = if reflector == 0 {
250            updated_tail
251        } else {
252            let prefix = q.narrow(0, 0, reflector)?;
253            Tensor::cat(&[&prefix, &updated_tail], 0)?
254        };
255    }
256
257    Ok(q)
258}
259
260/// Build a Vandermonde matrix from leading-dimension vectors.
261///
262/// # Examples
263///
264/// ```
265/// use tenferro_linalg::vander;
266/// use tenferro_prims::CpuContext;
267/// use tenferro_tensor::{MemoryOrder, Tensor};
268///
269/// let mut ctx = CpuContext::new(1);
270/// let col = MemoryOrder::ColumnMajor;
271/// let x = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0], &[3], col).unwrap();
272/// let v = vander(&mut ctx, &x, None, false).unwrap();
273/// assert_eq!(v.dims(), &[3, 3]);
274/// ```
275pub fn vander<T: KernelLinalgScalar, C>(
276    ctx: &mut C,
277    x: &Tensor<T>,
278    columns: Option<usize>,
279    increasing: bool,
280) -> Result<Tensor<T>>
281where
282    C: backend::TensorLinalgContextFor<T>
283        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
284{
285    let (vector_len, batch_dims): (usize, &[usize]) = if x.ndim() == 0 {
286        (1, &[])
287    } else {
288        (x.dims()[0], &x.dims()[1..])
289    };
290    let columns = columns.unwrap_or(vector_len);
291    let output_dims = output_dims(&[vector_len, columns], batch_dims);
292    let output_numel = output_dims.iter().product::<usize>();
293
294    let x_input = ensure_col_major(x);
295    let base = if x.ndim() == 0 {
296        x_input.reshape(&[1])?
297    } else {
298        x_input
299    };
300    let memory_space = base.logical_memory_space();
301    if output_numel == 0 {
302        return Tensor::zeros(&output_dims, memory_space, MemoryOrder::ColumnMajor);
303    }
304
305    let mut current = Tensor::ones(base.dims(), memory_space, MemoryOrder::ColumnMajor)?;
306    let mut columns_out = Vec::with_capacity(columns);
307
308    columns_out.push(current.clone());
309    for _ in 1..columns {
310        let mut next = Tensor::zeros(base.dims(), memory_space, MemoryOrder::ColumnMajor)?;
311        crate::prims_bridge::scalar_binary_same_shape_into(
312            ctx,
313            &current,
314            &base,
315            tenferro_prims::ScalarBinaryOp::Mul,
316            &mut next,
317        )?;
318        columns_out.push(next.clone());
319        current = next;
320    }
321
322    if !increasing {
323        columns_out.reverse();
324    }
325
326    let column_refs: Vec<&Tensor<T>> = columns_out.iter().collect();
327    Tensor::stack(&column_refs, 1)
328}
329
330/// Invert a tensorized square operator.
331///
332/// Reshapes the tensor into a square matrix using `ind` to split the
333/// dimensions, computes the inverse, and reshapes back.
334///
335/// # Examples
336///
337/// ```
338/// use tenferro_linalg::tensorinv;
339/// use tenferro_prims::CpuContext;
340/// use tenferro_tensor::{MemoryOrder, Tensor};
341///
342/// let mut ctx = CpuContext::new(1);
343/// let col = MemoryOrder::ColumnMajor;
344/// // Shape [2, 2] with ind=1: left product = 2, right product = 2.
345/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], col).unwrap();
346/// let inv = tensorinv(&mut ctx, &a, 1).unwrap();
347/// assert_eq!(inv.dims(), &[2, 2]);
348/// ```
349pub fn tensorinv<T: KernelLinalgScalar, C>(
350    ctx: &mut C,
351    tensor: &Tensor<T>,
352    ind: usize,
353) -> Result<Tensor<T>>
354where
355    T: KernelLinalgScalar,
356    C: backend::TensorLinalgContextFor<T>,
357    C::Backend: 'static,
358{
359    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::TensorInv, "tensorinv")?;
360
361    if ind == 0 || ind >= tensor.ndim() {
362        return Err(Error::InvalidArgument(format!(
363            "tensorinv expects 0 < ind < rank, got ind={ind} for shape {:?}",
364            tensor.dims()
365        )));
366    }
367
368    let left_dims = &tensor.dims()[..ind];
369    let right_dims = &tensor.dims()[ind..];
370    let left_prod = left_dims.iter().product::<usize>();
371    let right_prod = right_dims.iter().product::<usize>();
372    if left_prod != right_prod {
373        return Err(Error::InvalidArgument(format!(
374            "tensorinv requires prod(shape[..ind]) == prod(shape[ind..]); got {} and {} for {:?}",
375            left_prod,
376            right_prod,
377            tensor.dims()
378        )));
379    }
380
381    let input = ensure_col_major(tensor);
382    let matrix = input.reshape(&[left_prod, right_prod])?;
383    let inverse = inv(ctx, &matrix)?;
384
385    let mut out_dims = right_dims.to_vec();
386    out_dims.extend_from_slice(left_dims);
387    inverse.reshape(&out_dims)
388}
389
390/// Solve a tensorized linear system.
391///
392/// Reshapes `a` and `b` into a standard linear system and solves it.
393///
394/// # Examples
395///
396/// ```
397/// use tenferro_linalg::tensorsolve;
398/// use tenferro_prims::CpuContext;
399/// use tenferro_tensor::{MemoryOrder, Tensor};
400///
401/// let mut ctx = CpuContext::new(1);
402/// let col = MemoryOrder::ColumnMajor;
403/// let a = Tensor::<f64>::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2], col).unwrap();
404/// let b = Tensor::<f64>::from_slice(&[3.0, 4.0], &[2], col).unwrap();
405/// let x = tensorsolve(&mut ctx, &a, &b, None).unwrap();
406/// assert_eq!(x.dims(), &[2]);
407/// ```
408pub fn tensorsolve<T: KernelLinalgScalar, C>(
409    ctx: &mut C,
410    a: &Tensor<T>,
411    b: &Tensor<T>,
412    dims: Option<&[usize]>,
413) -> Result<Tensor<T>>
414where
415    T: KernelLinalgScalar,
416    C: backend::TensorLinalgContextFor<T>,
417    C::Backend: 'static,
418{
419    require_linalg_support::<T, C>(backend::LinalgCapabilityOp::TensorSolve, "tensorsolve")?;
420
421    if b.ndim() > a.ndim() {
422        return Err(Error::InvalidArgument(format!(
423            "tensorsolve expects b rank <= a rank, got {:?} and {:?}",
424            a.dims(),
425            b.dims()
426        )));
427    }
428
429    let solution_rank = a.ndim() - b.ndim();
430    let solution_axes = validate_tensor_solve_axes(a.ndim(), solution_rank, dims)?;
431    let perm = axes_to_end_permutation(a.ndim(), &solution_axes);
432    let a_permuted = if is_identity_permutation(&perm) {
433        a.clone()
434    } else {
435        a.permute(&perm)?
436    };
437
438    if &a_permuted.dims()[..b.ndim()] != b.dims() {
439        return Err(Error::InvalidArgument(format!(
440            "tensorsolve leading dims of permuted a must match b; got {:?} and {:?}",
441            a_permuted.dims(),
442            b.dims()
443        )));
444    }
445
446    let lhs_prod = b.dims().iter().product::<usize>();
447    let rhs_dims = &a_permuted.dims()[b.ndim()..];
448    let rhs_prod = rhs_dims.iter().product::<usize>();
449    if lhs_prod != rhs_prod {
450        return Err(Error::InvalidArgument(format!(
451            "tensorsolve requires matching flattened system size, got {} and {}",
452            lhs_prod, rhs_prod
453        )));
454    }
455
456    let a_contiguous = ensure_col_major(&a_permuted);
457    let a_matrix = a_contiguous.reshape(&[lhs_prod, rhs_prod])?;
458    let b_contiguous = ensure_col_major(b);
459    let b_vector = b_contiguous.reshape(&[lhs_prod])?;
460    let x = solve(ctx, &a_matrix, &b_vector)?;
461    x.reshape(rhs_dims)
462}