tenferro_ops/ext_op.rs
1//! Out-of-tree extension-operation mechanism.
2//!
3//! This module implements the [`ExtensionOp`] trait and its process-local
4//! registry. Together they let external crates contribute fused primitives
5//! that participate in the [`crate::std_tensor_op::StdTensorOp`] graph through
6//! the single carrier variant
7//! `StdTensorOp::Extension(Arc<dyn ExtensionOp>)`.
8//!
9//! See `docs/spec/extension-op.md` for the normative contract. Key points:
10//!
11//! - Identity / hashing / equality are expressed on the trait so the
12//! type-erased `Arc<dyn ExtensionOp>` carrier can satisfy
13//! `Clone + Hash + Eq + Send + Sync + 'static` (computegraph's
14//! `GraphOperation` requirements).
15//! - AD rules are owned by explicit [`ExtensionRuleSet`] values. A rule may
16//! emit core [`StdTensorOp`] values and registered `Extension` values so
17//! out-of-tree operations remain in the same graph.
18//! - Extension ops themselves do not require process-global registration.
19//! Frontends carry them directly as `Arc<dyn ExtensionOp>`.
20
21use std::any::Any;
22use std::fmt::Debug;
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25
26use computegraph::graph::GraphBuilder;
27#[cfg(not(feature = "autodiff"))]
28use computegraph::types::ValueRef;
29#[cfg(feature = "autodiff")]
30use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
31use tenferro_tensor::{DType, Tensor};
32#[cfg(feature = "autodiff")]
33use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
34
35#[cfg(feature = "autodiff")]
36use crate::ad::context::ShapeGuardContext;
37#[cfg(feature = "autodiff")]
38use crate::ad::PrimitiveRuleBuilder;
39use crate::std_tensor_op::StdTensorOp;
40use crate::sym_dim::SymDim;
41#[cfg(feature = "autodiff")]
42use std::collections::HashMap;
43
44/// Error returned when an extension cannot expand itself into standard ops.
45///
46/// # Examples
47///
48/// ```
49/// use tenferro_ops::ext_op::ExtensionLoweringError;
50///
51/// let err = ExtensionLoweringError::new("example extension cannot lower");
52/// assert!(err.to_string().contains("cannot lower"));
53/// ```
54#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
55#[error("{message}")]
56pub struct ExtensionLoweringError {
57 message: String,
58}
59
60impl ExtensionLoweringError {
61 /// Create a lowering error with a human-readable diagnostic.
62 ///
63 /// # Examples
64 ///
65 /// ```
66 /// use tenferro_ops::ext_op::ExtensionLoweringError;
67 ///
68 /// let err = ExtensionLoweringError::new("shape must be static");
69 /// assert_eq!(err.to_string(), "shape must be static");
70 /// ```
71 pub fn new(message: impl Into<String>) -> Self {
72 Self {
73 message: message.into(),
74 }
75 }
76}
77
78/// Result returned by [`ExtensionOp::lower_to_standard_ops`].
79pub type ExtensionLoweringResult =
80 std::result::Result<Option<Vec<ValueRef<StdTensorOp>>>, ExtensionLoweringError>;
81
82/// The contract every out-of-tree extension primitive must satisfy.
83///
84/// Implementations appear in the core graph as
85/// `StdTensorOp::Extension(Arc<dyn ExtensionOp>)`. Every method is part of the
86/// `ExtensionOp` spec (`docs/spec/extension-op.md`); the short form:
87///
88/// - identity via [`family_id`][Self::family_id] + [`payload_hash`][Self::payload_hash]
89/// + [`payload_eq`][Self::payload_eq];
90/// - fixed arity via [`input_count`][Self::input_count] / [`output_count`][Self::output_count];
91/// - shape / dtype inference via [`infer_output_meta`][Self::infer_output_meta];
92/// - host/reference forward execution via [`eager_execute`][Self::eager_execute];
93/// runtime-owned eager and compiled paths dispatch through registered
94/// extension runtimes instead of falling back to this method;
95/// - optional fixed-shape standard-op expansion via
96/// [`lower_to_standard_ops`][Self::lower_to_standard_ops] for peer lowerers
97/// such as XLA that cannot execute extension runtimes;
98/// - AD via a separately registered [`ExtensionAdRule`].
99///
100/// # Downcast convention
101///
102/// Implementations MUST also implement [`Any`] so that
103/// [`ExtensionOp::payload_eq`] can downcast a trait-object reference to
104/// the concrete type. The helper [`ExtensionOp::as_any`] returns
105/// `&dyn Any` for this purpose. Implementations usually define it as
106/// `fn as_any(&self) -> &dyn Any { self }`.
107///
108/// # Examples
109///
110/// ```
111/// # use std::any::Any;
112/// use std::sync::Arc;
113/// use tenferro_ops::ext_op::ExtensionOp;
114/// use tenferro_ops::SymDim;
115/// use tenferro_tensor::{DType, Tensor};
116///
117/// #[derive(Clone, Debug)]
118/// struct IdentityExt;
119///
120/// impl ExtensionOp for IdentityExt {
121/// fn family_id(&self) -> &'static str { "example.identity.v1" }
122/// fn payload_hash(&self, _hasher: &mut dyn std::hash::Hasher) {}
123/// fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
124/// other.as_any().downcast_ref::<IdentityExt>().is_some()
125/// }
126/// fn clone_arc(&self) -> Arc<dyn ExtensionOp> { Arc::new(self.clone()) }
127/// fn as_any(&self) -> &dyn Any { self }
128/// fn input_count(&self) -> usize { 1 }
129/// fn output_count(&self) -> usize { 1 }
130/// fn infer_output_meta(
131/// &self,
132/// dtypes: &[DType],
133/// shapes: &[&[SymDim]],
134/// ) -> Vec<(DType, Vec<SymDim>)> {
135/// vec![(dtypes[0], shapes[0].to_vec())]
136/// }
137/// fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
138/// Ok(vec![inputs[0].clone()])
139/// }
140/// }
141///
142/// let op: Arc<dyn ExtensionOp> = Arc::new(IdentityExt);
143/// assert_eq!(op.input_count(), 1);
144/// ```
145pub trait ExtensionOp: Debug + Send + Sync + 'static {
146 // ----- Identity, hashing, equality (spec Section 5) -----
147
148 /// Stable, process-independent family identifier.
149 ///
150 /// MUST be unique per extension *family* (payload schema), not per
151 /// *instance*, and MUST follow the reserved format
152 /// `"<crate-name>.<op-name>.v<major>"`.
153 fn family_id(&self) -> &'static str;
154
155 /// Hash the payload (everything except `family_id`).
156 ///
157 /// Implementations MUST be pure and deterministic across calls on the same
158 /// value. Hashes MUST NOT include transient state such as allocation
159 /// addresses or atomically updated counters.
160 fn payload_hash(&self, hasher: &mut dyn Hasher);
161
162 /// Structural equality against another extension value.
163 ///
164 /// The carrier's `PartialEq` impl first compares `family_id`s. When the
165 /// family IDs match, it calls `payload_eq`. Implementations MUST return
166 /// `true` iff the payloads are semantically equal AND
167 /// `other.family_id() == self.family_id()`.
168 fn payload_eq(&self, other: &dyn ExtensionOp) -> bool;
169
170 /// Deep-clone the payload behind an `Arc`.
171 ///
172 /// The carrier's `Clone` impl uses `Arc::clone` on the fast path; this
173 /// method exists for rare cases that need a second independent `Arc`.
174 fn clone_arc(&self) -> Arc<dyn ExtensionOp>;
175
176 /// Upcast this extension to `&dyn Any` for downcasting in `payload_eq`.
177 ///
178 /// Implementations SHOULD return `self` verbatim. The method is
179 /// object-safe (no `Self: Sized` bound) so it can be called on an
180 /// `&dyn ExtensionOp`; that's what makes
181 /// `other.as_any().downcast_ref::<ConcreteType>()` work from
182 /// [`Self::payload_eq`] implementations.
183 fn as_any(&self) -> &dyn Any;
184
185 // ----- Arity (spec Section 6) -----
186
187 /// Number of primal inputs. MUST be constant for any given
188 /// `Arc<dyn ExtensionOp>` value.
189 fn input_count(&self) -> usize;
190
191 /// Number of outputs. MUST match the length of the vector returned by
192 /// [`Self::infer_output_meta`].
193 fn output_count(&self) -> usize;
194
195 // ----- Shape and dtype inference (spec Section 7) -----
196
197 /// Infer output dtypes and shapes for each output slot.
198 ///
199 /// `input_dtypes.len()` and `input_shapes.len()` both equal
200 /// `self.input_count()`. The returned vector MUST have length
201 /// `self.output_count()`, one `(dtype, shape)` entry per output slot.
202 /// Shapes use [`SymDim`] so extension ops compose with graph-global
203 /// symbolic metadata.
204 fn infer_output_meta(
205 &self,
206 input_dtypes: &[DType],
207 input_shapes: &[&[SymDim]],
208 ) -> Vec<(DType, Vec<SymDim>)>;
209
210 // ----- Forward execution dispatch (spec Section 8) -----
211
212 /// Eager forward execution; called from the eager path and indirectly
213 /// from the compiled path.
214 ///
215 /// Input tensors are on the device the caller already arranged. Output
216 /// tensors MUST have shapes matching [`Self::infer_output_meta`] and MUST
217 /// be placed on a device the caller can consume.
218 fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>>;
219
220 /// Optionally expand this extension into standard tensor graph operations.
221 ///
222 /// Peer lowerers call this when all input metadata is known and extension
223 /// runtime dispatch is not available. Return `Ok(Some(outputs))` after
224 /// adding only standard [`StdTensorOp`] operations to `builder`. Return
225 /// `Ok(None)` when this extension family has no standard-op lowering for
226 /// the supplied metadata; strict lowerers should surface that as an
227 /// explicit unsupported-extension error. Return [`ExtensionLoweringError`]
228 /// when the payload is malformed or the lowering detects invalid metadata.
229 ///
230 /// The default implementation returns `Ok(None)` so existing extension
231 /// runtimes keep their native dispatch behavior until their owning crate
232 /// deliberately implements this hook.
233 fn lower_to_standard_ops(
234 &self,
235 _builder: &mut GraphBuilder<StdTensorOp>,
236 _inputs: &[ValueRef<StdTensorOp>],
237 _input_dtypes: &[DType],
238 _input_shapes: &[&[SymDim]],
239 ) -> ExtensionLoweringResult {
240 Ok(None)
241 }
242
243 // AD rules are registered separately; see [`ExtensionAdRule`].
244}
245
246/// AD rule provider for an extension family.
247///
248/// Rules are registered independently from the primal operation so an
249/// out-of-tree crate can provide forward execution without AD, or gate AD
250/// support behind an optional feature. Rule methods receive the concrete
251/// [`ExtensionOp`] payload as a trait object; implementations should downcast
252/// through [`ExtensionOp::as_any`] when they need payload-specific parameters.
253#[cfg(feature = "autodiff")]
254pub trait ExtensionAdRule: Debug + Send + Sync + 'static {
255 /// The extension family this rule handles.
256 fn family_id(&self) -> &'static str;
257
258 /// Emit the linear (JVP) rule.
259 fn linearize(
260 &self,
261 op: &dyn ExtensionOp,
262 builder: &mut dyn PrimitiveRuleBuilder,
263 primal_in: &[ValueKey<StdTensorOp>],
264 primal_out: &[ValueKey<StdTensorOp>],
265 tangent_in: &[Option<LocalValueId>],
266 ctx: &mut ShapeGuardContext,
267 ) -> ADRuleResult<Vec<Option<LocalValueId>>>;
268
269 /// Emit the transpose (VJP) rule.
270 fn transpose_rule(
271 &self,
272 op: &dyn ExtensionOp,
273 builder: &mut dyn PrimitiveRuleBuilder,
274 cotangent_out: &[Option<LocalValueId>],
275 inputs: &[ValueRef<StdTensorOp>],
276 mode: &OperationRole,
277 ctx: &mut ShapeGuardContext,
278 ) -> ADRuleResult<Vec<Option<LocalValueId>>>;
279}
280
281/// Errors returned from extension registries.
282#[cfg(feature = "autodiff")]
283#[derive(Debug, thiserror::Error)]
284pub enum ExtensionRegistryError {
285 /// An AD rule with the same `family_id` was already registered.
286 #[error("AD rule for family_id {family_id:?} already registered")]
287 DuplicateRule { family_id: &'static str },
288 /// The `family_id` does not match the namespaced format
289 /// `"<crate-name>.<op-name>.v<major>"`.
290 #[error("family_id {family_id:?} does not match the namespaced format")]
291 MalformedFamilyId { family_id: &'static str },
292}
293
294#[cfg(feature = "autodiff")]
295type RuleMap = HashMap<&'static str, Arc<dyn ExtensionAdRule>>;
296
297/// Explicit, owned set of extension AD rules.
298///
299/// This is the rule container used by higher-level AD contexts. Extension AD
300/// intentionally has no process-global fallback; callers must pass the rule set
301/// that their graph needs.
302#[cfg(feature = "autodiff")]
303#[derive(Clone, Default)]
304pub struct ExtensionRuleSet {
305 rules: Arc<RuleMap>,
306}
307
308#[cfg(feature = "autodiff")]
309impl std::fmt::Debug for ExtensionRuleSet {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 let mut families: Vec<_> = self.rules.keys().copied().collect();
312 families.sort_unstable();
313 f.debug_struct("ExtensionRuleSet")
314 .field("families", &families)
315 .finish()
316 }
317}
318
319#[cfg(feature = "autodiff")]
320impl ExtensionRuleSet {
321 /// Create an empty extension rule set.
322 ///
323 /// # Examples
324 ///
325 /// ```
326 /// use tenferro_ops::ExtensionRuleSet;
327 ///
328 /// let rules = ExtensionRuleSet::new();
329 /// assert!(!rules.is_rule_registered("example.missing.v1"));
330 /// ```
331 pub fn new() -> Self {
332 Self::default()
333 }
334
335 /// Add one rule to this owned set.
336 ///
337 /// # Examples
338 ///
339 /// ```
340 /// use std::sync::Arc;
341 /// use tidu::ADRuleResult;
342 /// use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
343 /// use tenferro_ops::ad::PrimitiveRuleBuilder;
344 /// use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionOp};
345 /// use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
346 /// use tenferro_ops::std_tensor_op::StdTensorOp;
347 ///
348 /// #[derive(Debug)]
349 /// struct Rule;
350 ///
351 /// impl ExtensionAdRule for Rule {
352 /// fn family_id(&self) -> &'static str { "example.register_rule.v1" }
353 /// fn linearize(
354 /// &self,
355 /// _op: &dyn ExtensionOp,
356 /// _builder: &mut dyn PrimitiveRuleBuilder,
357 /// _primal_in: &[ValueKey<StdTensorOp>],
358 /// _primal_out: &[ValueKey<StdTensorOp>],
359 /// tangent_in: &[Option<LocalValueId>],
360 /// _ctx: &mut ShapeGuardContext,
361 /// ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
362 /// Ok(tangent_in.to_vec())
363 /// }
364 /// fn transpose_rule(
365 /// &self,
366 /// _op: &dyn ExtensionOp,
367 /// _builder: &mut dyn PrimitiveRuleBuilder,
368 /// cotangent_out: &[Option<LocalValueId>],
369 /// _inputs: &[ValueRef<StdTensorOp>],
370 /// _mode: &OperationRole,
371 /// _ctx: &mut ShapeGuardContext,
372 /// ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
373 /// Ok(cotangent_out.to_vec())
374 /// }
375 /// }
376 ///
377 /// let mut rules = ExtensionRuleSet::new();
378 /// rules.register_rule(Arc::new(Rule)).unwrap();
379 /// assert!(rules.lookup_rule("example.register_rule.v1").is_some());
380 /// ```
381 pub fn register_rule(
382 &mut self,
383 rule: Arc<dyn ExtensionAdRule>,
384 ) -> Result<(), ExtensionRegistryError> {
385 let family_id = rule.family_id();
386 validate_rule_insert(&self.rules, family_id)?;
387 let rules = Arc::make_mut(&mut self.rules);
388 rules.insert(family_id, rule);
389 Ok(())
390 }
391
392 /// Return a new rule set containing `rule`.
393 ///
394 /// # Examples
395 ///
396 /// ```
397 /// use std::sync::Arc;
398 /// use tidu::ADRuleResult;
399 /// use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
400 /// use tenferro_ops::ad::PrimitiveRuleBuilder;
401 /// use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionOp};
402 /// use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
403 /// use tenferro_ops::std_tensor_op::StdTensorOp;
404 ///
405 /// #[derive(Debug)]
406 /// struct Rule;
407 ///
408 /// impl ExtensionAdRule for Rule {
409 /// fn family_id(&self) -> &'static str { "example.with_rule.v1" }
410 /// fn linearize(
411 /// &self,
412 /// _op: &dyn ExtensionOp,
413 /// _builder: &mut dyn PrimitiveRuleBuilder,
414 /// _primal_in: &[ValueKey<StdTensorOp>],
415 /// _primal_out: &[ValueKey<StdTensorOp>],
416 /// tangent_in: &[Option<LocalValueId>],
417 /// _ctx: &mut ShapeGuardContext,
418 /// ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
419 /// Ok(tangent_in.to_vec())
420 /// }
421 /// fn transpose_rule(
422 /// &self,
423 /// _op: &dyn ExtensionOp,
424 /// _builder: &mut dyn PrimitiveRuleBuilder,
425 /// cotangent_out: &[Option<LocalValueId>],
426 /// _inputs: &[ValueRef<StdTensorOp>],
427 /// _mode: &OperationRole,
428 /// _ctx: &mut ShapeGuardContext,
429 /// ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
430 /// Ok(cotangent_out.to_vec())
431 /// }
432 /// }
433 ///
434 /// let rules = ExtensionRuleSet::new().with_rule(Arc::new(Rule)).unwrap();
435 /// assert!(rules.is_rule_registered("example.with_rule.v1"));
436 /// ```
437 pub fn with_rule(
438 mut self,
439 rule: Arc<dyn ExtensionAdRule>,
440 ) -> Result<Self, ExtensionRegistryError> {
441 self.register_rule(rule)?;
442 Ok(self)
443 }
444
445 /// Merge another owned rule set into this one.
446 ///
447 /// The merge is atomic: if any rule in `other` is invalid or duplicates an
448 /// existing family, `self` is left unchanged.
449 ///
450 /// # Examples
451 ///
452 /// ```
453 /// use tenferro_ops::ExtensionRuleSet;
454 ///
455 /// let mut rules = ExtensionRuleSet::new();
456 /// rules.merge(ExtensionRuleSet::new()).unwrap();
457 /// assert!(!rules.is_rule_registered("example.missing.v1"));
458 /// ```
459 pub fn merge(&mut self, other: ExtensionRuleSet) -> Result<(), ExtensionRegistryError> {
460 let mut rules = (*self.rules).clone();
461 for rule in other.rules.values() {
462 insert_rule(&mut rules, Arc::clone(rule))?;
463 }
464 self.rules = Arc::new(rules);
465 Ok(())
466 }
467
468 /// Look up a rule in this set.
469 ///
470 /// # Examples
471 ///
472 /// ```
473 /// use tenferro_ops::ExtensionRuleSet;
474 ///
475 /// let rules = ExtensionRuleSet::new();
476 /// assert!(rules.lookup_rule("example.missing.v1").is_none());
477 /// ```
478 pub fn lookup_rule(&self, family_id: &str) -> Option<Arc<dyn ExtensionAdRule>> {
479 self.rules.get(family_id).cloned()
480 }
481
482 /// Return whether `family_id` is present in this set.
483 ///
484 /// # Examples
485 ///
486 /// ```
487 /// use tenferro_ops::ExtensionRuleSet;
488 ///
489 /// let rules = ExtensionRuleSet::new();
490 /// assert!(!rules.is_rule_registered("example.missing.v1"));
491 /// ```
492 pub fn is_rule_registered(&self, family_id: &str) -> bool {
493 self.rules.contains_key(family_id)
494 }
495}
496
497#[cfg(feature = "autodiff")]
498fn is_valid_family_id(family_id: &str) -> bool {
499 // Required shape: `<crate>.<op>.v<major>` with at least one non-empty
500 // `<crate>` chunk, at least one non-empty `<op>` chunk (which may itself
501 // contain `.`), and a final `.v<integer>` component. `<crate>` and
502 // `<op>` segments must be ASCII with no whitespace.
503 let mut parts = family_id.rsplitn(2, '.');
504 let Some(version_part) = parts.next() else {
505 return false;
506 };
507 let Some(prefix) = parts.next() else {
508 return false;
509 };
510 if !version_part.starts_with('v') {
511 return false;
512 }
513 let digits = &version_part[1..];
514 if digits.is_empty() || !digits.chars().all(|c| c.is_ascii_digit()) {
515 return false;
516 }
517 let Some((crate_name, op_name)) = prefix.split_once('.') else {
518 return false;
519 };
520 if crate_name.is_empty() || op_name.is_empty() {
521 return false;
522 }
523 let any_invalid = |s: &str| s.chars().any(|c| c.is_whitespace() || !c.is_ascii());
524 if any_invalid(crate_name) || any_invalid(op_name) {
525 return false;
526 }
527 true
528}
529
530#[cfg(feature = "autodiff")]
531fn insert_rule(
532 rules: &mut RuleMap,
533 rule: Arc<dyn ExtensionAdRule>,
534) -> Result<(), ExtensionRegistryError> {
535 let family_id = rule.family_id();
536 validate_rule_insert(rules, family_id)?;
537 rules.insert(family_id, rule);
538 Ok(())
539}
540
541#[cfg(feature = "autodiff")]
542fn validate_rule_insert(
543 rules: &RuleMap,
544 family_id: &'static str,
545) -> Result<(), ExtensionRegistryError> {
546 if !is_valid_family_id(family_id) {
547 return Err(ExtensionRegistryError::MalformedFamilyId { family_id });
548 }
549 if rules.contains_key(family_id) {
550 return Err(ExtensionRegistryError::DuplicateRule { family_id });
551 }
552 Ok(())
553}
554
555/// Emit a registered extension linearization rule.
556#[cfg(feature = "autodiff")]
557pub fn linearize_extension_rule(
558 op: &dyn ExtensionOp,
559 builder: &mut dyn PrimitiveRuleBuilder,
560 primal_in: &[ValueKey<StdTensorOp>],
561 primal_out: &[ValueKey<StdTensorOp>],
562 tangent_in: &[Option<LocalValueId>],
563 ctx: &mut ShapeGuardContext,
564) -> ADRuleResult<Vec<Option<LocalValueId>>> {
565 match ctx.extension_rule_for(op.family_id()) {
566 Some(rule) => rule.linearize(op, builder, primal_in, primal_out, tangent_in, ctx),
567 None => Err(ADRuleError::unsupported(op.family_id(), ADRuleKind::Jvp)),
568 }
569}
570
571/// Emit a registered extension transpose rule.
572#[cfg(feature = "autodiff")]
573pub fn transpose_extension_rule(
574 op: &dyn ExtensionOp,
575 builder: &mut dyn PrimitiveRuleBuilder,
576 cotangent_out: &[Option<LocalValueId>],
577 inputs: &[ValueRef<StdTensorOp>],
578 mode: &OperationRole,
579 ctx: &mut ShapeGuardContext,
580) -> ADRuleResult<Vec<Option<LocalValueId>>> {
581 match ctx.extension_rule_for(op.family_id()) {
582 Some(rule) => rule.transpose_rule(op, builder, cotangent_out, inputs, mode, ctx),
583 None => Err(ADRuleError::unsupported(
584 op.family_id(),
585 ADRuleKind::Transpose,
586 )),
587 }
588}
589
590/// Thin adapter that lets a generic `H: Hasher` satisfy the object-safe
591/// `&mut dyn Hasher` signature required by [`ExtensionOp::payload_hash`].
592///
593/// Only `write` and `finish` are load-bearing from the generic hasher; the
594/// various `write_u8` / `write_u16` default implementations in `Hasher`
595/// delegate to `write`. The adapter preserves that behaviour.
596pub(crate) struct DynHasherProxy<'a, H: Hasher + ?Sized> {
597 inner: &'a mut H,
598}
599
600impl<'a, H: Hasher + ?Sized> DynHasherProxy<'a, H> {
601 pub(crate) fn new(inner: &'a mut H) -> Self {
602 Self { inner }
603 }
604}
605
606impl<H: Hasher + ?Sized> Hasher for DynHasherProxy<'_, H> {
607 fn finish(&self) -> u64 {
608 self.inner.finish()
609 }
610
611 fn write(&mut self, bytes: &[u8]) {
612 self.inner.write(bytes);
613 }
614}
615
616/// Hash an `Arc<dyn ExtensionOp>` payload using the extension's
617/// [`ExtensionOp::family_id`] plus [`ExtensionOp::payload_hash`]. Shared
618/// between the `StdTensorOp::Extension` carrier's `Hash` impl and callers
619/// that need to fingerprint an `ExtensionOp` independently.
620pub(crate) fn hash_extension<H: Hasher>(op: &(dyn ExtensionOp + '_), state: &mut H) {
621 op.family_id().as_bytes().hash(state);
622 op.payload_hash(&mut DynHasherProxy::new(state));
623}
624
625/// Structural equality used by the `StdTensorOp::Extension` carrier.
626///
627/// Short-circuits on `family_id` inequality so two extensions with
628/// accidentally similar payloads but different families cannot be unified
629/// by the op interner.
630pub(crate) fn ext_op_eq(a: &dyn ExtensionOp, b: &dyn ExtensionOp) -> bool {
631 a.family_id() == b.family_id() && a.payload_eq(b)
632}