Skip to main content

tenferro_xla/
stablehlo.rs

1use sha2::{Digest, Sha256};
2
3/// StableHLO MLIR module text plus a deterministic fingerprint.
4///
5/// # Examples
6///
7/// ```
8/// use tenferro_xla::StableHloModule;
9///
10/// let module = StableHloModule::new("module {}".to_string());
11/// assert_eq!(module.as_str(), "module {}");
12/// assert_eq!(module.fingerprint().as_bytes().len(), 32);
13/// ```
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct StableHloModule {
16    text: String,
17    fingerprint: StableHloModuleFingerprint,
18}
19
20impl StableHloModule {
21    /// Create a module wrapper and fingerprint its text.
22    ///
23    /// # Examples
24    ///
25    /// ```
26    /// use tenferro_xla::StableHloModule;
27    ///
28    /// let module = StableHloModule::new("module {}".to_string());
29    /// assert!(module.fingerprint().to_hex().len() == 64);
30    /// ```
31    pub fn new(text: String) -> Self {
32        let fingerprint = StableHloModuleFingerprint::from_text(&text);
33        Self { text, fingerprint }
34    }
35
36    /// Borrow the StableHLO MLIR text.
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// use tenferro_xla::StableHloModule;
42    ///
43    /// let module = StableHloModule::new("module {}".to_string());
44    /// assert!(module.as_str().starts_with("module"));
45    /// ```
46    pub fn as_str(&self) -> &str {
47        &self.text
48    }
49
50    /// Return the module fingerprint.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use tenferro_xla::StableHloModule;
56    ///
57    /// let module = StableHloModule::new("module {}".to_string());
58    /// let _fingerprint = module.fingerprint();
59    /// ```
60    pub fn fingerprint(&self) -> StableHloModuleFingerprint {
61        self.fingerprint
62    }
63}
64
65/// SHA-256 fingerprint of StableHLO module text.
66///
67/// # Examples
68///
69/// ```
70/// use tenferro_xla::StableHloModuleFingerprint;
71///
72/// let fingerprint = StableHloModuleFingerprint::from_text("module {}");
73/// assert_eq!(fingerprint.as_bytes().len(), 32);
74/// ```
75#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
76pub struct StableHloModuleFingerprint([u8; 32]);
77
78impl StableHloModuleFingerprint {
79    /// Hash StableHLO text into a fingerprint.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use tenferro_xla::StableHloModuleFingerprint;
85    ///
86    /// let a = StableHloModuleFingerprint::from_text("module {}");
87    /// let b = StableHloModuleFingerprint::from_text("module {}");
88    /// assert_eq!(a, b);
89    /// ```
90    pub fn from_text(text: &str) -> Self {
91        let digest = Sha256::digest(text.as_bytes());
92        let mut bytes = [0_u8; 32];
93        bytes.copy_from_slice(&digest);
94        Self(bytes)
95    }
96
97    /// Borrow the raw fingerprint bytes.
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// use tenferro_xla::StableHloModuleFingerprint;
103    ///
104    /// let fingerprint = StableHloModuleFingerprint::from_text("module {}");
105    /// assert_eq!(fingerprint.as_bytes().len(), 32);
106    /// ```
107    pub fn as_bytes(&self) -> &[u8; 32] {
108        &self.0
109    }
110
111    /// Return the lowercase hexadecimal fingerprint.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use tenferro_xla::StableHloModuleFingerprint;
117    ///
118    /// let hex = StableHloModuleFingerprint::from_text("module {}").to_hex();
119    /// assert_eq!(hex.len(), 64);
120    /// ```
121    pub fn to_hex(&self) -> String {
122        self.0.iter().map(|byte| format!("{byte:02x}")).collect()
123    }
124}