tenferro_extension_macros/
lib.rs1use proc_macro::TokenStream;
16use quote::quote;
17use syn::parse::{Parse, ParseStream};
18use syn::{parse_macro_input, DeriveInput, Expr, ExprLit, Ident, Lit, Path, Token};
19
20#[derive(Debug, Default)]
21struct ExtensionArgs {
22 namespace: Option<String>,
23 name: Option<String>,
24 version: Option<u64>,
25}
26
27struct RuntimeArgs {
28 runtime: Ident,
29 family_id: Path,
30 op_type: Path,
31 execute: Path,
32 execute_reads: Path,
33 register_fn: Ident,
34 backend_bound: Path,
35}
36
37impl Parse for ExtensionArgs {
38 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
39 let mut args = Self::default();
40 while !input.is_empty() {
41 let key: syn::Ident = input.parse()?;
42 input.parse::<Token![=]>()?;
43 let value: Expr = input.parse()?;
44 match key.to_string().as_str() {
45 "namespace" => args.namespace = Some(expect_string(value, "namespace")?),
46 "name" => args.name = Some(expect_string(value, "name")?),
47 "version" => args.version = Some(expect_u64(value, "version")?),
48 other => {
49 return Err(syn::Error::new(
50 key.span(),
51 format!("unsupported tenferro_extension argument {other:?}"),
52 ));
53 }
54 }
55 if input.is_empty() {
56 break;
57 }
58 input.parse::<Token![,]>()?;
59 }
60 Ok(args)
61 }
62}
63
64impl Parse for RuntimeArgs {
65 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
66 let mut runtime = None;
67 let mut family_id = None;
68 let mut op_type = None;
69 let mut execute = None;
70 let mut execute_reads = None;
71 let mut register_fn = None;
72 let mut backend_bound = None;
73
74 while !input.is_empty() {
75 let key: Ident = input.parse()?;
76 input.parse::<Token![=]>()?;
77 match key.to_string().as_str() {
78 "runtime" => runtime = Some(input.parse()?),
79 "family_id" => family_id = Some(input.parse()?),
80 "op_type" => op_type = Some(input.parse()?),
81 "execute" => execute = Some(input.parse()?),
82 "execute_reads" => execute_reads = Some(input.parse()?),
83 "register_fn" => register_fn = Some(input.parse()?),
84 "backend_bound" => backend_bound = Some(input.parse()?),
85 other => {
86 return Err(syn::Error::new(
87 key.span(),
88 format!("unsupported define_extension_runtime argument {other:?}"),
89 ));
90 }
91 }
92 if input.is_empty() {
93 break;
94 }
95 input.parse::<Token![,]>()?;
96 }
97
98 Ok(Self {
99 runtime: required(runtime, "runtime")?,
100 family_id: required(family_id, "family_id")?,
101 op_type: required(op_type, "op_type")?,
102 execute: required(execute, "execute")?,
103 execute_reads: required(execute_reads, "execute_reads")?,
104 register_fn: required(register_fn, "register_fn")?,
105 backend_bound: backend_bound
106 .unwrap_or_else(|| syn::parse_quote!(tenferro_tensor::TensorBackend)),
107 })
108 }
109}
110
111#[proc_macro_derive(ExtensionFamilyId, attributes(tenferro_extension))]
118pub fn derive_extension_family_id(input: TokenStream) -> TokenStream {
119 let input = parse_macro_input!(input as DeriveInput);
120 match expand_extension_family_id(input) {
121 Ok(tokens) => tokens.into(),
122 Err(err) => err.to_compile_error().into(),
123 }
124}
125
126#[proc_macro]
134pub fn define_extension_runtime(input: TokenStream) -> TokenStream {
135 let args = parse_macro_input!(input as RuntimeArgs);
136 expand_extension_runtime(args).into()
137}
138
139fn expand_extension_family_id(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
140 let mut parsed = None;
141 for attr in &input.attrs {
142 if attr.path().is_ident("tenferro_extension") {
143 let args = attr.parse_args::<ExtensionArgs>()?;
144 parsed = Some(args);
145 }
146 }
147 let args = parsed.ok_or_else(|| {
148 syn::Error::new_spanned(
149 &input.ident,
150 "missing #[tenferro_extension(namespace = \"...\", version = N)]",
151 )
152 })?;
153 let namespace = args.namespace.ok_or_else(|| {
154 syn::Error::new_spanned(&input.ident, "missing tenferro_extension namespace")
155 })?;
156 let version = args.version.ok_or_else(|| {
157 syn::Error::new_spanned(&input.ident, "missing tenferro_extension version")
158 })?;
159 let name = args
160 .name
161 .unwrap_or_else(|| to_snake_case(&input.ident.to_string()));
162 let family_id = format!("{namespace}.{name}.v{version}");
163 let ident = input.ident;
164
165 Ok(quote! {
166 impl #ident {
167 pub const FAMILY_ID: &'static str = #family_id;
169 }
170 })
171}
172
173fn expand_extension_runtime(args: RuntimeArgs) -> proc_macro2::TokenStream {
174 let RuntimeArgs {
175 runtime,
176 family_id,
177 op_type,
178 execute,
179 execute_reads,
180 register_fn,
181 backend_bound,
182 } = args;
183 quote! {
184 #[derive(Debug, Default)]
185 pub(crate) struct #runtime;
186
187 impl<B: #backend_bound + 'static> tenferro_runtime::extension::ExtensionRuntime<B>
188 for #runtime
189 {
190 fn family_id(&self) -> &'static str {
191 #family_id
192 }
193
194 fn execute(
195 &self,
196 op: &dyn tenferro_runtime::extension::ExtensionOp,
197 inputs: &[&tenferro_tensor::Tensor],
198 ctx: &mut tenferro_runtime::extension::ExtensionExecutionContext<'_, B>,
199 ) -> tenferro_tensor::Result<Vec<tenferro_tensor::Tensor>> {
200 let op = op
201 .as_any()
202 .downcast_ref::<#op_type>()
203 .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
204 op: "extension_runtime",
205 message: format!("payload type mismatch for {}", #family_id),
206 })?;
207 #execute(op, inputs, ctx)
208 }
209
210 fn execute_reads(
211 &self,
212 op: &dyn tenferro_runtime::extension::ExtensionOp,
213 inputs: &[tenferro_tensor::TensorRead<'_>],
214 ctx: &mut tenferro_runtime::extension::ExtensionExecutionContext<'_, B>,
215 ) -> tenferro_tensor::Result<Vec<tenferro_tensor::Tensor>> {
216 let op = op
217 .as_any()
218 .downcast_ref::<#op_type>()
219 .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
220 op: "extension_runtime",
221 message: format!("payload type mismatch for {}", #family_id),
222 })?;
223 #execute_reads(op, inputs, ctx)
224 }
225 }
226
227 pub fn #register_fn<B: #backend_bound + 'static>(
228 executor: &mut tenferro_runtime::extension::ExtensionExecutor<B>,
229 ) -> std::result::Result<
230 (),
231 tenferro_runtime::extension::ExtensionRuntimeRegistryError,
232 > {
233 executor.registry_mut().register(std::sync::Arc::new(#runtime))
234 }
235 }
236}
237
238fn expect_string(value: Expr, field: &str) -> syn::Result<String> {
239 match value {
240 Expr::Lit(ExprLit {
241 lit: Lit::Str(value),
242 ..
243 }) => Ok(value.value()),
244 other => Err(syn::Error::new_spanned(
245 other,
246 format!("{field} must be a string literal"),
247 )),
248 }
249}
250
251fn expect_u64(value: Expr, field: &str) -> syn::Result<u64> {
252 match value {
253 Expr::Lit(ExprLit {
254 lit: Lit::Int(value),
255 ..
256 }) => value.base10_parse(),
257 other => Err(syn::Error::new_spanned(
258 other,
259 format!("{field} must be an integer literal"),
260 )),
261 }
262}
263
264fn required<T>(value: Option<T>, field: &str) -> syn::Result<T> {
265 value.ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), format!("missing {field}")))
266}
267
268fn to_snake_case(input: &str) -> String {
269 let mut out = String::new();
270 let mut prev_lower_or_digit = false;
271 for ch in input.chars() {
272 if ch.is_ascii_uppercase() {
273 if prev_lower_or_digit {
274 out.push('_');
275 }
276 out.push(ch.to_ascii_lowercase());
277 prev_lower_or_digit = false;
278 } else {
279 prev_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
280 out.push(ch);
281 }
282 }
283 out
284}
285
286#[cfg(test)]
287mod tests;