diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 7f199c3..995f029 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,120 +1,61 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{ - parse_macro_input, Attribute, FnArg, Ident, ItemFn, Pat, PatIdent, PatType, Type, -}; +use syn::{parse_macro_input, ItemFn, FnArg, Pat, PatType, PatIdent, Type, LitStr}; -/// Represents a function argument we care about -#[derive(Clone)] -struct Arg { - ident: Ident, - ty: Type, - attrs: Vec, -} +use syn::{punctuated::Punctuated, token::Comma}; -/// Extract arguments marked as `#[path]` or the rest -fn extract_args<'a, I>(inputs: I, path_only: bool) -> Vec -where - I: IntoIterator, -{ +fn extract_path_args(inputs: &Punctuated) -> Vec<(syn::Ident, syn::Type)> { let mut args = Vec::new(); - for fnarg in inputs { - if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = fnarg { - let ident = if let Pat::Ident(PatIdent { ident, .. }) = &**pat { - ident.clone() - } else { - panic!("Only simple identifiers are supported in endpoint arguments"); - }; - - if path_only && !attrs.iter().any(|a| a.path().is_ident("path")) { - continue; + for arg in inputs { + if let syn::FnArg::Typed(pat_type) = arg { + if pat_type.attrs.iter().any(|a| a.path().is_ident("path")) { + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { + args.push((pat_ident.ident.clone(), (*pat_type.ty).clone())); + } } - - args.push(Arg { - ident, - ty: *ty.clone(), - attrs: attrs.clone(), - }); } } args } -/// Implements #[get("…")] and #[post("…")] -macro_rules! endpoint_macro { - ($method:ident) => { - #[proc_macro_attribute] - pub fn $method(attr: TokenStream, item: TokenStream) -> TokenStream { - let path_lit = parse_macro_input!(attr as syn::LitStr); - let input_fn = parse_macro_input!(item as ItemFn); +fn generate_endpoint(attr: TokenStream, item: TokenStream, method: &str) -> TokenStream { + let item_fn = parse_macro_input!(item as ItemFn); + let path_lit = parse_macro_input!(attr as LitStr); - let vis = &input_fn.vis; - let fn_name = &input_fn.sig.ident; - let generics = &input_fn.sig.generics; - let inputs = &input_fn.sig.inputs; - let output = &input_fn.sig.output; + let fn_name = &item_fn.sig.ident; + let vis = &item_fn.vis; + let generics = &item_fn.sig.generics; - let path_args = extract_args(inputs, true); - let other_args: Vec<_> = extract_args(inputs, false) - .into_iter() - .filter(|a| !path_args.iter().any(|p| p.ident == a.ident)) - .collect(); + let path_args = extract_path_args(&item_fn.sig.inputs); + let path_idents: Vec<_> = path_args.iter().map(|(id, _)| id).collect(); + let path_types: Vec<_> = path_args.iter().map(|(_, ty)| ty).collect(); - // for builder: path args first - let path_idents: Vec<_> = path_args.iter().map(|a| &a.ident).collect(); - let path_types: Vec<_> = path_args.iter().map(|a| &a.ty).collect(); + let ret_type = match &item_fn.sig.output { + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ty) => quote! { #ty }, + }; - let body_idents: Vec<_> = other_args.iter().map(|a| &a.ident).collect(); - let body_types: Vec<_> = other_args.iter().map(|a| &a.ty).collect(); + let expanded = quote! { + #[bon::builder] + #vis async fn #fn_name #generics( + #[builder(finish_fn)] client: restson::RestClient, + #( #[builder(finish_fn)] #path_idents: #path_types, )* + ) -> Result<#ret_type, restson::Error> { + let path = format!(#path_lit, #( #path_idents = #path_idents, )* ); - let method_upper = stringify!($method).to_uppercase(); - - let expanded = quote! { - #[bon::builder] - #vis async fn #fn_name #generics ( - #[ builder(finish_fn) ] client: restson::RestClient, - #( #[builder(finish_fn)] #path_idents: #path_types, )* - #( #body_idents: #body_types, )* - ) -> Result<#output, restson::Error> { - // build path - let path = format!(#path_lit, #( #path_idents = #path_idents ),*); - - #[derive(serde::Serialize)] - struct Body { - #( #body_idents: #body_idents, )* - } - - let body = Body { - #( #body_idents, )* - }; - - // for query arguments, if any - #[derive(serde::Serialize)] - struct Query { - #( #body_idents: #body_idents, )* - } - - let query = Query { - #( #body_idents, )* - }; - // placeholder: convert query to vec of pairs - let query_pairs: Vec<(&str, &str)> = Vec::new(); - - let response = match #method_upper { - "GET" => client.get_with(path, query_pairs).await?, - "POST" => client.post_capture_with(body, query_pairs).await?, - _ => unreachable!(), - }; - - - todo!(response) - } - }; - - expanded.into() + todo!("Replace this with client.{} call", #method) } }; + + TokenStream::from(expanded) } -endpoint_macro!(get); -endpoint_macro!(post); +#[proc_macro_attribute] +pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream { + generate_endpoint(attr, item, "get_with") +} + +#[proc_macro_attribute] +pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream { + generate_endpoint(attr, item, "post_capture_with") +}