tenferro_linalg/rrules/linear_systems/
solve.rs1use super::*;
2
3pub fn solve_rrule<T, C>(
24 ctx: &mut C,
25 a: &Tensor<T>,
26 b: &Tensor<T>,
27 cotangent: &Tensor<T>,
28) -> AdResult<SolveGrad<T>>
29where
30 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
31 C: backend::TensorLinalgContextFor<T> + tenferro_prims::TensorResolveConjContextFor<T>,
32 C::Backend: 'static,
33{
34 let x = solve(ctx, a, b).map_err(to_ad_err)?;
36 let (n, batch_dims) = validate_square(a).map_err(to_ad_err)?;
37 let rhs = backend::tensor_helpers::validate_solve_rhs_shape(b, n, batch_dims, "solve_rrule")
38 .map_err(to_ad_err)?;
39 let nrhs = rhs.nrhs;
40 let sr = rhs.structural_rank;
41
42 let dx_mat = rhs_to_matrix(cotangent, sr).map_err(to_ad_err)?;
44 let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
45
46 let a_h = matrix_adjoint_eager(ctx, a).map_err(to_ad_err)?;
48 let g = solve(ctx, &a_h, &dx_mat).map_err(to_ad_err)?;
49
50 let grad_b = matrix_to_rhs(g.clone(), sr).map_err(to_ad_err)?;
52
53 let x_h = matrix_adjoint_eager(ctx, &x_mat).map_err(to_ad_err)?;
55 let grad_a = prims_bridge::batched_gemm_alpha_tensors(ctx, &g, &x_h, n, nrhs, n, -T::one())
56 .map_err(to_ad_err)?;
57
58 Ok(SolveGrad {
59 a: grad_a,
60 b: grad_b,
61 })
62}
63
64pub fn solve_triangular_rrule<T, C>(
72 ctx: &mut C,
73 a: &Tensor<T>,
74 b: &Tensor<T>,
75 cotangent: &Tensor<T>,
76 upper: bool,
77) -> AdResult<SolveGrad<T>>
78where
79 T: KernelLinalgScalar + tenferro_algebra::Conjugate,
80 C: backend::TensorLinalgContextFor<T>
81 + tenferro_prims::TensorResolveConjContextFor<T>
82 + tenferro_prims::TensorMetadataContextFor,
83 C::Backend: 'static,
84{
85 let x = solve_triangular(ctx, a, b, upper).map_err(to_ad_err)?;
86 let (n, _) = validate_square(a).map_err(to_ad_err)?;
87 let rhs = backend::tensor_helpers::validate_solve_rhs_shape(
88 b,
89 n,
90 &a.dims()[2..],
91 "solve_triangular_rrule",
92 )
93 .map_err(to_ad_err)?;
94 let nrhs = rhs.nrhs;
95 let sr = rhs.structural_rank;
96
97 let dx_mat = rhs_to_matrix(cotangent, sr).map_err(to_ad_err)?;
99 let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
100
101 let a_h = matrix_adjoint_eager(ctx, a).map_err(to_ad_err)?;
103 let g = solve_triangular(ctx, &a_h, &dx_mat, !upper).map_err(to_ad_err)?;
104
105 let grad_b = matrix_to_rhs(g.clone(), sr).map_err(to_ad_err)?;
107
108 let x_h = matrix_adjoint_eager(ctx, &x_mat).map_err(to_ad_err)?;
110 let neg_g_xh = prims_bridge::batched_gemm_alpha_tensors(ctx, &g, &x_h, n, nrhs, n, -T::one())
111 .map_err(to_ad_err)?;
112 let grad_a = if upper {
113 tenferro_prims::tensor_ops::triu(ctx, &neg_g_xh, 0).map_err(to_ad_err)?
114 } else {
115 tenferro_prims::tensor_ops::tril(ctx, &neg_g_xh, 0).map_err(to_ad_err)?
116 };
117
118 Ok(SolveGrad {
119 a: grad_a,
120 b: grad_b,
121 })
122}