This commit is contained in:
Jonas Rabenstein 2026-02-27 05:52:35 +01:00
commit a0ddaf89a9
5 changed files with 197 additions and 168 deletions

View file

@ -1,109 +1,132 @@
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use quote::quote;
use syn::{
parse_macro_input, Attribute, FnArg, ItemFn, Lit, Meta, Pat, PatType, Type, spanned::Spanned,
parse_macro_input, FnArg, ItemFn, LitStr, Pat, PatType, Type, Attribute, spanned::Spanned,
};
/// Endpoint macro
/// Path args parser for `#[endpoint((x:u128,y:String): "/path/{x}/{y}")]`
struct EndpointPath {
args: Vec<(syn::Ident, syn::Type)>,
path: LitStr,
}
impl syn::parse::Parse for EndpointPath {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let content;
syn::parenthesized!(content in input);
let mut args = Vec::new();
while !content.is_empty() {
let ident: syn::Ident = content.parse()?;
content.parse::<syn::Token![:]>()?;
let ty: syn::Type = content.parse()?;
args.push((ident, ty));
if content.peek(syn::Token![,]) {
content.parse::<syn::Token![,]>()?;
}
}
input.parse::<syn::Token![:]>()?;
let path: LitStr = input.parse()?;
Ok(EndpointPath { args, path })
}
}
#[proc_macro_attribute]
pub fn endpoint(attr: TokenStream, item: TokenStream) -> TokenStream {
// parse the endpoint attribute
let attr = proc_macro2::TokenStream::from(attr);
let path_args = parse_macro_input!(attr as EndpointPath);
let func = parse_macro_input!(item as ItemFn);
let vis = &func.vis;
let name = &func.sig.ident;
let inputs = &func.sig.inputs;
let output = &func.sig.output;
// must be async
if func.sig.asyncness.is_none() {
return syn::Error::new(name.span(), "endpoint function must be async")
.to_compile_error()
.into();
}
// Separate query/body args
let mut query_fields = Vec::new();
let mut body_fields = Vec::new();
// defaults
let mut method = quote! { GET };
let mut path = None;
for arg in &func.sig.inputs {
if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = arg {
let pat_ident = match &**pat {
Pat::Ident(pi) => &pi.ident,
_ => continue,
};
// parse #[endpoint(...)]
if !attr.is_empty() {
let attr_str = attr.to_string();
// simple heuristic: if contains "POST", switch
if attr_str.contains("POST") {
method = quote! { POST };
}
// simple heuristic: extract path in quotes
if let Some(start) = attr_str.find('"') {
if let Some(end) = attr_str[start+1..].find('"') {
path = Some(attr_str[start+1..start+1+end].to_string());
if attrs.iter().any(|a| a.path().is_ident("query")) {
query_fields.push((pat_ident.clone(), (*ty).clone(), attrs.clone()));
} else if attrs.iter().any(|a| a.path().is_ident("body")) {
body_fields.push((pat_ident.clone(), (*ty).clone(), attrs.clone()));
}
}
}
let path = match path {
Some(p) => p,
None => return syn::Error::new(name.span(), "endpoint path must be provided")
.to_compile_error()
.into(),
};
// Path args
let path_idents: Vec<_> = path_args.args.iter().map(|(i, _)| i).collect();
let path_types: Vec<_> = path_args.args.iter().map(|(_, t)| t).collect();
let path_fmt = &path_args.path;
// process arguments
let mut client_arg = None;
let mut other_args = Vec::new();
for input in inputs {
match input {
FnArg::Receiver(_) => continue, // skip self
FnArg::Typed(PatType { pat, ty, .. }) => {
if let Pat::Ident(ident) = &**pat {
if ident.ident == "client" {
client_arg = Some((ident.ident.clone(), ty));
} else {
other_args.push((ident.ident.clone(), ty));
}
}
// Build query serialization
let query_pairs = query_fields.iter().map(|(ident, _, _)| {
quote! {
if let Some(v) = &#ident {
query_pairs.push((stringify!(#ident), v.to_string()));
}
}
}
});
// generate tokens for function with builder
let arg_defs: Vec<proc_macro2::TokenStream> = other_args.iter().map(|(id, ty)| {
// wrap Option<T> with #[builder(default)]
if let Type::Path(tp) = ty.as_ref() {
let is_option = tp.path.segments.last().map(|seg| seg.ident == "Option").unwrap_or(false);
if is_option {
quote! { #[builder(default)] #id: #ty }
// Build body serialization
let body_pairs = body_fields.iter().map(|(ident, _, _)| {
quote! {
body_map.insert(stringify!(#ident).to_string(), serde_json::to_value(&#ident)?);
}
});
// Determine method
let method = if body_fields.is_empty() { quote! { GET } } else { quote! { POST } };
// Expand query/body fields for function signature
let query_sig = query_fields.iter().map(|(ident, ty, _attrs)| {
if let Type::Path(tp) = &**ty { // <-- dereference the Box<Type>
if tp.path.segments.last().unwrap().ident == "Option" {
quote! { #[builder(default)] #ident: #ty }
} else {
quote! { #id: #ty }
quote! { #ident: #ty }
}
} else {
quote! { #id: #ty }
quote! { #ident: #ty }
}
}).collect();
});
let client_def = match client_arg {
Some((id, ty)) => quote! { #[builder(finish_fn)] #id: #ty },
None => quote! { #[builder(finish_fn)] client: &restson::RestClient },
};
let call_args: Vec<proc_macro2::TokenStream> = other_args.iter().map(|(id, _)| {
quote! { #id }
}).collect();
let body_sig = body_fields.iter().map(|(ident, ty, _attrs)| {
quote! { #ident: #ty }
});
let expanded = quote! {
#[bon::builder]
#vis fn #name(
#client_def,
#(#arg_defs),*
) -> impl std::future::Future<Output = Result<_, restson::Error>> + '_ {
#[builder(finish_fn)]
client: &restson::RestClient,
#( #[builder(finish_fn)] #path_idents: #path_types, )*
#( #query_sig, )*
#( #body_sig, )*
) -> impl std::future::Future<Output = Result<serde_json::Value, restson::Error>> + '_ {
let mut path = format!(#path_fmt, #( #path_idents = #path_idents ),*);
let mut query_pairs = Vec::new();
#( #query_pairs )*
let mut body_map = serde_json::Map::new();
#( #body_pairs )*
async move {
let result = client.get::<_, serde_json::Value>(#path).await?;
Ok(result)
if body_map.is_empty() {
client.request_with(#method, &path, &query_pairs, &()).await
} else {
client.request_with(#method, &path, &query_pairs, &body_map).await
}
}
}
};
expanded.into()
TokenStream::from(expanded)
}