tenferro_linalg/frules/linear_systems/
solve.rs

1use super::*;
2
3/// Forward-mode AD rule for linear solve (JVP / pushforward).
4///
5/// # Examples
6///
7/// ```
8/// use tenferro_linalg::solve_frule;
9/// use tenferro_prims::CpuContext;
10/// use tenferro_tensor::{Tensor, MemoryOrder};
11/// use tenferro_device::LogicalMemorySpace;
12///
13/// let col = MemoryOrder::ColumnMajor;
14/// let mem = LogicalMemorySpace::MainMemory;
15/// let mut ctx = CpuContext::new(1);
16/// let a = Tensor::<f64>::eye(3, mem, col).unwrap();
17/// let b = Tensor::<f64>::ones(&[3], mem, col).unwrap();
18/// let da = Tensor::<f64>::ones(&[3, 3], mem, col).unwrap();
19/// let db = Tensor::<f64>::ones(&[3], mem, col).unwrap();
20/// let (x, dx) = solve_frule(&mut ctx, &a, &b, &da, &db).unwrap();
21/// ```
22pub fn solve_frule<T: KernelLinalgScalar, C>(
23    ctx: &mut C,
24    a: &Tensor<T>,
25    b: &Tensor<T>,
26    tangent_a: &Tensor<T>,
27    tangent_b: &Tensor<T>,
28) -> AdResult<(Tensor<T>, Tensor<T>)>
29where
30    C: backend::TensorLinalgContextFor<T>
31        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
32    C::Backend: 'static,
33    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
34        'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
35{
36    // dx = A^{-1} (db - dA x)
37    let x = solve(ctx, a, b).map_err(to_ad_err)?;
38    let (n, batch_dims) = validate_square(a).map_err(to_ad_err)?;
39    let rhs = backend::tensor_helpers::validate_solve_rhs_shape(b, n, batch_dims, "solve_frule")
40        .map_err(to_ad_err)?;
41    let nrhs = rhs.nrhs;
42    let sr = rhs.structural_rank;
43
44    // Promote x and tangent_b to matrix form [n, nrhs, batch...]
45    let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
46    let db_mat = rhs_to_matrix(tangent_b, sr).map_err(to_ad_err)?;
47
48    // dA @ x  (n×n @ n×nrhs = n×nrhs)
49    let da_x = prims_bridge::batched_gemm_with_semiring_tensors(ctx, tangent_a, &x_mat, n, n, nrhs)
50        .map_err(to_ad_err)?;
51
52    // db - dA @ x
53    let rhs_tangent = prims_bridge::scalar_binary_same_shape(
54        ctx,
55        &db_mat,
56        &da_x,
57        tenferro_prims::ScalarBinaryOp::Sub,
58    )
59    .map_err(to_ad_err)?;
60
61    // dx = A^{-1} (db - dA x)
62    let dx_mat = solve(ctx, a, &rhs_tangent).map_err(to_ad_err)?;
63    let dx = matrix_to_rhs(dx_mat, sr).map_err(to_ad_err)?;
64
65    Ok((x, dx))
66}
67
68/// Forward-mode AD rule for triangular solve (JVP / pushforward).
69///
70/// Computes:
71/// - `x = solve_triangular(a, b, upper)`
72/// - `dx = solve_triangular(a, db - proj(dA) * x, upper)`
73///
74/// where `proj(dA)` keeps only the active triangular part
75/// (`triu` when `upper=true`, `tril` when `upper=false`).
76pub fn solve_triangular_frule<T: KernelLinalgScalar, C>(
77    ctx: &mut C,
78    a: &Tensor<T>,
79    b: &Tensor<T>,
80    tangent_a: &Tensor<T>,
81    tangent_b: &Tensor<T>,
82    upper: bool,
83) -> AdResult<(Tensor<T>, Tensor<T>)>
84where
85    T: KernelLinalgScalar,
86    C: backend::TensorLinalgContextFor<T>
87        + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
88        + tenferro_prims::TensorMetadataContextFor,
89    C::Backend: 'static,
90    C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
91    <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
92        'static
93            + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>
94            + tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
95{
96    if tangent_a.dims() != a.dims() {
97        return Err(chainrules_core::AutodiffError::InvalidArgument(format!(
98            "solve_triangular_frule: tangent_a shape mismatch: expected {:?}, got {:?}",
99            a.dims(),
100            tangent_a.dims()
101        )));
102    }
103    if tangent_b.dims() != b.dims() {
104        return Err(chainrules_core::AutodiffError::InvalidArgument(format!(
105            "solve_triangular_frule: tangent_b shape mismatch: expected {:?}, got {:?}",
106            b.dims(),
107            tangent_b.dims()
108        )));
109    }
110
111    // dX = A^{-1} (dB - proj(dA) X), with projection to the triangular tangent space.
112    let x = solve_triangular(ctx, a, b, upper).map_err(to_ad_err)?;
113    let (n, _) = validate_square(a).map_err(to_ad_err)?;
114    let rhs = backend::tensor_helpers::validate_solve_rhs_shape(
115        b,
116        n,
117        &a.dims()[2..],
118        "solve_triangular_frule",
119    )
120    .map_err(to_ad_err)?;
121    let nrhs = rhs.nrhs;
122    let sr = rhs.structural_rank;
123
124    let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
125    let db_mat = rhs_to_matrix(tangent_b, sr).map_err(to_ad_err)?;
126
127    // Project dA onto the same triangular structure as A.
128    let da_proj = if upper {
129        tenferro_prims::tensor_ops::triu(ctx, tangent_a, 0).map_err(to_ad_err)?
130    } else {
131        tenferro_prims::tensor_ops::tril(ctx, tangent_a, 0).map_err(to_ad_err)?
132    };
133
134    // proj(dA) @ x
135    let da_x = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &da_proj, &x_mat, n, n, nrhs)
136        .map_err(to_ad_err)?;
137
138    // dB - proj(dA) @ x
139    let rhs_tangent = prims_bridge::scalar_binary_same_shape(
140        ctx,
141        &db_mat,
142        &da_x,
143        tenferro_prims::ScalarBinaryOp::Sub,
144    )
145    .map_err(to_ad_err)?;
146
147    // dX = solve_triangular(A, rhs_tangent, upper)
148    let dx_mat = solve_triangular(ctx, a, &rhs_tangent, upper).map_err(to_ad_err)?;
149    let dx = matrix_to_rhs(dx_mat, sr).map_err(to_ad_err)?;
150
151    Ok((x, dx))
152}