summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--src/lib.rs12
-rw-r--r--tests/12-combining-generics.rs20
-rw-r--r--tests/13-trait-generics.rs20
-rw-r--r--tests/progress.rs2
5 files changed, 55 insertions, 1 deletions
diff --git a/Cargo.toml b/Cargo.toml
index af7d8ba..2657493 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "impl_trait"
-version = "0.1.3"
+version = "0.1.4"
 authors = ["SoniEx2 <endermoneymod@gmail.com>"]
 edition = "2018"
 description = "Allows impl trait inside inherent impl."
diff --git a/src/lib.rs b/src/lib.rs
index 2a3b90f..192d56b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -163,7 +163,11 @@ pub fn impl_trait(item: TokenStream) -> TokenStream {
                     has_injected_generics = true;
                     if generics_scratchpad.is_empty() {
                         found.last_mut().unwrap().extend(generics.clone());
+                    } else if generics.is_empty() {
+                        found.last_mut().unwrap().extend(generics_scratchpad.clone());
                     } else {
+                        // need to *combine* generics. this is not exactly trivial.
+                        // thankfully we don't need to worry about defaults on impls.
                         let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap();
                         let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap();
                         let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::<Vec<_>>();
@@ -180,6 +184,14 @@ pub fn impl_trait(item: TokenStream) -> TokenStream {
                                 (&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
                             }
                         });
+                        // just need to fix the one Pair::End in the middle of the thing.
+                        for item in &mut target {
+                            if matches!(item, syn::punctuated::Pair::End(_)) {
+                                let value = item.value().clone();
+                                *item = syn::punctuated::Pair::Punctuated(value, syn::token::Comma { spans: [trait_span.unwrap().into()] });
+                                break;
+                            }
+                        }
                         this_generics.params = target.into_iter().collect();
                         let new_generics = TokenStream::from(this_generics.into_token_stream());
                         found.last_mut().unwrap().extend(new_generics);
diff --git a/tests/12-combining-generics.rs b/tests/12-combining-generics.rs
new file mode 100644
index 0000000..0401887
--- /dev/null
+++ b/tests/12-combining-generics.rs
@@ -0,0 +1,20 @@
+// Checks that the impl trait can have its own generics.
+
+use impl_trait::impl_trait;
+
+struct Foo<T>(T);
+trait Bar<U> {
+}
+
+impl_trait! {
+    impl<T> Foo<T> {
+        impl trait<U> Bar<U> {
+        }
+    }
+}
+
+fn static_assert_1<T: Bar<U>, U>(_t: T, _u: U) {}
+
+fn main() {
+    static_assert_1(Foo(()), ());
+}
diff --git a/tests/13-trait-generics.rs b/tests/13-trait-generics.rs
new file mode 100644
index 0000000..889d168
--- /dev/null
+++ b/tests/13-trait-generics.rs
@@ -0,0 +1,20 @@
+// Checks that the impl trait can have generics without the inherent impl having generics.
+
+use impl_trait::impl_trait;
+
+struct Foo;
+trait Bar<U> {
+}
+
+impl_trait! {
+    impl Foo {
+        impl trait<U> Bar<U> {
+        }
+    }
+}
+
+fn static_assert_1<T: Bar<U>, U>(_t: T, _u: U) {}
+
+fn main() {
+    static_assert_1(Foo, ());
+}
diff --git a/tests/progress.rs b/tests/progress.rs
index 5a1bdb6..ef302b0 100644
--- a/tests/progress.rs
+++ b/tests/progress.rs
@@ -12,6 +12,8 @@ fn tests() {
     t.pass("tests/09-multiple-traits.rs");
     t.pass("tests/10-multiple-traits-with-generics.rs");
     t.pass("tests/11-traits-generics-docs.rs");
+    t.pass("tests/12-combining-generics.rs");
+    t.pass("tests/13-trait-generics.rs");
     t.pass("tests/98-readme.rs");
     t.pass("tests/99-goal.rs");
 }