1use std::sync::Arc;
25
26use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
27use tenferro_ad::extension::{
28 ExtensionAdRule, ExtensionOp, ExtensionRegistryError, ExtensionRuleSet,
29};
30use tenferro_ops::ad::PrimitiveRuleBuilder;
31use tenferro_ops::std_tensor_op::StdTensorOp;
32use tenferro_ops::ShapeGuardContext;
33use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
34
35use crate::extension::{LinalgExtensionOp, LinalgOp};
36use crate::LINALG_EXTENSION_FAMILY_ID;
37
38mod rules;
39pub mod support;
40
41pub fn ad_rules() -> Result<ExtensionRuleSet, ExtensionRegistryError> {
50 ExtensionRuleSet::new().with_rule(Arc::new(LinalgAdRule))
51}
52
53#[derive(Debug)]
54struct LinalgAdRule;
55
56impl ExtensionAdRule for LinalgAdRule {
57 fn family_id(&self) -> &'static str {
58 LINALG_EXTENSION_FAMILY_ID
59 }
60
61 fn linearize(
62 &self,
63 op: &dyn ExtensionOp,
64 builder: &mut dyn PrimitiveRuleBuilder,
65 primal_in: &[ValueKey<StdTensorOp>],
66 primal_out: &[ValueKey<StdTensorOp>],
67 tangent_in: &[Option<LocalValueId>],
68 ctx: &mut ShapeGuardContext,
69 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
70 let op = downcast_ad_op(op, ADRuleKind::Jvp)?;
71 match op.op() {
72 LinalgOp::Lu => rules::linearize_lu(builder, primal_in, primal_out, tangent_in, ctx),
73 LinalgOp::LuFactor => Ok(vec![None; op.output_count()]),
74 LinalgOp::LuSolvePrepared {
75 transpose_a,
76 conjugate_a,
77 } => rules::linearize_lu_solve_prepared(
78 builder,
79 primal_in,
80 primal_out,
81 tangent_in,
82 transpose_a,
83 conjugate_a,
84 ctx,
85 ),
86 LinalgOp::FullPivLu => {
87 rules::linearize_full_piv_lu(builder, primal_in, primal_out, tangent_in, ctx)
88 }
89 LinalgOp::FullPivLuSolve { transpose_a } => rules::linearize_full_piv_lu_solve(
90 builder,
91 primal_in,
92 primal_out,
93 tangent_in,
94 transpose_a,
95 ctx,
96 ),
97 LinalgOp::TriangularSolve {
98 left_side,
99 lower,
100 transpose_a,
101 unit_diagonal,
102 } => rules::linearize_triangular_solve(
103 builder,
104 primal_in,
105 primal_out,
106 tangent_in,
107 rules::TriangularSolveFlags::new(left_side, lower, transpose_a, unit_diagonal),
108 ctx,
109 ),
110 LinalgOp::Cholesky => {
111 rules::linearize_cholesky(builder, primal_in, primal_out, tangent_in, ctx)
112 }
113 LinalgOp::Svd { eps } => {
114 rules::linearize_svd(builder, primal_in, primal_out, tangent_in, eps, ctx)
115 }
116 LinalgOp::SvdVals { eps } => {
117 rules::linearize_svd_values(builder, primal_in, tangent_in, eps, ctx)
118 }
119 LinalgOp::Qr => rules::linearize_qr(builder, primal_in, primal_out, tangent_in, ctx),
120 LinalgOp::Eigh { eps } => {
121 rules::linearize_eigh(builder, primal_in, primal_out, tangent_in, eps, ctx)
122 }
123 LinalgOp::EighVals { eps } => {
124 rules::linearize_eigh_values(builder, primal_in, tangent_in, eps, ctx)
125 }
126 LinalgOp::Eig { input_dtype } => Ok(rules::linearize_eig(
127 builder,
128 primal_in,
129 primal_out,
130 tangent_in,
131 input_dtype,
132 ctx,
133 )),
134 LinalgOp::EigVals { input_dtype } => Ok(rules::linearize_eig_values(
135 builder,
136 primal_in,
137 tangent_in,
138 input_dtype,
139 ctx,
140 )),
141 }
142 }
143
144 fn transpose_rule(
145 &self,
146 op: &dyn ExtensionOp,
147 builder: &mut dyn PrimitiveRuleBuilder,
148 cotangent_out: &[Option<LocalValueId>],
149 inputs: &[ValueRef<StdTensorOp>],
150 mode: &OperationRole,
151 ctx: &mut ShapeGuardContext,
152 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
153 let op = downcast_ad_op(op, ADRuleKind::Transpose)?;
154 let mut builder = DynBuilder(builder);
155 match op.op() {
156 LinalgOp::TriangularSolve {
157 left_side,
158 lower,
159 transpose_a,
160 unit_diagonal,
161 } => rules::transpose_triangular_solve(
162 &mut builder,
163 cotangent_out,
164 inputs,
165 mode,
166 rules::TriangularSolveFlags::new(left_side, lower, transpose_a, unit_diagonal),
167 ctx,
168 ),
169 LinalgOp::LuSolvePrepared {
170 transpose_a,
171 conjugate_a,
172 } => Ok(rules::transpose_lu_solve_prepared(
173 &mut builder,
174 cotangent_out,
175 inputs,
176 mode,
177 transpose_a,
178 conjugate_a,
179 ctx,
180 )),
181 LinalgOp::FullPivLuSolve { transpose_a } => rules::transpose_full_piv_lu_solve(
182 &mut builder,
183 cotangent_out,
184 inputs,
185 mode,
186 transpose_a,
187 ctx,
188 ),
189 LinalgOp::Cholesky
190 | LinalgOp::Lu
191 | LinalgOp::LuFactor
192 | LinalgOp::FullPivLu
193 | LinalgOp::Svd { .. }
194 | LinalgOp::SvdVals { .. }
195 | LinalgOp::Qr
196 | LinalgOp::Eigh { .. }
197 | LinalgOp::EighVals { .. }
198 | LinalgOp::Eig { .. }
199 | LinalgOp::EigVals { .. } => Ok(vec![None; op.input_count()]),
200 }
201 }
202}
203
204struct DynBuilder<'a>(&'a mut dyn PrimitiveRuleBuilder);
205
206impl PrimitiveRuleBuilder for DynBuilder<'_> {
207 fn add_operation(
208 &mut self,
209 op: StdTensorOp,
210 inputs: Vec<ValueRef<StdTensorOp>>,
211 mode: OperationRole,
212 ) -> Vec<LocalValueId> {
213 self.0.add_operation(op, inputs, mode)
214 }
215}
216
217fn downcast_ad_op(op: &dyn ExtensionOp, kind: ADRuleKind) -> ADRuleResult<&LinalgExtensionOp> {
218 op.as_any()
219 .downcast_ref::<LinalgExtensionOp>()
220 .ok_or_else(|| {
221 ADRuleError::invalid_input("tenferro-linalg.linalg.v1", kind, "payload type mismatch")
222 })
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use computegraph::graph::GraphBuilder;
229 use tenferro_ops::input_key::TensorInputKey;
230 use tenferro_ops::{ShapeExtent, SymDim, TensorMeta};
231 use tenferro_tensor::DType;
232
233 fn input_key(id: u64) -> ValueKey<StdTensorOp> {
234 ValueKey::Input(TensorInputKey::User { id })
235 }
236
237 fn insert_meta(ctx: &mut ShapeGuardContext, key: ValueKey<StdTensorOp>, shape: &[usize]) {
238 ctx.insert_metadata(
239 key,
240 TensorMeta::exact(
241 DType::F64,
242 shape.iter().copied().map(SymDim::from).collect(),
243 ),
244 );
245 }
246
247 #[test]
248 fn full_piv_lu_jvp_returns_inactive_outputs_for_non_square_input() {
249 let mut builder = GraphBuilder::<StdTensorOp>::new();
250 let mut ctx = ShapeGuardContext::default();
251 let primal = input_key(1);
252 insert_meta(&mut ctx, primal.clone(), &[2, 3]);
253 let tangent = builder.add_input(TensorInputKey::User { id: 2 });
254 let outputs = [
255 input_key(10),
256 input_key(11),
257 input_key(12),
258 input_key(13),
259 input_key(14),
260 ];
261 let op = LinalgExtensionOp::new(LinalgOp::FullPivLu);
262
263 let result = LinalgAdRule
264 .linearize(
265 &op,
266 &mut builder,
267 &[primal],
268 &outputs,
269 &[Some(tangent)],
270 &mut ctx,
271 )
272 .unwrap();
273
274 assert_eq!(result, vec![None, None, None, None, None]);
275 assert!(builder.build().operations().is_empty());
276 }
277
278 #[test]
279 fn triangular_solve_jvp_rejects_non_matrix_operands() {
280 let mut builder = GraphBuilder::<StdTensorOp>::new();
281 let mut ctx = ShapeGuardContext::default();
282 let lhs = input_key(20);
283 let rhs = input_key(21);
284 insert_meta(&mut ctx, lhs.clone(), &[2, 2]);
285 insert_meta(&mut ctx, rhs.clone(), &[2]);
286 let rhs_tangent = builder.add_input(TensorInputKey::User { id: 22 });
287 let op = LinalgExtensionOp::new(LinalgOp::TriangularSolve {
288 left_side: true,
289 lower: true,
290 transpose_a: false,
291 unit_diagonal: false,
292 });
293
294 let err = LinalgAdRule
295 .linearize(
296 &op,
297 &mut builder,
298 &[lhs, rhs],
299 &[input_key(23)],
300 &[None, Some(rhs_tangent)],
301 &mut ctx,
302 )
303 .unwrap_err();
304
305 assert_eq!(err.rule(), ADRuleKind::Jvp);
306 assert!(err
307 .to_string()
308 .contains("expected matrix operands with rank >= 2"));
309 assert!(builder.build().operations().is_empty());
310 }
311
312 #[test]
313 fn triangular_solve_jvp_accepts_upper_bound_matrix_metadata() {
314 let mut builder = GraphBuilder::<StdTensorOp>::new();
315 let mut ctx = ShapeGuardContext::default();
316 let lhs = input_key(30);
317 let rhs = input_key(31);
318 ctx.insert_metadata(
319 lhs.clone(),
320 TensorMeta::with_extents(
321 DType::F64,
322 vec![
323 ShapeExtent::upper_bound(SymDim::from(4usize)),
324 ShapeExtent::upper_bound(SymDim::from(4usize)),
325 ],
326 ),
327 );
328 ctx.insert_metadata(
329 rhs.clone(),
330 TensorMeta::with_extents(
331 DType::F64,
332 vec![
333 ShapeExtent::upper_bound(SymDim::from(4usize)),
334 ShapeExtent::upper_bound(SymDim::from(2usize)),
335 ],
336 ),
337 );
338 let rhs_tangent = builder.add_input(TensorInputKey::User { id: 32 });
339 let op = LinalgExtensionOp::new(LinalgOp::TriangularSolve {
340 left_side: true,
341 lower: true,
342 transpose_a: false,
343 unit_diagonal: false,
344 });
345
346 let result = LinalgAdRule
347 .linearize(
348 &op,
349 &mut builder,
350 &[lhs.clone(), rhs],
351 &[input_key(33)],
352 &[None, Some(rhs_tangent)],
353 &mut ctx,
354 )
355 .unwrap();
356
357 assert!(result[0].is_some());
358 let graph = builder.build();
359 assert_eq!(graph.operations().len(), 1);
360 let solve = &graph.operations()[0];
361 assert_eq!(solve.inputs[0], ValueRef::External(lhs));
362 assert_eq!(solve.inputs[1], ValueRef::Local(rhs_tangent));
363 }
364
365 #[test]
366 fn triangular_solve_transpose_accepts_upper_bound_matrix_metadata() {
367 let mut builder = GraphBuilder::<StdTensorOp>::new();
368 let mut ctx = ShapeGuardContext::default();
369 let lhs = input_key(40);
370 let rhs = input_key(41);
371 ctx.insert_metadata(
372 lhs.clone(),
373 TensorMeta::with_extents(
374 DType::F64,
375 vec![
376 ShapeExtent::upper_bound(SymDim::from(4usize)),
377 ShapeExtent::upper_bound(SymDim::from(4usize)),
378 ],
379 ),
380 );
381 ctx.insert_metadata(
382 rhs.clone(),
383 TensorMeta::with_extents(
384 DType::F64,
385 vec![
386 ShapeExtent::upper_bound(SymDim::from(4usize)),
387 ShapeExtent::upper_bound(SymDim::from(2usize)),
388 ],
389 ),
390 );
391 let cotangent = builder.add_input(TensorInputKey::User { id: 42 });
392 let op = LinalgExtensionOp::new(LinalgOp::TriangularSolve {
393 left_side: true,
394 lower: true,
395 transpose_a: false,
396 unit_diagonal: false,
397 });
398
399 let result = LinalgAdRule
400 .transpose_rule(
401 &op,
402 &mut builder,
403 &[Some(cotangent)],
404 &[ValueRef::External(lhs.clone()), ValueRef::External(rhs)],
405 &OperationRole::Linearized {
406 active_mask: vec![false, true],
407 },
408 &mut ctx,
409 )
410 .unwrap();
411
412 assert_eq!(result[0], None);
413 assert!(result[1].is_some());
414 let graph = builder.build();
415 assert_eq!(graph.operations().len(), 1);
416 assert_eq!(graph.operations()[0].inputs[0], ValueRef::External(lhs));
417 assert_eq!(graph.operations()[0].inputs[1], ValueRef::Local(cotangent));
418 }
419
420 #[test]
421 fn cholesky_jvp_uses_rank_when_input_metadata_is_upper_bound() {
422 let mut builder = GraphBuilder::<StdTensorOp>::new();
423 let mut ctx = ShapeGuardContext::default();
424 let primal = input_key(50);
425 ctx.insert_metadata(
426 primal.clone(),
427 TensorMeta::with_extents(
428 DType::F64,
429 vec![
430 ShapeExtent::upper_bound(SymDim::from(4usize)),
431 ShapeExtent::upper_bound(SymDim::from(4usize)),
432 ],
433 ),
434 );
435 let tangent = builder.add_input(TensorInputKey::User { id: 51 });
436 let op = LinalgExtensionOp::new(LinalgOp::Cholesky);
437
438 let result = LinalgAdRule
439 .linearize(
440 &op,
441 &mut builder,
442 &[primal],
443 &[input_key(52)],
444 &[Some(tangent)],
445 &mut ctx,
446 )
447 .unwrap();
448
449 assert!(result[0].is_some());
450 assert!(!builder.build().operations().is_empty());
451 }
452
453 #[test]
454 fn one_input_linalg_jvps_return_inactive_for_non_matrix_input() {
455 let cases = [
456 LinalgOp::Cholesky,
457 LinalgOp::Lu,
458 LinalgOp::FullPivLu,
459 LinalgOp::Svd { eps: 1e-12 },
460 LinalgOp::SvdVals { eps: 1e-12 },
461 LinalgOp::Qr,
462 LinalgOp::Eigh { eps: 1e-12 },
463 LinalgOp::EighVals { eps: 1e-12 },
464 LinalgOp::Eig {
465 input_dtype: DType::F64,
466 },
467 LinalgOp::EigVals {
468 input_dtype: DType::F64,
469 },
470 ];
471
472 for (case_index, kind) in cases.into_iter().enumerate() {
473 let mut builder = GraphBuilder::<StdTensorOp>::new();
474 let mut ctx = ShapeGuardContext::default();
475 let primal = input_key(100 + case_index as u64);
476 insert_meta(&mut ctx, primal.clone(), &[3]);
477 let tangent = builder.add_input(TensorInputKey::User {
478 id: 200 + case_index as u64,
479 });
480 let op = LinalgExtensionOp::new(kind);
481 let outputs: Vec<_> = (0..op.output_count())
482 .map(|offset| input_key(300 + case_index as u64 * 10 + offset as u64))
483 .collect();
484
485 let result = LinalgAdRule
486 .linearize(
487 &op,
488 &mut builder,
489 &[primal],
490 &outputs,
491 &[Some(tangent)],
492 &mut ctx,
493 )
494 .unwrap();
495
496 assert_eq!(result, vec![None; op.output_count()], "{kind:?}");
497 assert!(
498 builder.build().operations().is_empty(),
499 "{kind:?} should not emit a malformed matrix AD graph"
500 );
501 }
502 }
503}