// impl_trait - Rust proc macro that significantly reduces boilerplate // Copyright (C) 2021 Soni L. // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . extern crate proc_macro; use proc_macro::{TokenStream, TokenTree, Delimiter, Span}; //use syn::parse::{Parse, ParseStream, Result as ParseResult}; use syn::{Generics, GenericParam}; use std::cmp::Ordering; use quote::ToTokens; #[proc_macro] #[allow(unreachable_code)] pub fn impl_trait(item: TokenStream) -> TokenStream { //eprintln!("INPUT: {:#?}", item); let mut output: Vec<_> = item.into_iter().collect(); let attributes: Vec = { let mut pos = 0; let mut len = 0; let mut in_attr = false; while pos != output.len() { let tt = &output[pos]; pos += 1; match tt { &TokenTree::Punct(ref punct) => { if punct.as_char() == '#' && !in_attr { in_attr = true; continue; } } &TokenTree::Group(ref group) => { if group.delimiter() == Delimiter::Bracket && in_attr { in_attr = false; len = pos; continue; } } _ => {} } break; } output.drain(0..len).collect() }; //eprintln!("attributes: {:#?}", attributes); // check for impl. // unsafe impls are only available for traits and are automatically rejected. 'check_impl: loop { break { if let &TokenTree::Ident(ref ident) = &output[0] { if format!("{}", ident) == "impl" { break 'check_impl; } } panic!("impl_trait! may only be applied to inherent impls"); } } let mut has_where: Option<&TokenTree> = None; 'check_no_for_before_where: loop { break { for tt in &output { if let &TokenTree::Ident(ref ident) = tt { let formatted = format!("{}", ident); if formatted == "where" { has_where = Some(tt); break 'check_no_for_before_where; } else if formatted == "for" { panic!("impl_trait! may only be applied to inherent impls"); } } } } } // this is the "where [...]" part, including the "where". let mut where_bounds = Vec::new(); if let Some(where_in) = has_where { where_bounds = output.split_last().unwrap().1.into_iter().skip_while(|&tt| { !std::ptr::eq(tt, where_in) }).cloned().collect(); } let where_bounds = where_bounds; drop(has_where); let mut count = 0; // this is the "<...>" part, immediately after the "impl", and including the "<>". let generics = output.split_first().unwrap().1.into_iter().take_while(|&tt| { let mut result = count > 0; if let &TokenTree::Punct(ref punct) = tt { let c = punct.as_char(); if c == '<' { count += 1; result = true; } else if c == '>' { count -= 1; } } result }).cloned().collect::>(); // so how do you find the target? well... // "impl" + [_; generics.len()] + [_; target.len()] + [_; where_bounds.len()] + "{}" // we have generics and where_bounds, and the total, so we can easily find target! let target_start = 1 + generics.len(); let target_end = output.len() - 1 - where_bounds.len(); let target_range = target_start..target_end; let target = (&output[target_range]).into_iter().cloned().collect::>(); //eprintln!("generics: {:#?}", generics); //eprintln!("target: {:#?}", target); //eprintln!("where_bounds: {:#?}", where_bounds); let items = output.last_mut(); if let &mut TokenTree::Group(ref mut group) = items.unwrap() { // TODO: parse "[unsafe] impl trait" somehow. use syn for it maybe (after swallowing the "trait") // luckily for us, there's only one thing that can come after an impl trait: a path // (and optional generics). // but we can't figure out how to parse the `where`. //todo!(); let span = group.span(); let mut items = group.stream().into_iter().collect::>(); let mut in_unsafe = false; let mut in_impl = false; let mut in_path = false; let mut in_generic = false; let mut in_attr = false; let mut in_attr_cont = false; let mut has_injected_generics = false; let mut in_where = false; let mut start = 0; let mut found: Vec> = Vec::new(); let mut to_remove: Vec> = Vec::new(); let mut generics_scratchpad = Vec::new(); let mut count = 0; let mut trait_span: Option = None; 'main_loop: for (pos, tt) in (&items).into_iter().enumerate() { if in_generic { // collect the generics let mut result = count > 0; if let &TokenTree::Punct(ref punct) = tt { let c = punct.as_char(); if c == '<' { count += 1; result = true; } else if c == '>' { count -= 1; if count == 0 { in_generic = false; in_path = true; } } } if result { generics_scratchpad.push(tt.clone()); continue; } } if in_path { // inject the generics if !has_injected_generics { has_injected_generics = true; if generics_scratchpad.is_empty() { found.last_mut().unwrap().extend(generics.clone()); } else if generics.is_empty() { found.last_mut().unwrap().extend(generics_scratchpad.clone()); } else { // need to *combine* generics. this is not exactly trivial. // thankfully we don't need to worry about defaults on impls. let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap(); let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap(); let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::>(); target.sort_by(|a, b| { match (a.value(), b.value()) { (&GenericParam::Lifetime(_), &GenericParam::Const(_)) => Ordering::Less, (&GenericParam::Type(_), &GenericParam::Const(_)) => Ordering::Less, (&GenericParam::Lifetime(_), &GenericParam::Type(_)) => Ordering::Less, (&GenericParam::Lifetime(_), &GenericParam::Lifetime(_)) => Ordering::Equal, (&GenericParam::Type(_), &GenericParam::Type(_)) => Ordering::Equal, (&GenericParam::Const(_), &GenericParam::Const(_)) => Ordering::Equal, (&GenericParam::Type(_), &GenericParam::Lifetime(_)) => Ordering::Greater, (&GenericParam::Const(_), &GenericParam::Type(_)) => Ordering::Greater, (&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater, } }); // just need to fix the one Pair::End in the middle of the thing. for item in &mut target { if matches!(item, syn::punctuated::Pair::End(_)) { let value = item.value().clone(); *item = syn::punctuated::Pair::Punctuated(value, syn::token::Comma { spans: [trait_span.unwrap().into()] }); break; } } this_generics.params = target.into_iter().collect(); let new_generics = TokenStream::from(this_generics.into_token_stream()); found.last_mut().unwrap().extend(new_generics); } } in_generic = false; if let &TokenTree::Ident(ref ident) = tt { let formatted = format!("{}", ident); if count == 0 && formatted == "where" { in_path = false; in_where = true; // add "for" found.last_mut().unwrap().push(proc_macro::Ident::new("for", trait_span.unwrap()).into()); // add Target found.last_mut().unwrap().extend(target.clone()); // *then* add the "where" (from the impl-trait) found.last_mut().unwrap().push(tt.clone()); // and the parent bounds (except the "where") if !where_bounds.is_empty() { found.last_mut().unwrap().extend((&where_bounds).into_iter().skip(1).cloned()); // also make sure that there's an ',' at the correct place if let Some(&TokenTree::Punct(ref x)) = where_bounds.last() { if x.as_char() == ',' { continue 'main_loop; } } found.last_mut().unwrap().push(proc_macro::Punct::new(',', proc_macro::Spacing::Alone).into()); } continue 'main_loop; } } if let &TokenTree::Punct(ref punct) = tt { let c = punct.as_char(); if c == '<' { count += 1; } else if c == '>' { // this is broken so just give up // FIXME better error handling if count == 0 { in_path = false; continue 'main_loop; } count -= 1; } } if let &TokenTree::Group(ref group) = tt { if group.delimiter() == Delimiter::Brace && count == 0 { to_remove.push(start..pos+1); // add "for" found.last_mut().unwrap().push(proc_macro::Ident::new("for", tt.span()).into()); // add Target found.last_mut().unwrap().extend(target.clone()); // and the parent bounds (including the "where") found.last_mut().unwrap().extend(where_bounds.clone()); in_path = false; in_where = false; // fall through to add the block } } found.last_mut().unwrap().push(tt.clone()); continue 'main_loop; } if in_where { // just try to find the block, and add all the stuff. if let &TokenTree::Punct(ref punct) = tt { let c = punct.as_char(); if c == '<' { count += 1; } else if c == '>' { // this is broken so just give up // FIXME better error handling if count == 0 { in_where = false; continue 'main_loop; } count -= 1; } } if let &TokenTree::Group(ref group) = tt { if group.delimiter() == Delimiter::Brace && count == 0 { // call it done! to_remove.push(start..pos+1); in_where = false; } } found.last_mut().unwrap().push(tt.clone()); continue 'main_loop; } if found.len() == to_remove.len() { found.push(Vec::new()); in_unsafe = false; in_impl = false; in_where = false; in_path = false; in_attr_cont = false; in_generic = false; has_injected_generics = false; count = 0; } match tt { &TokenTree::Ident(ref ident) => { let formatted = format!("{}", ident); if formatted == "unsafe" && !in_impl { found.last_mut().unwrap().push(tt.clone()); if !in_attr_cont { start = pos; } in_attr = false; in_unsafe = true; continue; } else if formatted == "impl" && !in_impl { if !in_attr_cont && !in_unsafe { start = pos; } found.last_mut().unwrap().push(tt.clone()); in_unsafe = false; in_attr = false; in_impl = true; continue; } else if formatted == "trait" && in_impl { // swallowed. doesn't go into found. trait_span = Some(tt.span()); in_generic = true; in_path = true; in_impl = false; has_injected_generics = false; continue; } }, &TokenTree::Punct(ref punct) => { if punct.as_char() == '#' && !in_attr { found.last_mut().unwrap().push(tt.clone()); if !in_attr_cont { start = pos; } in_attr = true; continue; } } &TokenTree::Group(ref group) => { if group.delimiter() == Delimiter::Bracket && in_attr { found.last_mut().unwrap().push(tt.clone()); in_attr = false; in_attr_cont = true; continue; } } _ => {} } found.truncate(to_remove.len()); in_unsafe = false; in_impl = false; in_where = false; in_path = false; in_attr_cont = false; in_generic = false; has_injected_generics = false; count = 0; } // must be iterated backwards for range in to_remove.into_iter().rev() { items.drain(range); } *group = proc_macro::Group::new(group.delimiter(), items.into_iter().collect()); group.set_span(span); output.extend(found.into_iter().flatten()); } drop(generics); drop(target); drop(where_bounds); //eprintln!("attributes: {:#?}", attributes); //eprintln!("OUTPUT: {:#?}", output); //eprintln!("OUTPUT: {}", (&output).into_iter().cloned().collect::()); attributes.into_iter().chain(output.into_iter()).collect() }