Skip to main content

tenferro_linalg/
ad.rs

1//! Automatic differentiation support for `tenferro-linalg`.
2//!
3//! This module is enabled by the `autodiff` feature. It provides the linalg
4//! extension rule set used by explicit `tenferro_ad::AdContext` values.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use tenferro_ad::AdContext;
10//! use tenferro_linalg::TracedTensorLinalgExt;
11//! use tenferro_runtime::TracedTensor;
12//!
13//! let ad = AdContext::builder()
14//!     .with_extension_rules(tenferro_linalg::ad_rules().unwrap())
15//!     .build()
16//!     .unwrap();
17//! let x = TracedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 0.0, 0.0, 2.0]).unwrap();
18//! let (_u, s, _vt) = x.svd().unwrap();
19//! let loss = s.reduce_sum(&[0]).unwrap();
20//! let grad = ad.grad(&loss, &x).unwrap();
21//! assert_eq!(grad.rank, 2);
22//! ```
23
24use std::sync::Arc;
25
26use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
27use tenferro_ad::extension::{
28    ExtensionAdRule, ExtensionOp, ExtensionRegistryError, ExtensionRuleSet,
29};
30use tenferro_ops::ad::PrimitiveRuleBuilder;
31use tenferro_ops::std_tensor_op::StdTensorOp;
32use tenferro_ops::ShapeGuardContext;
33use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
34
35use crate::extension::{LinalgExtensionOp, LinalgOp};
36use crate::LINALG_EXTENSION_FAMILY_ID;
37
38mod rules;
39pub mod support;
40
41/// Return the explicit linalg extension AD rule set.
42///
43/// # Examples
44///
45/// ```rust
46/// let rules = tenferro_linalg::ad_rules().unwrap();
47/// assert!(rules.is_rule_registered(tenferro_linalg::LINALG_EXTENSION_FAMILY_ID));
48/// ```
49pub fn ad_rules() -> Result<ExtensionRuleSet, ExtensionRegistryError> {
50    ExtensionRuleSet::new().with_rule(Arc::new(LinalgAdRule))
51}
52
53#[derive(Debug)]
54struct LinalgAdRule;
55
56impl ExtensionAdRule for LinalgAdRule {
57    fn family_id(&self) -> &'static str {
58        LINALG_EXTENSION_FAMILY_ID
59    }
60
61    fn linearize(
62        &self,
63        op: &dyn ExtensionOp,
64        builder: &mut dyn PrimitiveRuleBuilder,
65        primal_in: &[ValueKey<StdTensorOp>],
66        primal_out: &[ValueKey<StdTensorOp>],
67        tangent_in: &[Option<LocalValueId>],
68        ctx: &mut ShapeGuardContext,
69    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
70        let op = downcast_ad_op(op, ADRuleKind::Jvp)?;
71        match op.op() {
72            LinalgOp::Lu => rules::linearize_lu(builder, primal_in, primal_out, tangent_in, ctx),
73            LinalgOp::LuFactor => Ok(vec![None; op.output_count()]),
74            LinalgOp::LuSolvePrepared {
75                transpose_a,
76                conjugate_a,
77            } => rules::linearize_lu_solve_prepared(
78                builder,
79                primal_in,
80                primal_out,
81                tangent_in,
82                transpose_a,
83                conjugate_a,
84                ctx,
85            ),
86            LinalgOp::FullPivLu => {
87                rules::linearize_full_piv_lu(builder, primal_in, primal_out, tangent_in, ctx)
88            }
89            LinalgOp::FullPivLuSolve { transpose_a } => rules::linearize_full_piv_lu_solve(
90                builder,
91                primal_in,
92                primal_out,
93                tangent_in,
94                transpose_a,
95                ctx,
96            ),
97            LinalgOp::TriangularSolve {
98                left_side,
99                lower,
100                transpose_a,
101                unit_diagonal,
102            } => rules::linearize_triangular_solve(
103                builder,
104                primal_in,
105                primal_out,
106                tangent_in,
107                rules::TriangularSolveFlags::new(left_side, lower, transpose_a, unit_diagonal),
108                ctx,
109            ),
110            LinalgOp::Cholesky => {
111                rules::linearize_cholesky(builder, primal_in, primal_out, tangent_in, ctx)
112            }
113            LinalgOp::Svd { eps } => {
114                rules::linearize_svd(builder, primal_in, primal_out, tangent_in, eps, ctx)
115            }
116            LinalgOp::SvdVals { eps } => {
117                rules::linearize_svd_values(builder, primal_in, tangent_in, eps, ctx)
118            }
119            LinalgOp::Qr => rules::linearize_qr(builder, primal_in, primal_out, tangent_in, ctx),
120            LinalgOp::Eigh { eps } => {
121                rules::linearize_eigh(builder, primal_in, primal_out, tangent_in, eps, ctx)
122            }
123            LinalgOp::EighVals { eps } => {
124                rules::linearize_eigh_values(builder, primal_in, tangent_in, eps, ctx)
125            }
126            LinalgOp::Eig { input_dtype } => Ok(rules::linearize_eig(
127                builder,
128                primal_in,
129                primal_out,
130                tangent_in,
131                input_dtype,
132                ctx,
133            )),
134            LinalgOp::EigVals { input_dtype } => Ok(rules::linearize_eig_values(
135                builder,
136                primal_in,
137                tangent_in,
138                input_dtype,
139                ctx,
140            )),
141        }
142    }
143
144    fn transpose_rule(
145        &self,
146        op: &dyn ExtensionOp,
147        builder: &mut dyn PrimitiveRuleBuilder,
148        cotangent_out: &[Option<LocalValueId>],
149        inputs: &[ValueRef<StdTensorOp>],
150        mode: &OperationRole,
151        ctx: &mut ShapeGuardContext,
152    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
153        let op = downcast_ad_op(op, ADRuleKind::Transpose)?;
154        let mut builder = DynBuilder(builder);
155        match op.op() {
156            LinalgOp::TriangularSolve {
157                left_side,
158                lower,
159                transpose_a,
160                unit_diagonal,
161            } => rules::transpose_triangular_solve(
162                &mut builder,
163                cotangent_out,
164                inputs,
165                mode,
166                rules::TriangularSolveFlags::new(left_side, lower, transpose_a, unit_diagonal),
167                ctx,
168            ),
169            LinalgOp::LuSolvePrepared {
170                transpose_a,
171                conjugate_a,
172            } => Ok(rules::transpose_lu_solve_prepared(
173                &mut builder,
174                cotangent_out,
175                inputs,
176                mode,
177                transpose_a,
178                conjugate_a,
179                ctx,
180            )),
181            LinalgOp::FullPivLuSolve { transpose_a } => rules::transpose_full_piv_lu_solve(
182                &mut builder,
183                cotangent_out,
184                inputs,
185                mode,
186                transpose_a,
187                ctx,
188            ),
189            LinalgOp::Cholesky
190            | LinalgOp::Lu
191            | LinalgOp::LuFactor
192            | LinalgOp::FullPivLu
193            | LinalgOp::Svd { .. }
194            | LinalgOp::SvdVals { .. }
195            | LinalgOp::Qr
196            | LinalgOp::Eigh { .. }
197            | LinalgOp::EighVals { .. }
198            | LinalgOp::Eig { .. }
199            | LinalgOp::EigVals { .. } => Ok(vec![None; op.input_count()]),
200        }
201    }
202}
203
204struct DynBuilder<'a>(&'a mut dyn PrimitiveRuleBuilder);
205
206impl PrimitiveRuleBuilder for DynBuilder<'_> {
207    fn add_operation(
208        &mut self,
209        op: StdTensorOp,
210        inputs: Vec<ValueRef<StdTensorOp>>,
211        mode: OperationRole,
212    ) -> Vec<LocalValueId> {
213        self.0.add_operation(op, inputs, mode)
214    }
215}
216
217fn downcast_ad_op(op: &dyn ExtensionOp, kind: ADRuleKind) -> ADRuleResult<&LinalgExtensionOp> {
218    op.as_any()
219        .downcast_ref::<LinalgExtensionOp>()
220        .ok_or_else(|| {
221            ADRuleError::invalid_input("tenferro-linalg.linalg.v1", kind, "payload type mismatch")
222        })
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use computegraph::graph::GraphBuilder;
229    use tenferro_ops::input_key::TensorInputKey;
230    use tenferro_ops::{ShapeExtent, SymDim, TensorMeta};
231    use tenferro_tensor::DType;
232
233    fn input_key(id: u64) -> ValueKey<StdTensorOp> {
234        ValueKey::Input(TensorInputKey::User { id })
235    }
236
237    fn insert_meta(ctx: &mut ShapeGuardContext, key: ValueKey<StdTensorOp>, shape: &[usize]) {
238        ctx.insert_metadata(
239            key,
240            TensorMeta::exact(
241                DType::F64,
242                shape.iter().copied().map(SymDim::from).collect(),
243            ),
244        );
245    }
246
247    #[test]
248    fn full_piv_lu_jvp_returns_inactive_outputs_for_non_square_input() {
249        let mut builder = GraphBuilder::<StdTensorOp>::new();
250        let mut ctx = ShapeGuardContext::default();
251        let primal = input_key(1);
252        insert_meta(&mut ctx, primal.clone(), &[2, 3]);
253        let tangent = builder.add_input(TensorInputKey::User { id: 2 });
254        let outputs = [
255            input_key(10),
256            input_key(11),
257            input_key(12),
258            input_key(13),
259            input_key(14),
260        ];
261        let op = LinalgExtensionOp::new(LinalgOp::FullPivLu);
262
263        let result = LinalgAdRule
264            .linearize(
265                &op,
266                &mut builder,
267                &[primal],
268                &outputs,
269                &[Some(tangent)],
270                &mut ctx,
271            )
272            .unwrap();
273
274        assert_eq!(result, vec![None, None, None, None, None]);
275        assert!(builder.build().operations().is_empty());
276    }
277
278    #[test]
279    fn triangular_solve_jvp_rejects_non_matrix_operands() {
280        let mut builder = GraphBuilder::<StdTensorOp>::new();
281        let mut ctx = ShapeGuardContext::default();
282        let lhs = input_key(20);
283        let rhs = input_key(21);
284        insert_meta(&mut ctx, lhs.clone(), &[2, 2]);
285        insert_meta(&mut ctx, rhs.clone(), &[2]);
286        let rhs_tangent = builder.add_input(TensorInputKey::User { id: 22 });
287        let op = LinalgExtensionOp::new(LinalgOp::TriangularSolve {
288            left_side: true,
289            lower: true,
290            transpose_a: false,
291            unit_diagonal: false,
292        });
293
294        let err = LinalgAdRule
295            .linearize(
296                &op,
297                &mut builder,
298                &[lhs, rhs],
299                &[input_key(23)],
300                &[None, Some(rhs_tangent)],
301                &mut ctx,
302            )
303            .unwrap_err();
304
305        assert_eq!(err.rule(), ADRuleKind::Jvp);
306        assert!(err
307            .to_string()
308            .contains("expected matrix operands with rank >= 2"));
309        assert!(builder.build().operations().is_empty());
310    }
311
312    #[test]
313    fn triangular_solve_jvp_accepts_upper_bound_matrix_metadata() {
314        let mut builder = GraphBuilder::<StdTensorOp>::new();
315        let mut ctx = ShapeGuardContext::default();
316        let lhs = input_key(30);
317        let rhs = input_key(31);
318        ctx.insert_metadata(
319            lhs.clone(),
320            TensorMeta::with_extents(
321                DType::F64,
322                vec![
323                    ShapeExtent::upper_bound(SymDim::from(4usize)),
324                    ShapeExtent::upper_bound(SymDim::from(4usize)),
325                ],
326            ),
327        );
328        ctx.insert_metadata(
329            rhs.clone(),
330            TensorMeta::with_extents(
331                DType::F64,
332                vec![
333                    ShapeExtent::upper_bound(SymDim::from(4usize)),
334                    ShapeExtent::upper_bound(SymDim::from(2usize)),
335                ],
336            ),
337        );
338        let rhs_tangent = builder.add_input(TensorInputKey::User { id: 32 });
339        let op = LinalgExtensionOp::new(LinalgOp::TriangularSolve {
340            left_side: true,
341            lower: true,
342            transpose_a: false,
343            unit_diagonal: false,
344        });
345
346        let result = LinalgAdRule
347            .linearize(
348                &op,
349                &mut builder,
350                &[lhs.clone(), rhs],
351                &[input_key(33)],
352                &[None, Some(rhs_tangent)],
353                &mut ctx,
354            )
355            .unwrap();
356
357        assert!(result[0].is_some());
358        let graph = builder.build();
359        assert_eq!(graph.operations().len(), 1);
360        let solve = &graph.operations()[0];
361        assert_eq!(solve.inputs[0], ValueRef::External(lhs));
362        assert_eq!(solve.inputs[1], ValueRef::Local(rhs_tangent));
363    }
364
365    #[test]
366    fn triangular_solve_transpose_accepts_upper_bound_matrix_metadata() {
367        let mut builder = GraphBuilder::<StdTensorOp>::new();
368        let mut ctx = ShapeGuardContext::default();
369        let lhs = input_key(40);
370        let rhs = input_key(41);
371        ctx.insert_metadata(
372            lhs.clone(),
373            TensorMeta::with_extents(
374                DType::F64,
375                vec![
376                    ShapeExtent::upper_bound(SymDim::from(4usize)),
377                    ShapeExtent::upper_bound(SymDim::from(4usize)),
378                ],
379            ),
380        );
381        ctx.insert_metadata(
382            rhs.clone(),
383            TensorMeta::with_extents(
384                DType::F64,
385                vec![
386                    ShapeExtent::upper_bound(SymDim::from(4usize)),
387                    ShapeExtent::upper_bound(SymDim::from(2usize)),
388                ],
389            ),
390        );
391        let cotangent = builder.add_input(TensorInputKey::User { id: 42 });
392        let op = LinalgExtensionOp::new(LinalgOp::TriangularSolve {
393            left_side: true,
394            lower: true,
395            transpose_a: false,
396            unit_diagonal: false,
397        });
398
399        let result = LinalgAdRule
400            .transpose_rule(
401                &op,
402                &mut builder,
403                &[Some(cotangent)],
404                &[ValueRef::External(lhs.clone()), ValueRef::External(rhs)],
405                &OperationRole::Linearized {
406                    active_mask: vec![false, true],
407                },
408                &mut ctx,
409            )
410            .unwrap();
411
412        assert_eq!(result[0], None);
413        assert!(result[1].is_some());
414        let graph = builder.build();
415        assert_eq!(graph.operations().len(), 1);
416        assert_eq!(graph.operations()[0].inputs[0], ValueRef::External(lhs));
417        assert_eq!(graph.operations()[0].inputs[1], ValueRef::Local(cotangent));
418    }
419
420    #[test]
421    fn cholesky_jvp_uses_rank_when_input_metadata_is_upper_bound() {
422        let mut builder = GraphBuilder::<StdTensorOp>::new();
423        let mut ctx = ShapeGuardContext::default();
424        let primal = input_key(50);
425        ctx.insert_metadata(
426            primal.clone(),
427            TensorMeta::with_extents(
428                DType::F64,
429                vec![
430                    ShapeExtent::upper_bound(SymDim::from(4usize)),
431                    ShapeExtent::upper_bound(SymDim::from(4usize)),
432                ],
433            ),
434        );
435        let tangent = builder.add_input(TensorInputKey::User { id: 51 });
436        let op = LinalgExtensionOp::new(LinalgOp::Cholesky);
437
438        let result = LinalgAdRule
439            .linearize(
440                &op,
441                &mut builder,
442                &[primal],
443                &[input_key(52)],
444                &[Some(tangent)],
445                &mut ctx,
446            )
447            .unwrap();
448
449        assert!(result[0].is_some());
450        assert!(!builder.build().operations().is_empty());
451    }
452
453    #[test]
454    fn one_input_linalg_jvps_return_inactive_for_non_matrix_input() {
455        let cases = [
456            LinalgOp::Cholesky,
457            LinalgOp::Lu,
458            LinalgOp::FullPivLu,
459            LinalgOp::Svd { eps: 1e-12 },
460            LinalgOp::SvdVals { eps: 1e-12 },
461            LinalgOp::Qr,
462            LinalgOp::Eigh { eps: 1e-12 },
463            LinalgOp::EighVals { eps: 1e-12 },
464            LinalgOp::Eig {
465                input_dtype: DType::F64,
466            },
467            LinalgOp::EigVals {
468                input_dtype: DType::F64,
469            },
470        ];
471
472        for (case_index, kind) in cases.into_iter().enumerate() {
473            let mut builder = GraphBuilder::<StdTensorOp>::new();
474            let mut ctx = ShapeGuardContext::default();
475            let primal = input_key(100 + case_index as u64);
476            insert_meta(&mut ctx, primal.clone(), &[3]);
477            let tangent = builder.add_input(TensorInputKey::User {
478                id: 200 + case_index as u64,
479            });
480            let op = LinalgExtensionOp::new(kind);
481            let outputs: Vec<_> = (0..op.output_count())
482                .map(|offset| input_key(300 + case_index as u64 * 10 + offset as u64))
483                .collect();
484
485            let result = LinalgAdRule
486                .linearize(
487                    &op,
488                    &mut builder,
489                    &[primal],
490                    &outputs,
491                    &[Some(tangent)],
492                    &mut ctx,
493                )
494                .unwrap();
495
496            assert_eq!(result, vec![None; op.output_count()], "{kind:?}");
497            assert!(
498                builder.build().operations().is_empty(),
499                "{kind:?} should not emit a malformed matrix AD graph"
500            );
501        }
502    }
503}