From 5bb562ad496f9ec42ec59abfe31d3576266c6a6d Mon Sep 17 00:00:00 2001 From: SoniEx2 Date: Mon, 22 Mar 2021 20:50:59 -0300 Subject: Initial commit --- src/lib.rs | 344 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 src/lib.rs (limited to 'src/lib.rs') diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..061345d --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,344 @@ +// impl_trait - Rust proc macro that significantly reduces boilerplate +// Copyright (C) 2021 Soni L. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +extern crate proc_macro; +use proc_macro::{TokenStream, TokenTree, Delimiter, Span}; +//use syn::parse::{Parse, ParseStream, Result as ParseResult}; +use syn::{Generics, GenericParam}; +use std::cmp::Ordering; +use quote::ToTokens; + +#[proc_macro] +#[allow(unreachable_code)] +pub fn impl_trait(item: TokenStream) -> TokenStream { + //eprintln!("INPUT: {:#?}", item); + let mut output: Vec<_> = item.into_iter().collect(); + let attributes: Vec = { + let mut pos = 0; + let mut len = 0; + let mut in_attr = false; + while pos != output.len() { + let tt = &output[pos]; + pos += 1; + match tt { + &TokenTree::Punct(ref punct) => { + if punct.as_char() == '#' && !in_attr { + in_attr = true; + continue; + } + } + &TokenTree::Group(ref group) => { + if group.delimiter() == Delimiter::Bracket && in_attr { + in_attr = false; + len = pos; + continue; + } + } + _ => {} + } + break; + } + output.drain(0..len).collect() + }; + //eprintln!("attributes: {:#?}", attributes); + // check for impl. + // unsafe impls are only available for traits and are automatically rejected. + 'check_impl: loop { break { + if let &TokenTree::Ident(ref ident) = &output[0] { + if format!("{}", ident) == "impl" { + break 'check_impl; + } + } + panic!("impl_trait! may only be applied to inherent impls"); + } } + let mut has_where: Option<&TokenTree> = None; + 'check_no_for_before_where: loop { break { + for tt in &output { + if let &TokenTree::Ident(ref ident) = tt { + let formatted = format!("{}", ident); + if formatted == "where" { + has_where = Some(tt); + break 'check_no_for_before_where; + } else if formatted == "for" { + panic!("impl_trait! may only be applied to inherent impls"); + } + } + } + } } + // this is the "where [...]" part, including the "where". + let mut where_bounds = Vec::new(); + if let Some(where_in) = has_where { + where_bounds = output.split_last().unwrap().1.into_iter().skip_while(|&tt| { + !std::ptr::eq(tt, where_in) + }).cloned().collect(); + } + let where_bounds = where_bounds; + drop(has_where); + let mut count = 0; + // this is the "<...>" part, immediately after the "impl", and including the "<>". + let generics = output.split_first().unwrap().1.into_iter().take_while(|&tt| { + let mut result = count > 0; + if let &TokenTree::Punct(ref punct) = tt { + let c = punct.as_char(); + if c == '<' { + count += 1; + result = true; + } else if c == '>' { + count -= 1; + } + } + result + }).cloned().collect::>(); + // so how do you find the target? well... + // "impl" + [_; generics.len()] + [_; target.len()] + [_; where_bounds.len()] + "{}" + // we have generics and where_bounds, and the total, so we can easily find target! + let target_start = 1 + generics.len(); + let target_end = output.len() - 1 - where_bounds.len(); + let target_range = target_start..target_end; + let target = (&output[target_range]).into_iter().cloned().collect::>(); + //eprintln!("generics: {:#?}", generics); + //eprintln!("target: {:#?}", target); + //eprintln!("where_bounds: {:#?}", where_bounds); + let items = output.last_mut(); + if let &mut TokenTree::Group(ref mut group) = items.unwrap() { + // TODO: parse "[unsafe] impl trait" somehow. use syn for it maybe (after swallowing the "trait") + // luckily for us, there's only one thing that can come after an impl trait: a path + // (and optional generics). + // but we can't figure out how to parse the `where`. + //todo!(); + let span = group.span(); + let mut items = group.stream().into_iter().collect::>(); + let mut in_unsafe = false; + let mut in_impl = false; + let mut in_path = false; + let mut in_generic = false; + let mut in_attr = false; + let mut in_attr_cont = false; + let mut has_injected_generics = false; + let mut in_where = false; + let mut start = 0; + let mut found: Vec> = Vec::new(); + let mut to_remove: Vec> = Vec::new(); + let mut generics_scratchpad = Vec::new(); + let mut count = 0; + let mut trait_span: Option = None; + 'main_loop: for (pos, tt) in (&items).into_iter().enumerate() { + if in_generic { + // collect the generics + let mut result = count > 0; + if let &TokenTree::Punct(ref punct) = tt { + let c = punct.as_char(); + if c == '<' { + count += 1; + result = true; + } else if c == '>' { + count -= 1; + if count == 0 { + in_generic = false; + in_path = true; + } + } + } + if result { + generics_scratchpad.push(tt.clone()); + continue; + } + } + if in_path { + // inject the generics + if !has_injected_generics { + has_injected_generics = true; + if generics_scratchpad.is_empty() { + found.last_mut().unwrap().extend(generics.clone()); + } else { + let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap(); + let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap(); + let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::>(); + target.sort_by(|a, b| { + match (a.value(), b.value()) { + (&GenericParam::Lifetime(_), &GenericParam::Const(_)) => Ordering::Less, + (&GenericParam::Type(_), &GenericParam::Const(_)) => Ordering::Less, + (&GenericParam::Lifetime(_), &GenericParam::Type(_)) => Ordering::Less, + (&GenericParam::Lifetime(_), &GenericParam::Lifetime(_)) => Ordering::Equal, + (&GenericParam::Type(_), &GenericParam::Type(_)) => Ordering::Equal, + (&GenericParam::Const(_), &GenericParam::Const(_)) => Ordering::Equal, + (&GenericParam::Type(_), &GenericParam::Lifetime(_)) => Ordering::Greater, + (&GenericParam::Const(_), &GenericParam::Type(_)) => Ordering::Greater, + (&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater, + } + }); + this_generics.params = target.into_iter().collect(); + let new_generics = TokenStream::from(this_generics.into_token_stream()); + found.last_mut().unwrap().extend(new_generics); + } + } + in_generic = false; + if let &TokenTree::Ident(ref ident) = tt { + let formatted = format!("{}", ident); + if count == 0 && formatted == "where" { + in_path = false; + in_where = true; + // add "for" + found.last_mut().unwrap().push(proc_macro::Ident::new("for", trait_span.unwrap()).into()); + // add Target + found.last_mut().unwrap().extend(target.clone()); + // *then* add the "where" (from the impl-trait) + found.last_mut().unwrap().push(tt.clone()); + // and the parent bounds (except the "where") + found.last_mut().unwrap().extend((&where_bounds).into_iter().skip(1).cloned()); + // also make sure that there's an ',' at the correct place + if let Some(&TokenTree::Punct(ref x)) = where_bounds.last() { + if x.as_char() == ',' { + continue 'main_loop; + } + } + found.last_mut().unwrap().push(proc_macro::Punct::new(',', proc_macro::Spacing::Alone).into()); + continue 'main_loop; + } + } + if let &TokenTree::Punct(ref punct) = tt { + let c = punct.as_char(); + if c == '<' { + count += 1; + } else if c == '>' { + // this is broken so just give up + // FIXME better error handling + if count == 0 { + in_path = false; + continue 'main_loop; + } + count -= 1; + } + } + if let &TokenTree::Group(ref group) = tt { + if group.delimiter() == Delimiter::Brace && count == 0 { + to_remove.push(start..pos+1); + // add "for" + found.last_mut().unwrap().push(proc_macro::Ident::new("for", tt.span()).into()); + // add Target + found.last_mut().unwrap().extend(target.clone()); + // and the parent bounds (including the "where") + found.last_mut().unwrap().extend(where_bounds.clone()); + in_path = false; + in_where = false; + // fall through to add the block + } + } + found.last_mut().unwrap().push(tt.clone()); + continue 'main_loop; + } + if in_where { + // just try to find the block, and add all the stuff. + if let &TokenTree::Punct(ref punct) = tt { + let c = punct.as_char(); + if c == '<' { + count += 1; + } else if c == '>' { + // this is broken so just give up + // FIXME better error handling + if count == 0 { + in_where = false; + continue 'main_loop; + } + count -= 1; + } + } + if let &TokenTree::Group(ref group) = tt { + if group.delimiter() == Delimiter::Brace && count == 0 { + // call it done! + to_remove.push(start..pos+1); + in_where = false; + } + } + found.last_mut().unwrap().push(tt.clone()); + continue 'main_loop; + } + if found.len() == to_remove.len() { + found.push(Vec::new()); + } + match tt { + &TokenTree::Ident(ref ident) => { + let formatted = format!("{}", ident); + if formatted == "unsafe" && !in_impl { + found.last_mut().unwrap().push(tt.clone()); + if !in_attr_cont { + start = pos; + } + in_attr = false; + in_unsafe = true; + continue; + } else if formatted == "impl" && !in_impl { + if !in_attr_cont && !in_unsafe { + start = pos; + } + found.last_mut().unwrap().push(tt.clone()); + in_unsafe = false; + in_attr = false; + in_impl = true; + continue; + } else if formatted == "trait" && in_impl { + // swallowed. doesn't go into found. + trait_span = Some(tt.span()); + in_generic = true; + in_path = true; + continue; + } + }, + &TokenTree::Punct(ref punct) => { + if punct.as_char() == '#' && !in_attr { + found.last_mut().unwrap().push(tt.clone()); + if !in_attr_cont { + start = pos; + } + in_attr = true; + continue; + } + } + &TokenTree::Group(ref group) => { + if group.delimiter() == Delimiter::Bracket && in_attr { + found.last_mut().unwrap().push(tt.clone()); + in_attr = false; + in_attr_cont = true; + continue; + } + } + _ => {} + } + found.truncate(to_remove.len()); + in_unsafe = false; + in_impl = false; + in_where = false; + in_path = false; + in_attr_cont = false; + in_generic = false; + has_injected_generics = false; + count = 0; + } + // must be iterated backwards + for range in to_remove.into_iter().rev() { + items.drain(range); + } + *group = proc_macro::Group::new(group.delimiter(), items.into_iter().collect()); + group.set_span(span); + output.extend(found.into_iter().flatten()); + } + drop(generics); + drop(target); + drop(where_bounds); + //eprintln!("attributes: {:#?}", attributes); + //eprintln!("OUTPUT: {:#?}", output); + attributes.into_iter().chain(output.into_iter()).collect() +} -- cgit 1.4.1