diff --git a/macros/src/lib.rs b/macros/src/lib.rs index a09a442..7f199c3 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,89 +1,120 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Expr, ExprClosure, FnArg, ItemFn, Pat, PatIdent, PatType}; +use syn::{ + parse_macro_input, Attribute, FnArg, Ident, ItemFn, Pat, PatIdent, PatType, Type, +}; -#[proc_macro_attribute] -pub fn endpoint(attr: TokenStream, item: TokenStream) -> TokenStream { - let input_fn = parse_macro_input!(item as ItemFn); - let closure_expr = parse_macro_input!(attr as Expr); +/// Represents a function argument we care about +#[derive(Clone)] +struct Arg { + ident: Ident, + ty: Type, + attrs: Vec, +} - let fn_name = input_fn.sig.ident.clone(); - let vis = input_fn.vis.clone(); - let generics = input_fn.sig.generics.clone(); - - // Collect query and body args - let mut query_idents = Vec::new(); - let mut query_types = Vec::new(); - let mut body_idents = Vec::new(); - let mut body_types = Vec::new(); - - for input in &input_fn.sig.inputs { - if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = input { +/// Extract arguments marked as `#[path]` or the rest +fn extract_args<'a, I>(inputs: I, path_only: bool) -> Vec +where + I: IntoIterator, +{ + let mut args = Vec::new(); + for fnarg in inputs { + if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = fnarg { let ident = if let Pat::Ident(PatIdent { ident, .. }) = &**pat { ident.clone() - } else { continue; }; - - if attrs.iter().any(|a| a.path().is_ident("query")) { - query_idents.push(ident); - query_types.push(*ty.clone()); } else { - body_idents.push(ident); - body_types.push(*ty.clone()); + panic!("Only simple identifiers are supported in endpoint arguments"); + }; + + if path_only && !attrs.iter().any(|a| a.path().is_ident("path")) { + continue; } + + args.push(Arg { + ident, + ty: *ty.clone(), + attrs: attrs.clone(), + }); } } + args +} - // Extract path args from closure - let mut path_idents = Vec::new(); - let mut path_types = Vec::new(); +/// Implements #[get("…")] and #[post("…")] +macro_rules! endpoint_macro { + ($method:ident) => { + #[proc_macro_attribute] + pub fn $method(attr: TokenStream, item: TokenStream) -> TokenStream { + let path_lit = parse_macro_input!(attr as syn::LitStr); + let input_fn = parse_macro_input!(item as ItemFn); - if let Expr::Closure(ExprClosure { inputs, .. }) = closure_expr { - for input in inputs.iter() { - if let Pat::Type(PatType { pat, ty, .. }) = input { - if let Pat::Ident(PatIdent { ident, .. }) = &**pat { - path_idents.push(ident.clone()); - path_types.push(*ty.clone()); + let vis = &input_fn.vis; + let fn_name = &input_fn.sig.ident; + let generics = &input_fn.sig.generics; + let inputs = &input_fn.sig.inputs; + let output = &input_fn.sig.output; + + let path_args = extract_args(inputs, true); + let other_args: Vec<_> = extract_args(inputs, false) + .into_iter() + .filter(|a| !path_args.iter().any(|p| p.ident == a.ident)) + .collect(); + + // for builder: path args first + let path_idents: Vec<_> = path_args.iter().map(|a| &a.ident).collect(); + let path_types: Vec<_> = path_args.iter().map(|a| &a.ty).collect(); + + let body_idents: Vec<_> = other_args.iter().map(|a| &a.ident).collect(); + let body_types: Vec<_> = other_args.iter().map(|a| &a.ty).collect(); + + let method_upper = stringify!($method).to_uppercase(); + + let expanded = quote! { + #[bon::builder] + #vis async fn #fn_name #generics ( + #[ builder(finish_fn) ] client: restson::RestClient, + #( #[builder(finish_fn)] #path_idents: #path_types, )* + #( #body_idents: #body_types, )* + ) -> Result<#output, restson::Error> { + // build path + let path = format!(#path_lit, #( #path_idents = #path_idents ),*); + + #[derive(serde::Serialize)] + struct Body { + #( #body_idents: #body_idents, )* + } + + let body = Body { + #( #body_idents, )* + }; + + // for query arguments, if any + #[derive(serde::Serialize)] + struct Query { + #( #body_idents: #body_idents, )* + } + + let query = Query { + #( #body_idents, )* + }; + // placeholder: convert query to vec of pairs + let query_pairs: Vec<(&str, &str)> = Vec::new(); + + let response = match #method_upper { + "GET" => client.get_with(path, query_pairs).await?, + "POST" => client.post_capture_with(body, query_pairs).await?, + _ => unreachable!(), + }; + + + todo!(response) } - } - } - } + }; - // Generate the final function - let expanded = quote! { - #[bon::builder] - #vis async fn #fn_name #generics ( - #( #[builder(finish_fn)] #path_idents: #path_types, )* - #( #query_idents: #query_types, )* - #( #body_idents: #body_types, )* - ) -> Result<_, restson::Error> { - - // Query struct - #[derive(serde::Serialize)] - struct Query { - #( #query_idents: #query_types, )* - } - let query = Query { #( #query_idents ),* }; - let query_vec = query.to_vec::<(&str, &str)>(); - - // Body struct - #[derive(serde::Serialize)] - struct Body { - #( #body_idents: #body_types, )* - } - let body = Body { #( #body_idents ),* }; - - // Response - #[derive(serde::de::DeserializeOwned)] - struct Response(Vec); - - let response: restson::Response = - client.post_capture_with(body, query_vec).await?; - - | #( #path_idents ),* | { - todo!(response -> _) - } ( #( #path_idents ),* ) + expanded.into() } }; - - TokenStream::from(expanded) } + +endpoint_macro!(get); +endpoint_macro!(post);