Skip to main content

tenferro_extension_macros/
lib.rs

1//! Procedural macros for tenferro extension crates.
2//!
3//! # Examples
4//!
5//! ```
6//! use tenferro_extension_macros::ExtensionFamilyId;
7//!
8//! #[derive(ExtensionFamilyId)]
9//! #[tenferro_extension(namespace = "my-crate", name = "fft", version = 1)]
10//! struct FftOp;
11//!
12//! assert_eq!(FftOp::FAMILY_ID, "my-crate.fft.v1");
13//! ```
14
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::parse::{Parse, ParseStream};
18use syn::{parse_macro_input, DeriveInput, Expr, ExprLit, Ident, Lit, Path, Token};
19
20#[derive(Debug, Default)]
21struct ExtensionArgs {
22    namespace: Option<String>,
23    name: Option<String>,
24    version: Option<u64>,
25}
26
27struct RuntimeArgs {
28    runtime: Ident,
29    family_id: Path,
30    op_type: Path,
31    execute: Path,
32    execute_reads: Path,
33    register_fn: Ident,
34    backend_bound: Path,
35}
36
37impl Parse for ExtensionArgs {
38    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
39        let mut args = Self::default();
40        while !input.is_empty() {
41            let key: syn::Ident = input.parse()?;
42            input.parse::<Token![=]>()?;
43            let value: Expr = input.parse()?;
44            match key.to_string().as_str() {
45                "namespace" => args.namespace = Some(expect_string(value, "namespace")?),
46                "name" => args.name = Some(expect_string(value, "name")?),
47                "version" => args.version = Some(expect_u64(value, "version")?),
48                other => {
49                    return Err(syn::Error::new(
50                        key.span(),
51                        format!("unsupported tenferro_extension argument {other:?}"),
52                    ));
53                }
54            }
55            if input.is_empty() {
56                break;
57            }
58            input.parse::<Token![,]>()?;
59        }
60        Ok(args)
61    }
62}
63
64impl Parse for RuntimeArgs {
65    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
66        let mut runtime = None;
67        let mut family_id = None;
68        let mut op_type = None;
69        let mut execute = None;
70        let mut execute_reads = None;
71        let mut register_fn = None;
72        let mut backend_bound = None;
73
74        while !input.is_empty() {
75            let key: Ident = input.parse()?;
76            input.parse::<Token![=]>()?;
77            match key.to_string().as_str() {
78                "runtime" => runtime = Some(input.parse()?),
79                "family_id" => family_id = Some(input.parse()?),
80                "op_type" => op_type = Some(input.parse()?),
81                "execute" => execute = Some(input.parse()?),
82                "execute_reads" => execute_reads = Some(input.parse()?),
83                "register_fn" => register_fn = Some(input.parse()?),
84                "backend_bound" => backend_bound = Some(input.parse()?),
85                other => {
86                    return Err(syn::Error::new(
87                        key.span(),
88                        format!("unsupported define_extension_runtime argument {other:?}"),
89                    ));
90                }
91            }
92            if input.is_empty() {
93                break;
94            }
95            input.parse::<Token![,]>()?;
96        }
97
98        Ok(Self {
99            runtime: required(runtime, "runtime")?,
100            family_id: required(family_id, "family_id")?,
101            op_type: required(op_type, "op_type")?,
102            execute: required(execute, "execute")?,
103            execute_reads: required(execute_reads, "execute_reads")?,
104            register_fn: required(register_fn, "register_fn")?,
105            backend_bound: backend_bound
106                .unwrap_or_else(|| syn::parse_quote!(tenferro_tensor::TensorBackend)),
107        })
108    }
109}
110
111/// Derive an inherent `FAMILY_ID` constant for an extension payload type.
112///
113/// The required attribute is:
114/// `#[tenferro_extension(namespace = "...", version = N)]`.
115/// `name = "..."` is optional; when omitted, the Rust type name is converted
116/// to snake_case.
117#[proc_macro_derive(ExtensionFamilyId, attributes(tenferro_extension))]
118pub fn derive_extension_family_id(input: TokenStream) -> TokenStream {
119    let input = parse_macro_input!(input as DeriveInput);
120    match expand_extension_family_id(input) {
121        Ok(tokens) => tokens.into(),
122        Err(err) => err.to_compile_error().into(),
123    }
124}
125
126/// Generate a standard extension runtime and registration function.
127///
128/// The `execute` function must have this signature:
129/// `fn<B: BackendBound + 'static>(&OpType, &[&Tensor], &mut ExtensionExecutionContext<'_, B>)`.
130///
131/// `execute_reads` is required. It must have this signature:
132/// `fn<B: BackendBound + 'static>(&OpType, &[TensorRead<'_>], &mut ExtensionExecutionContext<'_, B>)`.
133#[proc_macro]
134pub fn define_extension_runtime(input: TokenStream) -> TokenStream {
135    let args = parse_macro_input!(input as RuntimeArgs);
136    expand_extension_runtime(args).into()
137}
138
139fn expand_extension_family_id(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
140    let mut parsed = None;
141    for attr in &input.attrs {
142        if attr.path().is_ident("tenferro_extension") {
143            let args = attr.parse_args::<ExtensionArgs>()?;
144            parsed = Some(args);
145        }
146    }
147    let args = parsed.ok_or_else(|| {
148        syn::Error::new_spanned(
149            &input.ident,
150            "missing #[tenferro_extension(namespace = \"...\", version = N)]",
151        )
152    })?;
153    let namespace = args.namespace.ok_or_else(|| {
154        syn::Error::new_spanned(&input.ident, "missing tenferro_extension namespace")
155    })?;
156    let version = args.version.ok_or_else(|| {
157        syn::Error::new_spanned(&input.ident, "missing tenferro_extension version")
158    })?;
159    let name = args
160        .name
161        .unwrap_or_else(|| to_snake_case(&input.ident.to_string()));
162    let family_id = format!("{namespace}.{name}.v{version}");
163    let ident = input.ident;
164
165    Ok(quote! {
166        impl #ident {
167            /// Stable extension family identifier generated by `ExtensionFamilyId`.
168            pub const FAMILY_ID: &'static str = #family_id;
169        }
170    })
171}
172
173fn expand_extension_runtime(args: RuntimeArgs) -> proc_macro2::TokenStream {
174    let RuntimeArgs {
175        runtime,
176        family_id,
177        op_type,
178        execute,
179        execute_reads,
180        register_fn,
181        backend_bound,
182    } = args;
183    quote! {
184        #[derive(Debug, Default)]
185        pub(crate) struct #runtime;
186
187        impl<B: #backend_bound + 'static> tenferro_runtime::extension::ExtensionRuntime<B>
188            for #runtime
189        {
190            fn family_id(&self) -> &'static str {
191                #family_id
192            }
193
194            fn execute(
195                &self,
196                op: &dyn tenferro_runtime::extension::ExtensionOp,
197                inputs: &[&tenferro_tensor::Tensor],
198                ctx: &mut tenferro_runtime::extension::ExtensionExecutionContext<'_, B>,
199            ) -> tenferro_tensor::Result<Vec<tenferro_tensor::Tensor>> {
200                let op = op
201                    .as_any()
202                    .downcast_ref::<#op_type>()
203                    .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
204                        op: "extension_runtime",
205                        message: format!("payload type mismatch for {}", #family_id),
206                    })?;
207                #execute(op, inputs, ctx)
208            }
209
210            fn execute_reads(
211                &self,
212                op: &dyn tenferro_runtime::extension::ExtensionOp,
213                inputs: &[tenferro_tensor::TensorRead<'_>],
214                ctx: &mut tenferro_runtime::extension::ExtensionExecutionContext<'_, B>,
215            ) -> tenferro_tensor::Result<Vec<tenferro_tensor::Tensor>> {
216                let op = op
217                    .as_any()
218                    .downcast_ref::<#op_type>()
219                    .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
220                        op: "extension_runtime",
221                        message: format!("payload type mismatch for {}", #family_id),
222                    })?;
223                #execute_reads(op, inputs, ctx)
224            }
225        }
226
227        pub fn #register_fn<B: #backend_bound + 'static>(
228            executor: &mut tenferro_runtime::extension::ExtensionExecutor<B>,
229        ) -> std::result::Result<
230            (),
231            tenferro_runtime::extension::ExtensionRuntimeRegistryError,
232        > {
233            executor.registry_mut().register(std::sync::Arc::new(#runtime))
234        }
235    }
236}
237
238fn expect_string(value: Expr, field: &str) -> syn::Result<String> {
239    match value {
240        Expr::Lit(ExprLit {
241            lit: Lit::Str(value),
242            ..
243        }) => Ok(value.value()),
244        other => Err(syn::Error::new_spanned(
245            other,
246            format!("{field} must be a string literal"),
247        )),
248    }
249}
250
251fn expect_u64(value: Expr, field: &str) -> syn::Result<u64> {
252    match value {
253        Expr::Lit(ExprLit {
254            lit: Lit::Int(value),
255            ..
256        }) => value.base10_parse(),
257        other => Err(syn::Error::new_spanned(
258            other,
259            format!("{field} must be an integer literal"),
260        )),
261    }
262}
263
264fn required<T>(value: Option<T>, field: &str) -> syn::Result<T> {
265    value.ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), format!("missing {field}")))
266}
267
268fn to_snake_case(input: &str) -> String {
269    let mut out = String::new();
270    let mut prev_lower_or_digit = false;
271    for ch in input.chars() {
272        if ch.is_ascii_uppercase() {
273            if prev_lower_or_digit {
274                out.push('_');
275            }
276            out.push(ch.to_ascii_lowercase());
277            prev_lower_or_digit = false;
278        } else {
279            prev_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
280            out.push(ch);
281        }
282    }
283    out
284}
285
286#[cfg(test)]
287mod tests;