tenferro_linalg/frules/linear_systems/
solve.rs1use super::*;
2
3pub fn solve_frule<T: KernelLinalgScalar, C>(
23 ctx: &mut C,
24 a: &Tensor<T>,
25 b: &Tensor<T>,
26 tangent_a: &Tensor<T>,
27 tangent_b: &Tensor<T>,
28) -> AdResult<(Tensor<T>, Tensor<T>)>
29where
30 C: backend::TensorLinalgContextFor<T>
31 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>,
32 C::Backend: 'static,
33 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
34 'static + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>,
35{
36 let x = solve(ctx, a, b).map_err(to_ad_err)?;
38 let (n, batch_dims) = validate_square(a).map_err(to_ad_err)?;
39 let rhs = backend::tensor_helpers::validate_solve_rhs_shape(b, n, batch_dims, "solve_frule")
40 .map_err(to_ad_err)?;
41 let nrhs = rhs.nrhs;
42 let sr = rhs.structural_rank;
43
44 let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
46 let db_mat = rhs_to_matrix(tangent_b, sr).map_err(to_ad_err)?;
47
48 let da_x = prims_bridge::batched_gemm_with_semiring_tensors(ctx, tangent_a, &x_mat, n, n, nrhs)
50 .map_err(to_ad_err)?;
51
52 let rhs_tangent = prims_bridge::scalar_binary_same_shape(
54 ctx,
55 &db_mat,
56 &da_x,
57 tenferro_prims::ScalarBinaryOp::Sub,
58 )
59 .map_err(to_ad_err)?;
60
61 let dx_mat = solve(ctx, a, &rhs_tangent).map_err(to_ad_err)?;
63 let dx = matrix_to_rhs(dx_mat, sr).map_err(to_ad_err)?;
64
65 Ok((x, dx))
66}
67
68pub fn solve_triangular_frule<T: KernelLinalgScalar, C>(
77 ctx: &mut C,
78 a: &Tensor<T>,
79 b: &Tensor<T>,
80 tangent_a: &Tensor<T>,
81 tangent_b: &Tensor<T>,
82 upper: bool,
83) -> AdResult<(Tensor<T>, Tensor<T>)>
84where
85 T: KernelLinalgScalar,
86 C: backend::TensorLinalgContextFor<T>
87 + tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>
88 + tenferro_prims::TensorMetadataContextFor,
89 C::Backend: 'static,
90 C::MetadataBackend: tenferro_prims::TensorMetadataPrims<Context = C>,
91 <C as tenferro_prims::TensorScalarContextFor<tenferro_algebra::Standard<T>>>::ScalarBackend:
92 'static
93 + tenferro_prims::TensorAnalyticPrims<tenferro_algebra::Standard<T>, Context = C>
94 + tenferro_prims::TensorMetadataCastPrims<T, Context = C>,
95{
96 if tangent_a.dims() != a.dims() {
97 return Err(chainrules_core::AutodiffError::InvalidArgument(format!(
98 "solve_triangular_frule: tangent_a shape mismatch: expected {:?}, got {:?}",
99 a.dims(),
100 tangent_a.dims()
101 )));
102 }
103 if tangent_b.dims() != b.dims() {
104 return Err(chainrules_core::AutodiffError::InvalidArgument(format!(
105 "solve_triangular_frule: tangent_b shape mismatch: expected {:?}, got {:?}",
106 b.dims(),
107 tangent_b.dims()
108 )));
109 }
110
111 let x = solve_triangular(ctx, a, b, upper).map_err(to_ad_err)?;
113 let (n, _) = validate_square(a).map_err(to_ad_err)?;
114 let rhs = backend::tensor_helpers::validate_solve_rhs_shape(
115 b,
116 n,
117 &a.dims()[2..],
118 "solve_triangular_frule",
119 )
120 .map_err(to_ad_err)?;
121 let nrhs = rhs.nrhs;
122 let sr = rhs.structural_rank;
123
124 let x_mat = rhs_to_matrix(&x, sr).map_err(to_ad_err)?;
125 let db_mat = rhs_to_matrix(tangent_b, sr).map_err(to_ad_err)?;
126
127 let da_proj = if upper {
129 tenferro_prims::tensor_ops::triu(ctx, tangent_a, 0).map_err(to_ad_err)?
130 } else {
131 tenferro_prims::tensor_ops::tril(ctx, tangent_a, 0).map_err(to_ad_err)?
132 };
133
134 let da_x = prims_bridge::batched_gemm_with_semiring_tensors(ctx, &da_proj, &x_mat, n, n, nrhs)
136 .map_err(to_ad_err)?;
137
138 let rhs_tangent = prims_bridge::scalar_binary_same_shape(
140 ctx,
141 &db_mat,
142 &da_x,
143 tenferro_prims::ScalarBinaryOp::Sub,
144 )
145 .map_err(to_ad_err)?;
146
147 let dx_mat = solve_triangular(ctx, a, &rhs_tangent, upper).map_err(to_ad_err)?;
149 let dx = matrix_to_rhs(dx_mat, sr).map_err(to_ad_err)?;
150
151 Ok((x, dx))
152}