summary refs log tree commit diff stats
path: root/src
diff options
context:
space:
mode:
authorSoniEx2 <endermoneymod@gmail.com>2021-03-22 20:50:59 -0300
committerSoniEx2 <endermoneymod@gmail.com>2021-03-22 20:52:47 -0300
commit5bb562ad496f9ec42ec59abfe31d3576266c6a6d (patch)
tree4c792fcc67d4211c6c5b4a0e18ebb832b0d4ce66 /src
Initial commit
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs344
1 files changed, 344 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..061345d
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,344 @@
+// impl_trait - Rust proc macro that significantly reduces boilerplate
+// Copyright (C) 2021  Soni L.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+extern crate proc_macro;
+use proc_macro::{TokenStream, TokenTree, Delimiter, Span};
+//use syn::parse::{Parse, ParseStream, Result as ParseResult};
+use syn::{Generics, GenericParam};
+use std::cmp::Ordering;
+use quote::ToTokens;
+
+#[proc_macro]
+#[allow(unreachable_code)]
+pub fn impl_trait(item: TokenStream) -> TokenStream {
+    //eprintln!("INPUT: {:#?}", item);
+    let mut output: Vec<_> = item.into_iter().collect();
+    let attributes: Vec<TokenTree> = {
+        let mut pos = 0;
+        let mut len = 0;
+        let mut in_attr = false;
+        while pos != output.len() {
+            let tt = &output[pos];
+            pos += 1;
+            match tt {
+                &TokenTree::Punct(ref punct) => {
+                    if punct.as_char() == '#' && !in_attr {
+                        in_attr = true;
+                        continue;
+                    }
+                }
+                &TokenTree::Group(ref group) => {
+                    if group.delimiter() == Delimiter::Bracket && in_attr {
+                        in_attr = false;
+                        len = pos;
+                        continue;
+                    }
+                }
+                _ => {}
+            }
+            break;
+        }
+        output.drain(0..len).collect()
+    };
+    //eprintln!("attributes: {:#?}", attributes);
+    // check for impl.
+    // unsafe impls are only available for traits and are automatically rejected.
+    'check_impl: loop { break {
+        if let &TokenTree::Ident(ref ident) = &output[0] {
+            if format!("{}", ident) == "impl" {
+                break 'check_impl;
+            }
+        }
+        panic!("impl_trait! may only be applied to inherent impls");
+    } }
+    let mut has_where: Option<&TokenTree> = None;
+    'check_no_for_before_where: loop { break {
+        for tt in &output {
+            if let &TokenTree::Ident(ref ident) = tt {
+                let formatted = format!("{}", ident);
+                if formatted == "where" {
+                    has_where = Some(tt);
+                    break 'check_no_for_before_where;
+                } else if formatted == "for" {
+                    panic!("impl_trait! may only be applied to inherent impls");
+                }
+            }
+        }
+    } }
+    // this is the "where [...]" part, including the "where".
+    let mut where_bounds = Vec::new();
+    if let Some(where_in) = has_where {
+        where_bounds = output.split_last().unwrap().1.into_iter().skip_while(|&tt| {
+            !std::ptr::eq(tt, where_in)
+        }).cloned().collect();
+    }
+    let where_bounds = where_bounds;
+    drop(has_where);
+    let mut count = 0;
+    // this is the "<...>" part, immediately after the "impl", and including the "<>".
+    let generics = output.split_first().unwrap().1.into_iter().take_while(|&tt| {
+        let mut result = count > 0;
+        if let &TokenTree::Punct(ref punct) = tt {
+            let c = punct.as_char();
+            if c == '<' {
+                count += 1;
+                result = true;
+            } else if c == '>' {
+                count -= 1;
+            }
+        }
+        result
+    }).cloned().collect::<Vec<_>>();
+    // so how do you find the target? well...
+    // "impl" + [_; generics.len()] + [_; target.len()] + [_; where_bounds.len()] + "{}"
+    // we have generics and where_bounds, and the total, so we can easily find target!
+    let target_start = 1 + generics.len();
+    let target_end = output.len() - 1 - where_bounds.len();
+    let target_range = target_start..target_end;
+    let target = (&output[target_range]).into_iter().cloned().collect::<Vec<_>>();
+    //eprintln!("generics: {:#?}", generics);
+    //eprintln!("target: {:#?}", target);
+    //eprintln!("where_bounds: {:#?}", where_bounds);
+    let items = output.last_mut();
+    if let &mut TokenTree::Group(ref mut group) = items.unwrap() {
+        // TODO: parse "[unsafe] impl trait" somehow. use syn for it maybe (after swallowing the "trait")
+        // luckily for us, there's only one thing that can come after an impl trait: a path
+        // (and optional generics).
+        // but we can't figure out how to parse the `where`.
+        //todo!();
+        let span = group.span();
+        let mut items = group.stream().into_iter().collect::<Vec<_>>();
+        let mut in_unsafe = false;
+        let mut in_impl = false;
+        let mut in_path = false;
+        let mut in_generic = false;
+        let mut in_attr = false;
+        let mut in_attr_cont = false;
+        let mut has_injected_generics = false;
+        let mut in_where = false;
+        let mut start = 0;
+        let mut found: Vec<Vec<TokenTree>> = Vec::new();
+        let mut to_remove: Vec<std::ops::Range<usize>> = Vec::new();
+        let mut generics_scratchpad = Vec::new();
+        let mut count = 0;
+        let mut trait_span: Option<Span> = None;
+        'main_loop: for (pos, tt) in (&items).into_iter().enumerate() {
+            if in_generic {
+                // collect the generics
+                let mut result = count > 0;
+                if let &TokenTree::Punct(ref punct) = tt {
+                    let c = punct.as_char();
+                    if c == '<' {
+                        count += 1;
+                        result = true;
+                    } else if c == '>' {
+                        count -= 1;
+                        if count == 0 {
+                            in_generic = false;
+                            in_path = true;
+                        }
+                    }
+                }
+                if result {
+                    generics_scratchpad.push(tt.clone());
+                    continue;
+                }
+            }
+            if in_path {
+                // inject the generics
+                if !has_injected_generics {
+                    has_injected_generics = true;
+                    if generics_scratchpad.is_empty() {
+                        found.last_mut().unwrap().extend(generics.clone());
+                    } else {
+                        let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap();
+                        let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap();
+                        let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::<Vec<_>>();
+                        target.sort_by(|a, b| {
+                            match (a.value(), b.value()) {
+                                (&GenericParam::Lifetime(_), &GenericParam::Const(_)) => Ordering::Less,
+                                (&GenericParam::Type(_), &GenericParam::Const(_)) => Ordering::Less,
+                                (&GenericParam::Lifetime(_), &GenericParam::Type(_)) => Ordering::Less,
+                                (&GenericParam::Lifetime(_), &GenericParam::Lifetime(_)) => Ordering::Equal,
+                                (&GenericParam::Type(_), &GenericParam::Type(_)) => Ordering::Equal,
+                                (&GenericParam::Const(_), &GenericParam::Const(_)) => Ordering::Equal,
+                                (&GenericParam::Type(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
+                                (&GenericParam::Const(_), &GenericParam::Type(_)) => Ordering::Greater,
+                                (&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
+                            }
+                        });
+                        this_generics.params = target.into_iter().collect();
+                        let new_generics = TokenStream::from(this_generics.into_token_stream());
+                        found.last_mut().unwrap().extend(new_generics);
+                    }
+                }
+                in_generic = false;
+                if let &TokenTree::Ident(ref ident) = tt {
+                    let formatted = format!("{}", ident);
+                    if count == 0 && formatted == "where" {
+                        in_path = false;
+                        in_where = true;
+                        // add "for"
+                        found.last_mut().unwrap().push(proc_macro::Ident::new("for", trait_span.unwrap()).into());
+                        // add Target
+                        found.last_mut().unwrap().extend(target.clone());
+                        // *then* add the "where" (from the impl-trait)
+                        found.last_mut().unwrap().push(tt.clone());
+                        // and the parent bounds (except the "where")
+                        found.last_mut().unwrap().extend((&where_bounds).into_iter().skip(1).cloned());
+                        // also make sure that there's an ',' at the correct place
+                        if let Some(&TokenTree::Punct(ref x)) = where_bounds.last() {
+                            if x.as_char() == ',' {
+                                continue 'main_loop;
+                            }
+                        }
+                        found.last_mut().unwrap().push(proc_macro::Punct::new(',', proc_macro::Spacing::Alone).into());
+                        continue 'main_loop;
+                    }
+                }
+                if let &TokenTree::Punct(ref punct) = tt {
+                    let c = punct.as_char();
+                    if c == '<' {
+                        count += 1;
+                    } else if c == '>' {
+                        // this is broken so just give up
+                        // FIXME better error handling
+                        if count == 0 {
+                            in_path = false;
+                            continue 'main_loop;
+                        }
+                        count -= 1;
+                    }
+                }
+                if let &TokenTree::Group(ref group) = tt {
+                    if group.delimiter() == Delimiter::Brace && count == 0 {
+                        to_remove.push(start..pos+1);
+                        // add "for"
+                        found.last_mut().unwrap().push(proc_macro::Ident::new("for", tt.span()).into());
+                        // add Target
+                        found.last_mut().unwrap().extend(target.clone());
+                        // and the parent bounds (including the "where")
+                        found.last_mut().unwrap().extend(where_bounds.clone());
+                        in_path = false;
+                        in_where = false;
+                        // fall through to add the block
+                    }
+                }
+                found.last_mut().unwrap().push(tt.clone());
+                continue 'main_loop;
+            }
+            if in_where {
+                // just try to find the block, and add all the stuff.
+                if let &TokenTree::Punct(ref punct) = tt {
+                    let c = punct.as_char();
+                    if c == '<' {
+                        count += 1;
+                    } else if c == '>' {
+                        // this is broken so just give up
+                        // FIXME better error handling
+                        if count == 0 {
+                            in_where = false;
+                            continue 'main_loop;
+                        }
+                        count -= 1;
+                    }
+                }
+                if let &TokenTree::Group(ref group) = tt {
+                    if group.delimiter() == Delimiter::Brace && count == 0 {
+                        // call it done!
+                        to_remove.push(start..pos+1);
+                        in_where = false;
+                    }
+                }
+                found.last_mut().unwrap().push(tt.clone());
+                continue 'main_loop;
+            }
+            if found.len() == to_remove.len() {
+                found.push(Vec::new());
+            }
+            match tt {
+                &TokenTree::Ident(ref ident) => {
+                    let formatted = format!("{}", ident);
+                    if formatted == "unsafe" && !in_impl {
+                        found.last_mut().unwrap().push(tt.clone());
+                        if !in_attr_cont {
+                            start = pos;
+                        }
+                        in_attr = false;
+                        in_unsafe = true;
+                        continue;
+                    } else if formatted == "impl" && !in_impl {
+                        if !in_attr_cont && !in_unsafe {
+                            start = pos;
+                        }
+                        found.last_mut().unwrap().push(tt.clone());
+                        in_unsafe = false;
+                        in_attr = false;
+                        in_impl = true;
+                        continue;
+                    } else if formatted == "trait" && in_impl {
+                        // swallowed. doesn't go into found.
+                        trait_span = Some(tt.span());
+                        in_generic = true;
+                        in_path = true;
+                        continue;
+                    }
+                },
+                &TokenTree::Punct(ref punct) => {
+                    if punct.as_char() == '#' && !in_attr {
+                        found.last_mut().unwrap().push(tt.clone());
+                        if !in_attr_cont {
+                            start = pos;
+                        }
+                        in_attr = true;
+                        continue;
+                    }
+                }
+                &TokenTree::Group(ref group) => {
+                    if group.delimiter() == Delimiter::Bracket && in_attr {
+                        found.last_mut().unwrap().push(tt.clone());
+                        in_attr = false;
+                        in_attr_cont = true;
+                        continue;
+                    }
+                }
+                _ => {}
+            }
+            found.truncate(to_remove.len());
+            in_unsafe = false;
+            in_impl = false;
+            in_where = false;
+            in_path = false;
+            in_attr_cont = false;
+            in_generic = false;
+            has_injected_generics = false;
+            count = 0;
+        }
+        // must be iterated backwards
+        for range in to_remove.into_iter().rev() {
+            items.drain(range);
+        }
+        *group = proc_macro::Group::new(group.delimiter(), items.into_iter().collect());
+        group.set_span(span);
+        output.extend(found.into_iter().flatten());
+    }
+    drop(generics);
+    drop(target);
+    drop(where_bounds);
+    //eprintln!("attributes: {:#?}", attributes);
+    //eprintln!("OUTPUT: {:#?}", output);
+    attributes.into_iter().chain(output.into_iter()).collect()
+}