// impl_trait - Rust proc macro that significantly reduces boilerplate
// Copyright (C) 2021 Soni L.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
extern crate proc_macro;
use proc_macro::{TokenStream, TokenTree, Delimiter, Span};
//use syn::parse::{Parse, ParseStream, Result as ParseResult};
use syn::{Generics, GenericParam};
use std::cmp::Ordering;
use quote::ToTokens;
#[proc_macro]
#[allow(unreachable_code)]
pub fn impl_trait(item: TokenStream) -> TokenStream {
//eprintln!("INPUT: {:#?}", item);
let mut output: Vec<_> = item.into_iter().collect();
let attributes: Vec<TokenTree> = {
let mut pos = 0;
let mut len = 0;
let mut in_attr = false;
while pos != output.len() {
let tt = &output[pos];
pos += 1;
match tt {
&TokenTree::Punct(ref punct) => {
if punct.as_char() == '#' && !in_attr {
in_attr = true;
continue;
}
}
&TokenTree::Group(ref group) => {
if group.delimiter() == Delimiter::Bracket && in_attr {
in_attr = false;
len = pos;
continue;
}
}
_ => {}
}
break;
}
output.drain(0..len).collect()
};
//eprintln!("attributes: {:#?}", attributes);
// check for impl.
// unsafe impls are only available for traits and are automatically rejected.
'check_impl: loop { break {
if let &TokenTree::Ident(ref ident) = &output[0] {
if format!("{}", ident) == "impl" {
break 'check_impl;
}
}
panic!("impl_trait! may only be applied to inherent impls");
} }
let mut has_where: Option<&TokenTree> = None;
'check_no_for_before_where: loop { break {
for tt in &output {
if let &TokenTree::Ident(ref ident) = tt {
let formatted = format!("{}", ident);
if formatted == "where" {
has_where = Some(tt);
break 'check_no_for_before_where;
} else if formatted == "for" {
panic!("impl_trait! may only be applied to inherent impls");
}
}
}
} }
// this is the "where [...]" part, including the "where".
let mut where_bounds = Vec::new();
if let Some(where_in) = has_where {
where_bounds = output.split_last().unwrap().1.into_iter().skip_while(|&tt| {
!std::ptr::eq(tt, where_in)
}).cloned().collect();
}
let where_bounds = where_bounds;
drop(has_where);
let mut count = 0;
// this is the "<...>" part, immediately after the "impl", and including the "<>".
let generics = output.split_first().unwrap().1.into_iter().take_while(|&tt| {
let mut result = count > 0;
if let &TokenTree::Punct(ref punct) = tt {
let c = punct.as_char();
if c == '<' {
count += 1;
result = true;
} else if c == '>' {
count -= 1;
}
}
result
}).cloned().collect::<Vec<_>>();
// so how do you find the target? well...
// "impl" + [_; generics.len()] + [_; target.len()] + [_; where_bounds.len()] + "{}"
// we have generics and where_bounds, and the total, so we can easily find target!
let target_start = 1 + generics.len();
let target_end = output.len() - 1 - where_bounds.len();
let target_range = target_start..target_end;
let target = (&output[target_range]).into_iter().cloned().collect::<Vec<_>>();
//eprintln!("generics: {:#?}", generics);
//eprintln!("target: {:#?}", target);
//eprintln!("where_bounds: {:#?}", where_bounds);
let items = output.last_mut();
if let &mut TokenTree::Group(ref mut group) = items.unwrap() {
// TODO: parse "[unsafe] impl trait" somehow. use syn for it maybe (after swallowing the "trait")
// luckily for us, there's only one thing that can come after an impl trait: a path
// (and optional generics).
// but we can't figure out how to parse the `where`.
//todo!();
let span = group.span();
let mut items = group.stream().into_iter().collect::<Vec<_>>();
let mut in_unsafe = false;
let mut in_impl = false;
let mut in_path = false;
let mut in_generic = false;
let mut in_attr = false;
let mut in_attr_cont = false;
let mut has_injected_generics = false;
let mut in_where = false;
let mut start = 0;
let mut found: Vec<Vec<TokenTree>> = Vec::new();
let mut to_remove: Vec<std::ops::Range<usize>> = Vec::new();
let mut generics_scratchpad = Vec::new();
let mut count = 0;
let mut trait_span: Option<Span> = None;
'main_loop: for (pos, tt) in (&items).into_iter().enumerate() {
if in_generic {
// collect the generics
let mut result = count > 0;
if let &TokenTree::Punct(ref punct) = tt {
let c = punct.as_char();
if c == '<' {
count += 1;
result = true;
} else if c == '>' {
count -= 1;
if count == 0 {
in_generic = false;
in_path = true;
}
}
}
if result {
generics_scratchpad.push(tt.clone());
continue;
}
}
if in_path {
// inject the generics
if !has_injected_generics {
has_injected_generics = true;
if generics_scratchpad.is_empty() {
found.last_mut().unwrap().extend(generics.clone());
} else if generics.is_empty() {
found.last_mut().unwrap().extend(generics_scratchpad.clone());
} else {
// need to *combine* generics. this is not exactly trivial.
// thankfully we don't need to worry about defaults on impls.
let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap();
let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap();
let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::<Vec<_>>();
target.sort_by(|a, b| {
match (a.value(), b.value()) {
(&GenericParam::Lifetime(_), &GenericParam::Const(_)) => Ordering::Less,
(&GenericParam::Type(_), &GenericParam::Const(_)) => Ordering::Less,
(&GenericParam::Lifetime(_), &GenericParam::Type(_)) => Ordering::Less,
(&GenericParam::Lifetime(_), &GenericParam::Lifetime(_)) => Ordering::Equal,
(&GenericParam::Type(_), &GenericParam::Type(_)) => Ordering::Equal,
(&GenericParam::Const(_), &GenericParam::Const(_)) => Ordering::Equal,
(&GenericParam::Type(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
(&GenericParam::Const(_), &GenericParam::Type(_)) => Ordering::Greater,
(&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
}
});
// just need to fix the one Pair::End in the middle of the thing.
for item in &mut target {
if matches!(item, syn::punctuated::Pair::End(_)) {
let value = item.value().clone();
*item = syn::punctuated::Pair::Punctuated(value, syn::token::Comma { spans: [trait_span.unwrap().into()] });
break;
}
}
this_generics.params = target.into_iter().collect();
let new_generics = TokenStream::from(this_generics.into_token_stream());
found.last_mut().unwrap().extend(new_generics);
}
}
in_generic = false;
if let &TokenTree::Ident(ref ident) = tt {
let formatted = format!("{}", ident);
if count == 0 && formatted == "where" {
in_path = false;
in_where = true;
// add "for"
found.last_mut().unwrap().push(proc_macro::Ident::new("for", trait_span.unwrap()).into());
// add Target
found.last_mut().unwrap().extend(target.clone());
// *then* add the "where" (from the impl-trait)
found.last_mut().unwrap().push(tt.clone());
// and the parent bounds (except the "where")
found.last_mut().unwrap().extend((&where_bounds).into_iter().skip(1).cloned());
// also make sure that there's an ',' at the correct place
if let Some(&TokenTree::Punct(ref x)) = where_bounds.last() {
if x.as_char() == ',' {
continue 'main_loop;
}
}
found.last_mut().unwrap().push(proc_macro::Punct::new(',', proc_macro::Spacing::Alone).into());
continue 'main_loop;
}
}
if let &TokenTree::Punct(ref punct) = tt {
let c = punct.as_char();
if c == '<' {
count += 1;
} else if c == '>' {
// this is broken so just give up
// FIXME better error handling
if count == 0 {
in_path = false;
continue 'main_loop;
}
count -= 1;
}
}
if let &TokenTree::Group(ref group) = tt {
if group.delimiter() == Delimiter::Brace && count == 0 {
to_remove.push(start..pos+1);
// add "for"
found.last_mut().unwrap().push(proc_macro::Ident::new("for", tt.span()).into());
// add Target
found.last_mut().unwrap().extend(target.clone());
// and the parent bounds (including the "where")
found.last_mut().unwrap().extend(where_bounds.clone());
in_path = false;
in_where = false;
// fall through to add the block
}
}
found.last_mut().unwrap().push(tt.clone());
continue 'main_loop;
}
if in_where {
// just try to find the block, and add all the stuff.
if let &TokenTree::Punct(ref punct) = tt {
let c = punct.as_char();
if c == '<' {
count += 1;
} else if c == '>' {
// this is broken so just give up
// FIXME better error handling
if count == 0 {
in_where = false;
continue 'main_loop;
}
count -= 1;
}
}
if let &TokenTree::Group(ref group) = tt {
if group.delimiter() == Delimiter::Brace && count == 0 {
// call it done!
to_remove.push(start..pos+1);
in_where = false;
}
}
found.last_mut().unwrap().push(tt.clone());
continue 'main_loop;
}
if found.len() == to_remove.len() {
found.push(Vec::new());
in_unsafe = false;
in_impl = false;
in_where = false;
in_path = false;
in_attr_cont = false;
in_generic = false;
has_injected_generics = false;
count = 0;
}
match tt {
&TokenTree::Ident(ref ident) => {
let formatted = format!("{}", ident);
if formatted == "unsafe" && !in_impl {
found.last_mut().unwrap().push(tt.clone());
if !in_attr_cont {
start = pos;
}
in_attr = false;
in_unsafe = true;
continue;
} else if formatted == "impl" && !in_impl {
if !in_attr_cont && !in_unsafe {
start = pos;
}
found.last_mut().unwrap().push(tt.clone());
in_unsafe = false;
in_attr = false;
in_impl = true;
continue;
} else if formatted == "trait" && in_impl {
// swallowed. doesn't go into found.
trait_span = Some(tt.span());
in_generic = true;
in_path = true;
in_impl = false;
has_injected_generics = false;
continue;
}
},
&TokenTree::Punct(ref punct) => {
if punct.as_char() == '#' && !in_attr {
found.last_mut().unwrap().push(tt.clone());
if !in_attr_cont {
start = pos;
}
in_attr = true;
continue;
}
}
&TokenTree::Group(ref group) => {
if group.delimiter() == Delimiter::Bracket && in_attr {
found.last_mut().unwrap().push(tt.clone());
in_attr = false;
in_attr_cont = true;
continue;
}
}
_ => {}
}
found.truncate(to_remove.len());
in_unsafe = false;
in_impl = false;
in_where = false;
in_path = false;
in_attr_cont = false;
in_generic = false;
has_injected_generics = false;
count = 0;
}
// must be iterated backwards
for range in to_remove.into_iter().rev() {
items.drain(range);
}
*group = proc_macro::Group::new(group.delimiter(), items.into_iter().collect());
group.set_span(span);
output.extend(found.into_iter().flatten());
}
drop(generics);
drop(target);
drop(where_bounds);
//eprintln!("attributes: {:#?}", attributes);
//eprintln!("OUTPUT: {:#?}", output);
//eprintln!("OUTPUT: {}", (&output).into_iter().cloned().collect::<TokenStream>());
attributes.into_iter().chain(output.into_iter()).collect()
}