summary refs log tree commit diff stats
path: root/src/vm/de.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/vm/de.rs')
-rw-r--r--src/vm/de.rs93
1 files changed, 70 insertions, 23 deletions
diff --git a/src/vm/de.rs b/src/vm/de.rs
index 471b541..85a24fb 100644
--- a/src/vm/de.rs
+++ b/src/vm/de.rs
@@ -153,6 +153,7 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
                     // tho first we set it as overstep because it has special
                     // handling.
                     frame.overstep = 1;
+                    frame.matches = false;
                     let mut at = index + 1;
                     while self.interp.frames[index].next() {
                         let op = self.interp.frames[index].raw_op();
@@ -180,10 +181,10 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
     }
 
     /// Steps the VM back into the previous operation.
-    fn step_out(
+    fn step_out<E: serde::de::Error>(
         &mut self,
         mut packs: Vec<Pack<'pat, 'de>>,
-    ) -> Vec<Pack<'pat, 'de>> {
+    ) -> Result<Vec<Pack<'pat, 'de>>, E> {
         // this code attempts to maintain the logical invariant of:
         // self.frames().iter_active().count() == packs.len()
         self.call_limit += 1;
@@ -212,7 +213,6 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
                     let mut count = 1;
                     let mut target = index;
                     let mut target_pack = pack_index;
-                    let mut target_unwound = false;
                     while count > 0 && target > 0 {
                         target -= 1;
                         if self.interp.frames[target].matches {
@@ -220,11 +220,10 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
                             target_pack -= 1;
                         }
                         match self.interp.frames[target].num_subtrees() {
-                            Some((num, unwound)) if num < count => {
+                            Some((num, _)) if num < count => {
                                 count -= num;
                             },
-                            Some((num, unwound)) => {
-                                target_unwound = unwound;
+                            Some((num, _)) => {
                                 count = 0;
                             },
                             None => {
@@ -233,32 +232,36 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
                         }
                     }
                     if count == 0 {
+                        // found target frame
                         let frame = self.interp.frames.remove(index);
                         let target_frame = &mut self.interp.frames[target];
-                        // FIXME check frame.matches vs frames[target].op()
-                        // FIXME actually test that this is correct
-                        let op = target_frame.raw_op();
+                        let (_, optional) = target_frame.value_subtree();
                         target_frame.prev().then(|| ()).unwrap();
-                        if !target_unwound {
-                            packs.insert(target_pack, Default::default());
-                            pack_index += 1;
-                            // FIXME this is VERY wrong
-                            target_frame.matches = true;
-                        }
                         if has_pack {
-                            // has parent frame
                             let pack = packs.remove(pack_index);
-                            packs[target_pack].merge_from(pack);
+                            if !target_frame.matches {
+                                packs.insert(target_pack, pack);
+                                target_frame.matches = true;
+                                pack_index += 1;
+                            } else {
+                                packs[target_pack].merge_from(pack);
+                            }
+                        } else {
+                            if !optional {
+                                self.interp.error.insert({
+                                    MatchError::ValidationError
+                                });
+                                return Err(E::custom("subtree failed"));
+                            }
                         }
                         if let Some((0, _)) = target_frame.num_subtrees() {
-                            //target_frame.prev().then(|| ()).unwrap();
                             target_frame.overstep = 0;
                         }
                     }
                 }
             }
         }
-        packs
+        Ok(packs)
     }
 }
 
@@ -380,7 +383,7 @@ where
             Type::Enum { name, variants } => {
                 deserializer.deserialize_enum(name, variants, &mut *self)
             },
-        }.map(|(packs, obj)| (self.step_out(packs), obj))
+        }.and_then(|(packs, obj)| Ok((self.step_out(packs)?, obj)))
     }
 }
 
@@ -842,7 +845,7 @@ where
             }
         }
         let obj = SerdeObject::Map(obj_inner);
-        let mut final_packs = self.step_out(output_packs);
+        let mut final_packs = self.step_out(output_packs)?;
         let mut iter_final_packs = 0..;
         self.frames_mut().iter_active_mut().for_each(|frame| {
             let ty = frame.get_type();
@@ -1015,6 +1018,7 @@ where
 mod tests {
     use super::Packer;
     use super::super::PatternConstants;
+
     use crate::vm::MAX_CALLS;
     use crate::vm::Interpreter;
     use crate::vm::Type;
@@ -1022,9 +1026,12 @@ mod tests {
     use crate::vm::PatternElement;
     use crate::vm::SerdeObject;
     use crate::vm::Frame;
-    use serde_json::Deserializer as JsonDeserializer;
+
     use postcard::Deserializer as PostcardDeserializer;
     use serde::de::DeserializeSeed as _;
+    use serde_json::Deserializer as JsonDeserializer;
+
+    use crate::errors::MatchError;
 
     #[test]
     #[should_panic]
@@ -1330,7 +1337,7 @@ mod tests {
         // use a parsed pattern with subtrees to test Packer
         // also test a non-self-describing format (postcard)
         let consts = crate::parser::parse::<&'static str, &'static str, ()>(
-            ":map(->['name'?]name:str)(->['value'?]value:u32)(->[:str]:?ignored_any)",
+            ":map(->['name'?]name:str)?(->['value'?]value:u32)?(->[:str]:?ignored_any)",
             None,
             None,
         ).unwrap();
@@ -1365,5 +1372,45 @@ mod tests {
         assert_eq!(pack.subpacks[0]["name"].1, SerdeObject::Str(From::from("a")));
         assert_eq!(pack.subpacks[1]["value"].1, SerdeObject::U32(1));
     }
+
+    #[test]
+    fn test_parser_subtrees_strict() {
+        // use a parsed pattern with subtrees to test Packer
+        // also test a non-self-describing format (postcard)
+        // also require at least one subtree to match on every iteration.
+        // (also this test fails)
+        let consts = crate::parser::parse::<&'static str, &'static str, ()>(
+            ":map((->['name'?]name:u32)?(->['value'?]value:u32)?)(->[:str]:u32)",
+            None,
+            None,
+        ).unwrap();
+        let data = &[
+            0x03, // map length (3)
+            0x04, // string length (4)
+            0x6E, 0x61, 0x6D, 0x65, // b'name'
+            0x01, // 1
+            0x05, // string length (5)
+            0x76, 0x61, 0x6C, 0x75, 0x65, // b'value'
+            0x01, // 1
+            0x05, // string length (5)
+            0x76, 0x65, 0x6C, 0x75, 0x65, // b'velue'
+            0x01, // 1
+        ];
+        let mut der = PostcardDeserializer::from_bytes(data);
+        let mut err = Default::default();
+        let mut frames = Default::default();
+        let interp = Interpreter::new(
+            &consts,
+            &mut err,
+            &mut frames,
+            //&mut output,
+        );
+        let result = Packer::new(
+            interp,
+            MAX_CALLS,
+        ).deserialize(&mut der);
+        assert!(matches!(err, Some(MatchError::ValidationError)));
+        assert!(result.is_err());
+    }
 }