Skip to main content

tenferro_ops/
ext_op.rs

1//! Out-of-tree extension-operation mechanism.
2//!
3//! This module implements the [`ExtensionOp`] trait and its process-local
4//! registry. Together they let external crates contribute fused primitives
5//! that participate in the [`crate::std_tensor_op::StdTensorOp`] graph through
6//! the single carrier variant
7//! `StdTensorOp::Extension(Arc<dyn ExtensionOp>)`.
8//!
9//! See `docs/spec/extension-op.md` for the normative contract. Key points:
10//!
11//! - Identity / hashing / equality are expressed on the trait so the
12//!   type-erased `Arc<dyn ExtensionOp>` carrier can satisfy
13//!   `Clone + Hash + Eq + Send + Sync + 'static` (computegraph's
14//!   `GraphOperation` requirements).
15//! - AD rules are owned by explicit [`ExtensionRuleSet`] values. A rule may
16//!   emit core [`StdTensorOp`] values and registered `Extension` values so
17//!   out-of-tree operations remain in the same graph.
18//! - Extension ops themselves do not require process-global registration.
19//!   Frontends carry them directly as `Arc<dyn ExtensionOp>`.
20
21use std::any::Any;
22use std::fmt::Debug;
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25
26use computegraph::graph::GraphBuilder;
27#[cfg(not(feature = "autodiff"))]
28use computegraph::types::ValueRef;
29#[cfg(feature = "autodiff")]
30use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
31use tenferro_tensor::{DType, Tensor};
32#[cfg(feature = "autodiff")]
33use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
34
35#[cfg(feature = "autodiff")]
36use crate::ad::context::ShapeGuardContext;
37#[cfg(feature = "autodiff")]
38use crate::ad::PrimitiveRuleBuilder;
39use crate::std_tensor_op::StdTensorOp;
40use crate::sym_dim::SymDim;
41#[cfg(feature = "autodiff")]
42use std::collections::HashMap;
43
44/// Error returned when an extension cannot expand itself into standard ops.
45///
46/// # Examples
47///
48/// ```
49/// use tenferro_ops::ext_op::ExtensionLoweringError;
50///
51/// let err = ExtensionLoweringError::new("example extension cannot lower");
52/// assert!(err.to_string().contains("cannot lower"));
53/// ```
54#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
55#[error("{message}")]
56pub struct ExtensionLoweringError {
57    message: String,
58}
59
60impl ExtensionLoweringError {
61    /// Create a lowering error with a human-readable diagnostic.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use tenferro_ops::ext_op::ExtensionLoweringError;
67    ///
68    /// let err = ExtensionLoweringError::new("shape must be static");
69    /// assert_eq!(err.to_string(), "shape must be static");
70    /// ```
71    pub fn new(message: impl Into<String>) -> Self {
72        Self {
73            message: message.into(),
74        }
75    }
76}
77
78/// Result returned by [`ExtensionOp::lower_to_standard_ops`].
79pub type ExtensionLoweringResult =
80    std::result::Result<Option<Vec<ValueRef<StdTensorOp>>>, ExtensionLoweringError>;
81
82/// The contract every out-of-tree extension primitive must satisfy.
83///
84/// Implementations appear in the core graph as
85/// `StdTensorOp::Extension(Arc<dyn ExtensionOp>)`. Every method is part of the
86/// `ExtensionOp` spec (`docs/spec/extension-op.md`); the short form:
87///
88/// - identity via [`family_id`][Self::family_id] + [`payload_hash`][Self::payload_hash]
89///   + [`payload_eq`][Self::payload_eq];
90/// - fixed arity via [`input_count`][Self::input_count] / [`output_count`][Self::output_count];
91/// - shape / dtype inference via [`infer_output_meta`][Self::infer_output_meta];
92/// - host/reference forward execution via [`eager_execute`][Self::eager_execute];
93///   runtime-owned eager and compiled paths dispatch through registered
94///   extension runtimes instead of falling back to this method;
95/// - optional fixed-shape standard-op expansion via
96///   [`lower_to_standard_ops`][Self::lower_to_standard_ops] for peer lowerers
97///   such as XLA that cannot execute extension runtimes;
98/// - AD via a separately registered [`ExtensionAdRule`].
99///
100/// # Downcast convention
101///
102/// Implementations MUST also implement [`Any`] so that
103/// [`ExtensionOp::payload_eq`] can downcast a trait-object reference to
104/// the concrete type. The helper [`ExtensionOp::as_any`] returns
105/// `&dyn Any` for this purpose. Implementations usually define it as
106/// `fn as_any(&self) -> &dyn Any { self }`.
107///
108/// # Examples
109///
110/// ```
111/// # use std::any::Any;
112/// use std::sync::Arc;
113/// use tenferro_ops::ext_op::ExtensionOp;
114/// use tenferro_ops::SymDim;
115/// use tenferro_tensor::{DType, Tensor};
116///
117/// #[derive(Clone, Debug)]
118/// struct IdentityExt;
119///
120/// impl ExtensionOp for IdentityExt {
121///     fn family_id(&self) -> &'static str { "example.identity.v1" }
122///     fn payload_hash(&self, _hasher: &mut dyn std::hash::Hasher) {}
123///     fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
124///         other.as_any().downcast_ref::<IdentityExt>().is_some()
125///     }
126///     fn clone_arc(&self) -> Arc<dyn ExtensionOp> { Arc::new(self.clone()) }
127///     fn as_any(&self) -> &dyn Any { self }
128///     fn input_count(&self) -> usize { 1 }
129///     fn output_count(&self) -> usize { 1 }
130///     fn infer_output_meta(
131///         &self,
132///         dtypes: &[DType],
133///         shapes: &[&[SymDim]],
134///     ) -> Vec<(DType, Vec<SymDim>)> {
135///         vec![(dtypes[0], shapes[0].to_vec())]
136///     }
137///     fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
138///         Ok(vec![inputs[0].clone()])
139///     }
140/// }
141///
142/// let op: Arc<dyn ExtensionOp> = Arc::new(IdentityExt);
143/// assert_eq!(op.input_count(), 1);
144/// ```
145pub trait ExtensionOp: Debug + Send + Sync + 'static {
146    // ----- Identity, hashing, equality (spec Section 5) -----
147
148    /// Stable, process-independent family identifier.
149    ///
150    /// MUST be unique per extension *family* (payload schema), not per
151    /// *instance*, and MUST follow the reserved format
152    /// `"<crate-name>.<op-name>.v<major>"`.
153    fn family_id(&self) -> &'static str;
154
155    /// Hash the payload (everything except `family_id`).
156    ///
157    /// Implementations MUST be pure and deterministic across calls on the same
158    /// value. Hashes MUST NOT include transient state such as allocation
159    /// addresses or atomically updated counters.
160    fn payload_hash(&self, hasher: &mut dyn Hasher);
161
162    /// Structural equality against another extension value.
163    ///
164    /// The carrier's `PartialEq` impl first compares `family_id`s. When the
165    /// family IDs match, it calls `payload_eq`. Implementations MUST return
166    /// `true` iff the payloads are semantically equal AND
167    /// `other.family_id() == self.family_id()`.
168    fn payload_eq(&self, other: &dyn ExtensionOp) -> bool;
169
170    /// Deep-clone the payload behind an `Arc`.
171    ///
172    /// The carrier's `Clone` impl uses `Arc::clone` on the fast path; this
173    /// method exists for rare cases that need a second independent `Arc`.
174    fn clone_arc(&self) -> Arc<dyn ExtensionOp>;
175
176    /// Upcast this extension to `&dyn Any` for downcasting in `payload_eq`.
177    ///
178    /// Implementations SHOULD return `self` verbatim. The method is
179    /// object-safe (no `Self: Sized` bound) so it can be called on an
180    /// `&dyn ExtensionOp`; that's what makes
181    /// `other.as_any().downcast_ref::<ConcreteType>()` work from
182    /// [`Self::payload_eq`] implementations.
183    fn as_any(&self) -> &dyn Any;
184
185    // ----- Arity (spec Section 6) -----
186
187    /// Number of primal inputs. MUST be constant for any given
188    /// `Arc<dyn ExtensionOp>` value.
189    fn input_count(&self) -> usize;
190
191    /// Number of outputs. MUST match the length of the vector returned by
192    /// [`Self::infer_output_meta`].
193    fn output_count(&self) -> usize;
194
195    // ----- Shape and dtype inference (spec Section 7) -----
196
197    /// Infer output dtypes and shapes for each output slot.
198    ///
199    /// `input_dtypes.len()` and `input_shapes.len()` both equal
200    /// `self.input_count()`. The returned vector MUST have length
201    /// `self.output_count()`, one `(dtype, shape)` entry per output slot.
202    /// Shapes use [`SymDim`] so extension ops compose with graph-global
203    /// symbolic metadata.
204    fn infer_output_meta(
205        &self,
206        input_dtypes: &[DType],
207        input_shapes: &[&[SymDim]],
208    ) -> Vec<(DType, Vec<SymDim>)>;
209
210    // ----- Forward execution dispatch (spec Section 8) -----
211
212    /// Eager forward execution; called from the eager path and indirectly
213    /// from the compiled path.
214    ///
215    /// Input tensors are on the device the caller already arranged. Output
216    /// tensors MUST have shapes matching [`Self::infer_output_meta`] and MUST
217    /// be placed on a device the caller can consume.
218    fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>>;
219
220    /// Optionally expand this extension into standard tensor graph operations.
221    ///
222    /// Peer lowerers call this when all input metadata is known and extension
223    /// runtime dispatch is not available. Return `Ok(Some(outputs))` after
224    /// adding only standard [`StdTensorOp`] operations to `builder`. Return
225    /// `Ok(None)` when this extension family has no standard-op lowering for
226    /// the supplied metadata; strict lowerers should surface that as an
227    /// explicit unsupported-extension error. Return [`ExtensionLoweringError`]
228    /// when the payload is malformed or the lowering detects invalid metadata.
229    ///
230    /// The default implementation returns `Ok(None)` so existing extension
231    /// runtimes keep their native dispatch behavior until their owning crate
232    /// deliberately implements this hook.
233    fn lower_to_standard_ops(
234        &self,
235        _builder: &mut GraphBuilder<StdTensorOp>,
236        _inputs: &[ValueRef<StdTensorOp>],
237        _input_dtypes: &[DType],
238        _input_shapes: &[&[SymDim]],
239    ) -> ExtensionLoweringResult {
240        Ok(None)
241    }
242
243    // AD rules are registered separately; see [`ExtensionAdRule`].
244}
245
246/// AD rule provider for an extension family.
247///
248/// Rules are registered independently from the primal operation so an
249/// out-of-tree crate can provide forward execution without AD, or gate AD
250/// support behind an optional feature. Rule methods receive the concrete
251/// [`ExtensionOp`] payload as a trait object; implementations should downcast
252/// through [`ExtensionOp::as_any`] when they need payload-specific parameters.
253#[cfg(feature = "autodiff")]
254pub trait ExtensionAdRule: Debug + Send + Sync + 'static {
255    /// The extension family this rule handles.
256    fn family_id(&self) -> &'static str;
257
258    /// Emit the linear (JVP) rule.
259    fn linearize(
260        &self,
261        op: &dyn ExtensionOp,
262        builder: &mut dyn PrimitiveRuleBuilder,
263        primal_in: &[ValueKey<StdTensorOp>],
264        primal_out: &[ValueKey<StdTensorOp>],
265        tangent_in: &[Option<LocalValueId>],
266        ctx: &mut ShapeGuardContext,
267    ) -> ADRuleResult<Vec<Option<LocalValueId>>>;
268
269    /// Emit the transpose (VJP) rule.
270    fn transpose_rule(
271        &self,
272        op: &dyn ExtensionOp,
273        builder: &mut dyn PrimitiveRuleBuilder,
274        cotangent_out: &[Option<LocalValueId>],
275        inputs: &[ValueRef<StdTensorOp>],
276        mode: &OperationRole,
277        ctx: &mut ShapeGuardContext,
278    ) -> ADRuleResult<Vec<Option<LocalValueId>>>;
279}
280
281/// Errors returned from extension registries.
282#[cfg(feature = "autodiff")]
283#[derive(Debug, thiserror::Error)]
284pub enum ExtensionRegistryError {
285    /// An AD rule with the same `family_id` was already registered.
286    #[error("AD rule for family_id {family_id:?} already registered")]
287    DuplicateRule { family_id: &'static str },
288    /// The `family_id` does not match the namespaced format
289    /// `"<crate-name>.<op-name>.v<major>"`.
290    #[error("family_id {family_id:?} does not match the namespaced format")]
291    MalformedFamilyId { family_id: &'static str },
292}
293
294#[cfg(feature = "autodiff")]
295type RuleMap = HashMap<&'static str, Arc<dyn ExtensionAdRule>>;
296
297/// Explicit, owned set of extension AD rules.
298///
299/// This is the rule container used by higher-level AD contexts. Extension AD
300/// intentionally has no process-global fallback; callers must pass the rule set
301/// that their graph needs.
302#[cfg(feature = "autodiff")]
303#[derive(Clone, Default)]
304pub struct ExtensionRuleSet {
305    rules: Arc<RuleMap>,
306}
307
308#[cfg(feature = "autodiff")]
309impl std::fmt::Debug for ExtensionRuleSet {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        let mut families: Vec<_> = self.rules.keys().copied().collect();
312        families.sort_unstable();
313        f.debug_struct("ExtensionRuleSet")
314            .field("families", &families)
315            .finish()
316    }
317}
318
319#[cfg(feature = "autodiff")]
320impl ExtensionRuleSet {
321    /// Create an empty extension rule set.
322    ///
323    /// # Examples
324    ///
325    /// ```
326    /// use tenferro_ops::ExtensionRuleSet;
327    ///
328    /// let rules = ExtensionRuleSet::new();
329    /// assert!(!rules.is_rule_registered("example.missing.v1"));
330    /// ```
331    pub fn new() -> Self {
332        Self::default()
333    }
334
335    /// Add one rule to this owned set.
336    ///
337    /// # Examples
338    ///
339    /// ```
340    /// use std::sync::Arc;
341    /// use tidu::ADRuleResult;
342    /// use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
343    /// use tenferro_ops::ad::PrimitiveRuleBuilder;
344    /// use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionOp};
345    /// use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
346    /// use tenferro_ops::std_tensor_op::StdTensorOp;
347    ///
348    /// #[derive(Debug)]
349    /// struct Rule;
350    ///
351    /// impl ExtensionAdRule for Rule {
352    ///     fn family_id(&self) -> &'static str { "example.register_rule.v1" }
353    ///     fn linearize(
354    ///         &self,
355    ///         _op: &dyn ExtensionOp,
356    ///         _builder: &mut dyn PrimitiveRuleBuilder,
357    ///         _primal_in: &[ValueKey<StdTensorOp>],
358    ///         _primal_out: &[ValueKey<StdTensorOp>],
359    ///         tangent_in: &[Option<LocalValueId>],
360    ///         _ctx: &mut ShapeGuardContext,
361    ///     ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
362    ///         Ok(tangent_in.to_vec())
363    ///     }
364    ///     fn transpose_rule(
365    ///         &self,
366    ///         _op: &dyn ExtensionOp,
367    ///         _builder: &mut dyn PrimitiveRuleBuilder,
368    ///         cotangent_out: &[Option<LocalValueId>],
369    ///         _inputs: &[ValueRef<StdTensorOp>],
370    ///         _mode: &OperationRole,
371    ///         _ctx: &mut ShapeGuardContext,
372    ///     ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
373    ///         Ok(cotangent_out.to_vec())
374    ///     }
375    /// }
376    ///
377    /// let mut rules = ExtensionRuleSet::new();
378    /// rules.register_rule(Arc::new(Rule)).unwrap();
379    /// assert!(rules.lookup_rule("example.register_rule.v1").is_some());
380    /// ```
381    pub fn register_rule(
382        &mut self,
383        rule: Arc<dyn ExtensionAdRule>,
384    ) -> Result<(), ExtensionRegistryError> {
385        let family_id = rule.family_id();
386        validate_rule_insert(&self.rules, family_id)?;
387        let rules = Arc::make_mut(&mut self.rules);
388        rules.insert(family_id, rule);
389        Ok(())
390    }
391
392    /// Return a new rule set containing `rule`.
393    ///
394    /// # Examples
395    ///
396    /// ```
397    /// use std::sync::Arc;
398    /// use tidu::ADRuleResult;
399    /// use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
400    /// use tenferro_ops::ad::PrimitiveRuleBuilder;
401    /// use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionOp};
402    /// use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
403    /// use tenferro_ops::std_tensor_op::StdTensorOp;
404    ///
405    /// #[derive(Debug)]
406    /// struct Rule;
407    ///
408    /// impl ExtensionAdRule for Rule {
409    ///     fn family_id(&self) -> &'static str { "example.with_rule.v1" }
410    ///     fn linearize(
411    ///         &self,
412    ///         _op: &dyn ExtensionOp,
413    ///         _builder: &mut dyn PrimitiveRuleBuilder,
414    ///         _primal_in: &[ValueKey<StdTensorOp>],
415    ///         _primal_out: &[ValueKey<StdTensorOp>],
416    ///         tangent_in: &[Option<LocalValueId>],
417    ///         _ctx: &mut ShapeGuardContext,
418    ///     ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
419    ///         Ok(tangent_in.to_vec())
420    ///     }
421    ///     fn transpose_rule(
422    ///         &self,
423    ///         _op: &dyn ExtensionOp,
424    ///         _builder: &mut dyn PrimitiveRuleBuilder,
425    ///         cotangent_out: &[Option<LocalValueId>],
426    ///         _inputs: &[ValueRef<StdTensorOp>],
427    ///         _mode: &OperationRole,
428    ///         _ctx: &mut ShapeGuardContext,
429    ///     ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
430    ///         Ok(cotangent_out.to_vec())
431    ///     }
432    /// }
433    ///
434    /// let rules = ExtensionRuleSet::new().with_rule(Arc::new(Rule)).unwrap();
435    /// assert!(rules.is_rule_registered("example.with_rule.v1"));
436    /// ```
437    pub fn with_rule(
438        mut self,
439        rule: Arc<dyn ExtensionAdRule>,
440    ) -> Result<Self, ExtensionRegistryError> {
441        self.register_rule(rule)?;
442        Ok(self)
443    }
444
445    /// Merge another owned rule set into this one.
446    ///
447    /// The merge is atomic: if any rule in `other` is invalid or duplicates an
448    /// existing family, `self` is left unchanged.
449    ///
450    /// # Examples
451    ///
452    /// ```
453    /// use tenferro_ops::ExtensionRuleSet;
454    ///
455    /// let mut rules = ExtensionRuleSet::new();
456    /// rules.merge(ExtensionRuleSet::new()).unwrap();
457    /// assert!(!rules.is_rule_registered("example.missing.v1"));
458    /// ```
459    pub fn merge(&mut self, other: ExtensionRuleSet) -> Result<(), ExtensionRegistryError> {
460        let mut rules = (*self.rules).clone();
461        for rule in other.rules.values() {
462            insert_rule(&mut rules, Arc::clone(rule))?;
463        }
464        self.rules = Arc::new(rules);
465        Ok(())
466    }
467
468    /// Look up a rule in this set.
469    ///
470    /// # Examples
471    ///
472    /// ```
473    /// use tenferro_ops::ExtensionRuleSet;
474    ///
475    /// let rules = ExtensionRuleSet::new();
476    /// assert!(rules.lookup_rule("example.missing.v1").is_none());
477    /// ```
478    pub fn lookup_rule(&self, family_id: &str) -> Option<Arc<dyn ExtensionAdRule>> {
479        self.rules.get(family_id).cloned()
480    }
481
482    /// Return whether `family_id` is present in this set.
483    ///
484    /// # Examples
485    ///
486    /// ```
487    /// use tenferro_ops::ExtensionRuleSet;
488    ///
489    /// let rules = ExtensionRuleSet::new();
490    /// assert!(!rules.is_rule_registered("example.missing.v1"));
491    /// ```
492    pub fn is_rule_registered(&self, family_id: &str) -> bool {
493        self.rules.contains_key(family_id)
494    }
495}
496
497#[cfg(feature = "autodiff")]
498fn is_valid_family_id(family_id: &str) -> bool {
499    // Required shape: `<crate>.<op>.v<major>` with at least one non-empty
500    // `<crate>` chunk, at least one non-empty `<op>` chunk (which may itself
501    // contain `.`), and a final `.v<integer>` component. `<crate>` and
502    // `<op>` segments must be ASCII with no whitespace.
503    let mut parts = family_id.rsplitn(2, '.');
504    let Some(version_part) = parts.next() else {
505        return false;
506    };
507    let Some(prefix) = parts.next() else {
508        return false;
509    };
510    if !version_part.starts_with('v') {
511        return false;
512    }
513    let digits = &version_part[1..];
514    if digits.is_empty() || !digits.chars().all(|c| c.is_ascii_digit()) {
515        return false;
516    }
517    let Some((crate_name, op_name)) = prefix.split_once('.') else {
518        return false;
519    };
520    if crate_name.is_empty() || op_name.is_empty() {
521        return false;
522    }
523    let any_invalid = |s: &str| s.chars().any(|c| c.is_whitespace() || !c.is_ascii());
524    if any_invalid(crate_name) || any_invalid(op_name) {
525        return false;
526    }
527    true
528}
529
530#[cfg(feature = "autodiff")]
531fn insert_rule(
532    rules: &mut RuleMap,
533    rule: Arc<dyn ExtensionAdRule>,
534) -> Result<(), ExtensionRegistryError> {
535    let family_id = rule.family_id();
536    validate_rule_insert(rules, family_id)?;
537    rules.insert(family_id, rule);
538    Ok(())
539}
540
541#[cfg(feature = "autodiff")]
542fn validate_rule_insert(
543    rules: &RuleMap,
544    family_id: &'static str,
545) -> Result<(), ExtensionRegistryError> {
546    if !is_valid_family_id(family_id) {
547        return Err(ExtensionRegistryError::MalformedFamilyId { family_id });
548    }
549    if rules.contains_key(family_id) {
550        return Err(ExtensionRegistryError::DuplicateRule { family_id });
551    }
552    Ok(())
553}
554
555/// Emit a registered extension linearization rule.
556#[cfg(feature = "autodiff")]
557pub fn linearize_extension_rule(
558    op: &dyn ExtensionOp,
559    builder: &mut dyn PrimitiveRuleBuilder,
560    primal_in: &[ValueKey<StdTensorOp>],
561    primal_out: &[ValueKey<StdTensorOp>],
562    tangent_in: &[Option<LocalValueId>],
563    ctx: &mut ShapeGuardContext,
564) -> ADRuleResult<Vec<Option<LocalValueId>>> {
565    match ctx.extension_rule_for(op.family_id()) {
566        Some(rule) => rule.linearize(op, builder, primal_in, primal_out, tangent_in, ctx),
567        None => Err(ADRuleError::unsupported(op.family_id(), ADRuleKind::Jvp)),
568    }
569}
570
571/// Emit a registered extension transpose rule.
572#[cfg(feature = "autodiff")]
573pub fn transpose_extension_rule(
574    op: &dyn ExtensionOp,
575    builder: &mut dyn PrimitiveRuleBuilder,
576    cotangent_out: &[Option<LocalValueId>],
577    inputs: &[ValueRef<StdTensorOp>],
578    mode: &OperationRole,
579    ctx: &mut ShapeGuardContext,
580) -> ADRuleResult<Vec<Option<LocalValueId>>> {
581    match ctx.extension_rule_for(op.family_id()) {
582        Some(rule) => rule.transpose_rule(op, builder, cotangent_out, inputs, mode, ctx),
583        None => Err(ADRuleError::unsupported(
584            op.family_id(),
585            ADRuleKind::Transpose,
586        )),
587    }
588}
589
590/// Thin adapter that lets a generic `H: Hasher` satisfy the object-safe
591/// `&mut dyn Hasher` signature required by [`ExtensionOp::payload_hash`].
592///
593/// Only `write` and `finish` are load-bearing from the generic hasher; the
594/// various `write_u8` / `write_u16` default implementations in `Hasher`
595/// delegate to `write`. The adapter preserves that behaviour.
596pub(crate) struct DynHasherProxy<'a, H: Hasher + ?Sized> {
597    inner: &'a mut H,
598}
599
600impl<'a, H: Hasher + ?Sized> DynHasherProxy<'a, H> {
601    pub(crate) fn new(inner: &'a mut H) -> Self {
602        Self { inner }
603    }
604}
605
606impl<H: Hasher + ?Sized> Hasher for DynHasherProxy<'_, H> {
607    fn finish(&self) -> u64 {
608        self.inner.finish()
609    }
610
611    fn write(&mut self, bytes: &[u8]) {
612        self.inner.write(bytes);
613    }
614}
615
616/// Hash an `Arc<dyn ExtensionOp>` payload using the extension's
617/// [`ExtensionOp::family_id`] plus [`ExtensionOp::payload_hash`]. Shared
618/// between the `StdTensorOp::Extension` carrier's `Hash` impl and callers
619/// that need to fingerprint an `ExtensionOp` independently.
620pub(crate) fn hash_extension<H: Hasher>(op: &(dyn ExtensionOp + '_), state: &mut H) {
621    op.family_id().as_bytes().hash(state);
622    op.payload_hash(&mut DynHasherProxy::new(state));
623}
624
625/// Structural equality used by the `StdTensorOp::Extension` carrier.
626///
627/// Short-circuits on `family_id` inequality so two extensions with
628/// accidentally similar payloads but different families cannot be unified
629/// by the op interner.
630pub(crate) fn ext_op_eq(a: &dyn ExtensionOp, b: &dyn ExtensionOp) -> bool {
631    a.family_id() == b.family_id() && a.payload_eq(b)
632}