summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--abdl/_vm.py40
-rw-r--r--testing/test_abdl.py48
2 files changed, 77 insertions, 11 deletions
diff --git a/abdl/_vm.py b/abdl/_vm.py
index 0ec1018..5ab7efb 100644
--- a/abdl/_vm.py
+++ b/abdl/_vm.py
@@ -41,6 +41,14 @@ class PatternElement:
         """
         raise RuntimeError(self)
 
+    def on_end(self, frame, path, defs, in_key):
+        """Called when the pattern has reached the end.
+
+        Returns the new value of in_key and a dict to be yielded, or
+        None and a dict to be yielded.
+        """
+        raise RuntimeError(self)
+
     def collect_params(self, res: list):
         """Appends parameter names used in this pattern to ``res``.
         """
@@ -245,6 +253,10 @@ class ApplyPredicate(PatternElement):
     def collect_params(self, res: list):
         res.append(self.key)
 
+    def on_end(self, frame, path, defs, in_key):
+        assert not in_key
+        raise NotImplementedError
+
 class End(PatternElement):
     """Pseudo-token, used to advance iteration."""
 
@@ -260,6 +272,19 @@ class End(PatternElement):
                 path.clear()
         return True
 
+    def on_end(self, frame, path, defs, in_key):
+        assert not path[-1].empty
+        res = {}
+        for holder in path:
+            if holder.subtree:
+                for name, pair in holder.match.items():
+                    res[name] = pair
+            elif holder.name is not None:
+                res[holder.name] = (holder.match, holder.value)
+        if not frame.prev():
+            return (None, res)
+        return (True, res)
+
     @classmethod
     def action(cls, toks):
         return [cls()]
@@ -356,24 +381,17 @@ def match_helper(ops, defs, tree):
     """
 
     frame = _Frame(ops)
+    if not len(frame.ops): # no ops?
+        return # do nothing
 
     path = [Holder(value=tree, parent=None, iterator=iter(()))]
     in_key = False
     while path:
         if not frame.next():
-            assert not path[-1].empty
-            res = {}
-            for holder in path:
-                if holder.subtree:
-                    for name, pair in holder.match.items():
-                        res[name] = pair
-                elif holder.name is not None:
-                    res[holder.name] = (holder.match, holder.value)
+            in_key, res = frame.current_op.on_end(frame, path, defs, in_key)
             yield res
-            assert len(path) == 1 or isinstance(frame.current_op, End)
-            if not frame.prev():
+            if in_key is None:
                 return
-            in_key = True
         else:
             if in_key:
                 in_key = frame.current_op.on_in_key(frame, path, defs)
diff --git a/testing/test_abdl.py b/testing/test_abdl.py
index d951c7b..65549ed 100644
--- a/testing/test_abdl.py
+++ b/testing/test_abdl.py
@@ -31,21 +31,33 @@ class LogAndCompare:
         self._itr = right
         self.left = []
         self.right = []
+        self.done = False
     def __iter__(self):
         return self
     def __next__(self):
+        if self.done:
+            raise StopIteration
         try:
             left = next(self._itl)
         except abdl.ValidationError as e:
             e.tb = traceback.format_exc()
             left = e
+        except StopIteration as e:
+            e.tb = traceback.format_exc()
+            left = e
         try:
             right = next(self._itr)
         except abdl.ValidationError as e:
             e.tb = traceback.format_exc()
             right = e
+        except StopIteration as e:
+            e.tb = traceback.format_exc()
+            right = e
         self.left.append(left)
         self.right.append(right)
+        if StopIteration in (type(left), type(right)):
+            self.done = True
+            return (type(left), type(right)) == (StopIteration,)*2
         return left == right or (type(left), type(right)) == (abdl.ValidationError,)*2
     def __repr__(self):
         return "LogAndCompare(left=" + repr(self.left) + ", right=" + repr(self.right) + ")"
@@ -159,6 +171,15 @@ def test_multi_type_with_validation_errors(foo, pat):
                 raise abdl.ValidationError
     assert all(LogAndCompare(pat.match(foo), deep(foo)))
 
+defs = {'a': (dict, list, set), 'b': (dict, set), 'c': dict}
+@hypothesis.given(objtree, st.just(abdl.compile("->X:?$a:?$b:?$c", defs=defs)))
+def test_multi_type_at_end(foo, pat):
+    def deep(foo):
+        for x in pairs(foo):
+            if isinstance(x[1], dict):
+                    yield {"X": x}
+    assert all(LogAndCompare(pat.match(foo), deep(foo)))
+
 @hypothesis.given(st.dictionaries(st.frozensets(st.text()), st.text()), st.just(abdl.compile("->[:?$sets->A]->D", {'sets': collections.abc.Set})))
 def test_subtree_partial(foo, pat):
     def deep(foo):
@@ -236,6 +257,33 @@ def test_key_predicate(foo, pat):
                     yield {"D": d}
     assert all(LogAndCompare(pat.match(foo), deep(foo)))
 
+@hypothesis.given(objtree, st.just(abdl.compile("->[:?$sets:?$sets->V]->D", {'sets': collections.abc.Set})))
+def test_multi_key_predicate_with_values(foo, pat):
+    def deep(foo):
+        for x in pairs(foo):
+            if isinstance(x[0], collections.abc.Set):
+                if isinstance(x[0], collections.abc.Set):
+                    for v in pairs(x[0]):
+                        for d in pairs(x[1]):
+                            yield {"V": v, "D": d}
+    assert all(LogAndCompare(pat.match(foo), deep(foo)))
+
+@hypothesis.given(objtree, st.just(abdl.compile("->[:?$sets:?$sets]->D", {'sets': collections.abc.Set})))
+def test_multi_key_predicate(foo, pat):
+    def deep(foo):
+        for x in pairs(foo):
+            if isinstance(x[0], collections.abc.Set):
+                if isinstance(x[0], collections.abc.Set):
+                    for d in pairs(x[1]):
+                        yield {"D": d}
+    assert all(LogAndCompare(pat.match(foo), deep(foo)))
+
+@hypothesis.given(objtree, st.just(abdl.compile("")))
+def test_empty(foo, pat):
+    def deep(foo):
+        yield from ()
+    assert all(LogAndCompare(pat.match(foo), deep(foo)))
+
 # FIXME
 #@hypothesis.given(objtree, st.text())
 #def test_exhaustive(foo, pat):