summary refs log tree commit diff stats
path: root/src/vm/de.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/vm/de.rs')
-rw-r--r--src/vm/de.rs303
1 files changed, 212 insertions, 91 deletions
diff --git a/src/vm/de.rs b/src/vm/de.rs
index 9583962..c906226 100644
--- a/src/vm/de.rs
+++ b/src/vm/de.rs
@@ -95,7 +95,13 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
     }
 
     /// Steps the VM into the next operation.
-    fn step_in(&mut self) {
+    fn step_in<E: serde::de::Error>(&mut self) -> Result<(), E> {
+        if self.call_limit > 0 {
+            self.call_limit -= 1;
+        } else {
+            self.interp.error.insert(crate::errors::MatchError::StackOverflow);
+            return Err(todo!());
+        }
         // iterate up to the *live* length (i.e. the loop is allowed to modify
         // the length).
         // NOTE: we need to use while-let so as to not borrow anything in an
@@ -148,6 +154,7 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
                 }
             }
         }
+        Ok(())
     }
 
     /// Steps the VM back into the previous operation.
@@ -155,6 +162,7 @@ impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
         &mut self,
         mut packs: Vec<Pack<'pat, 'de>>,
     ) -> Vec<Pack<'pat, 'de>> {
+        self.call_limit += 1;
         let mut index_iter = 0..;
         while let Some(index) = index_iter.next().filter(|&i| {
             i < self.interp.frames.len()
@@ -212,7 +220,7 @@ where
     where
         D: serde::Deserializer<'de>
     {
-        self.step_in();
+        if let Err(e) = self.step_in() { return Err(e); }
         let pat = self.interp.pat;
         let target_type = self.frames().iter_active().fold(
             Type::IgnoredAny,
@@ -300,32 +308,48 @@ where
 }
 
 /// visit method generator for simple values (primitives).
+///
+/// can generate whole function or just the glue.
 macro_rules! vs {
-    ($visit:ident $obj:ident $t:ty) => {
-        fn $visit<E>(mut self, v: $t) -> Result<Self::Value, E>
+    (fn $visit:ident $obj:ident ($data_type:pat) $rust_type:ty) => {
+        fn $visit<E>(mut self, v: $rust_type) -> Result<Self::Value, E>
         where
             E: serde::de::Error,
         {
-            // FIXME filtering/errors
-            let pat = self.interp.pat;
+            vs!(self (v) $obj ($data_type))
+        }
+    };
+    ($this:ident $v:tt $obj:ident ($data_type:pat)) => {
+        {
+            let pat = $this.interp.pat;
             let mut obj = None;
-            if self.collecting {
-                obj = Some(SerdeObject::$obj(v));
+            if $this.collecting {
+                obj = Some(SerdeObject::$obj$v);
             }
             let mut packs = Vec::new();
-            self.frames_mut().iter_active_mut().try_for_each(|frame| {
+            $this.frames_mut().iter_active_mut().try_for_each(|frame| {
+                let ty = frame.get_type(pat);
+                match ty {
+                    | Some(($data_type, _))
+                    | Some((Type::Any, _))
+                    | Some((Type::IgnoredAny, _))
+                    => {},
+                    Some((_, false)) => todo!(),
+                    Some((_, true)) => return Err(todo!()),
+                    None => unreachable!(),
+                }
                 let mut pack = Pack::default();
-                let mut map = IndexMap::new();
                 if let Some(name) = frame.get_name(pat) {
-                    map.insert(name, (Pack::default(), SerdeObject::$obj(v)));
+                    let mut map = IndexMap::new();
+                    map.insert(name, (Pack::default(), SerdeObject::$obj$v));
+                    pack.subpacks.push(map);
                 }
-                pack.subpacks.push(map);
                 packs.push(pack);
                 Ok(())
             })?;
             Ok((packs, obj))
         }
-    }
+    };
 }
 
 impl<'pat, 'state, 'de, O> serde::de::Visitor<'de>
@@ -338,119 +362,77 @@ where
         write!(f, "unsure")
     }
 
-    vs!(visit_bool Bool bool);
-    vs!(visit_i8 I8 i8);
-    vs!(visit_i16 I16 i16);
-    vs!(visit_i32 I32 i32);
-    vs!(visit_i64 I64 i64);
-    vs!(visit_i128 I128 i128);
-    vs!(visit_u8 U8 u8);
-    vs!(visit_u16 U16 u16);
-    vs!(visit_u32 U32 u32);
-    vs!(visit_u64 U64 u64);
-    vs!(visit_u128 U128 u128);
-    vs!(visit_f32 F32 f32);
-    vs!(visit_f64 F64 f64);
-    vs!(visit_char Char char);
+    vs!(fn visit_bool Bool (Type::Bool) bool);
+    vs!(fn visit_i8 I8 (Type::I8) i8);
+    vs!(fn visit_i16 I16 (Type::I16) i16);
+    vs!(fn visit_i32 I32 (Type::I32) i32);
+    vs!(fn visit_i64 I64 (Type::I64) i64);
+    vs!(fn visit_i128 I128 (Type::I128) i128);
+    vs!(fn visit_u8 U8 (Type::U8) u8);
+    vs!(fn visit_u16 U16 (Type::U16) u16);
+    vs!(fn visit_u32 U32 (Type::U32) u32);
+    vs!(fn visit_u64 U64 (Type::U64) u64);
+    vs!(fn visit_u128 U128 (Type::U128) u128);
+    vs!(fn visit_f32 F32 (Type::F32) f32);
+    vs!(fn visit_f64 F64 (Type::F64) f64);
+    vs!(fn visit_char Char (Type::Char) char);
 
     fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Str(Cow::Owned(String::from(v))));
-        }
-        todo!()
+        // no real option but to clone.
+        vs!(self (Cow::Owned(v.to_owned())) Str (Type::String | Type::Str))
     }
     fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Str(Cow::Borrowed(v)));
-        }
-        todo!()
+        vs!(self (Cow::Borrowed(v)) Str (Type::String | Type::Str))
     }
     fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Str(Cow::Owned(v)));
-        }
-        todo!()
+        // TODO try to avoid cloning
+        self.visit_str(&*v)
     }
     fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Bytes(Cow::Owned(Vec::from(v))));
-        }
-        todo!()
+        vs!(self (Cow::Owned(v.to_owned())) Bytes (Type::Bytes | Type::ByteBuf))
     }
     fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Bytes(Cow::Borrowed(v)));
-        }
-        todo!()
+        vs!(self (Cow::Borrowed(v)) Bytes (Type::Bytes | Type::ByteBuf))
     }
     fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Bytes(Cow::Owned(v)));
-        }
-        todo!()
+        // TODO try to avoid cloning
+        self.visit_byte_buf(&*v)
     }
     fn visit_none<E>(self) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::None);
-        }
-        todo!()
+        vs!(self {} None (Type::Option))
     }
     fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
     where
         D: serde::de::Deserializer<'de>,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Some(todo!()));
-        }
         todo!()
     }
     fn visit_unit<E>(self) -> Result<Self::Value, E>
     where
         E: serde::de::Error,
     {
-        // FIXME subtrees
-        let mut obj = None;
-        let mut packs = Vec::new();
-        let mut pack = Pack::default();
-        if self.collecting {
-            obj = Some(SerdeObject::Unit);
-        }
-        let mut map = IndexMap::new();
-        //for name in self.get_name() {
-        //    map.insert(name, (Default::default(), SerdeObject::Unit));
-        //}
-        pack.subpacks.push(map);
-        packs.push(pack);
-        Ok((packs, obj))
+        vs!(self {} Unit (Type::Unit))
     }
     fn visit_newtype_struct<D>(
         self,
@@ -459,10 +441,6 @@ where
     where
         D: serde::de::Deserializer<'de>,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::NewtypeStruct(todo!()));
-        }
         todo!()
     }
     fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
@@ -475,15 +453,84 @@ where
         }
         todo!()
     }
-    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
+    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
     where
         A: serde::de::MapAccess<'de>,
     {
-        let mut obj = None;
-        if self.collecting {
-            obj = Some(SerdeObject::Map(Vec::new()));
+        let old_collecting = self.collecting;
+        let pat = self.interp.pat;
+        let mut collecting = old_collecting;
+        self.frames_mut().iter_active_mut().try_for_each(|frame| {
+            let ty = frame.get_type(pat);
+            match ty {
+                | Some((Type::Map, _))
+                | Some((Type::Any, _))
+                | Some((Type::IgnoredAny, _))
+                => {},
+                Some((_, false)) => {
+                    frame.matches = false;
+                    todo!()
+                },
+                Some((_, true)) => return Err(todo!()),
+                None => unreachable!(),
+            }
+            if frame.get_name(pat).is_some() {
+                collecting = true;
+            }
+            Ok(())
+        })?;
+        if let Err(e) = self.step_in() { return Err(e); }
+        self.collecting = collecting;
+        let mut subframes = Vec::new();
+        self.frames().iter_active().for_each(|frame| {
+            if let PatternElement::Tag { key_subtree } = frame.op() {
+                if let Some(key_subtree) = key_subtree {
+                    subframes.push(Frame {
+                        ops: &pat.protos[key_subtree],
+                        iar: None,
+                        overstep: 0,
+                        matches: true,
+                    });
+                }
+            } else {
+                unreachable!()
+            }
+        });
+        let mut obj_inner = Vec::new();
+        while let Some(packed_key) = {
+            let subinterp = Interpreter {
+                pat: pat,
+                frames: &mut subframes,
+                error: self.interp.error,
+            };
+            let mut subpacker = Packer {
+                interp: subinterp,
+                collecting: self.collecting,
+                call_limit: self.call_limit,
+            };
+            map.next_key_seed(&mut subpacker)?
+        } {
+            self.frames_mut().iter_active_mut().filter(|frame| {
+                if let PatternElement::Tag { key_subtree } = frame.op() {
+                    key_subtree.is_some()
+                } else {
+                    unreachable!()
+                }
+            }).zip(&mut subframes).for_each(|(frame, subframe)| {
+                frame.matches = subframe.matches;
+                // reset subframe for next iteration
+                subframe.matches = true;
+                subframe.iar = None;
+            });
+            let packed_value = map.next_value_seed(&mut *self)?;
+            if self.collecting {
+                obj_inner.push(
+                    (packed_key.1.unwrap(), packed_value.1.unwrap()),
+                );
+            }
+            todo!("merge kv");
         }
-        todo!()
+        todo!();
     }
     fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
     where
@@ -676,8 +723,8 @@ mod tests {
                 name_and_value: These::Both(0, Value::Type {
                     ty: Type::U64,
                     skippable: false,
-                })
-            }
+                }),
+            },
         ]);
         let mut der = JsonDeserializer::from_str("3");
         let mut err = Default::default();
@@ -688,5 +735,79 @@ mod tests {
         assert!(obj.is_none());
         assert_eq!(packs[0].subpacks[0]["hello"].1, SerdeObject::U64(3));
     }
+
+    #[test]
+    fn test_simple_error() {
+        let mut consts = PatternConstants::<()>::default();
+        consts.strings.push("hello".into());
+        consts.protos.push(vec![
+            PatternElement::Value {
+                name_and_value: These::Both(0, Value::Type {
+                    ty: Type::U64,
+                    skippable: false,
+                }),
+            },
+        ]);
+        let mut der = JsonDeserializer::from_str("\"hello\"");
+        let mut err = Default::default();
+        let mut frames = Default::default();
+        let interp = Interpreter::new(&consts, &mut err, &mut frames);
+        let packed = Packer::new(interp, MAX_CALLS).deserialize(&mut der);
+        dbg!(&packed);
+        // error produced by serde_json
+        assert!(packed.is_err());
+    }
+
+    #[test]
+    fn test_map() {
+        let mut consts = PatternConstants::<()>::default();
+        consts.strings.push("key".into());
+        consts.strings.push("value".into());
+        consts.protos.push(vec![
+            PatternElement::Value {
+                name_and_value: These::This(0),
+            },
+        ]);
+        consts.protos.push(vec![
+            PatternElement::Value {
+                name_and_value: These::That(Value::Type {
+                    ty: Type::Map,
+                    skippable: false,
+                }),
+            },
+            PatternElement::Tag {
+                key_subtree: Some(0),
+            },
+            PatternElement::Value {
+                name_and_value: These::Both(1, Value::Type {
+                    ty: Type::U64,
+                    skippable: false,
+                }),
+            },
+        ]);
+        let mut der = JsonDeserializer::from_str(r#"{"hello": 0, "world": 1}"#);
+        let mut err = Default::default();
+        let mut frames = Default::default();
+        let interp = Interpreter::new(&consts, &mut err, &mut frames);
+        let packed = Packer::new(interp, MAX_CALLS).deserialize(&mut der);
+        let (packs, obj) = packed.unwrap();
+        assert!(obj.is_none());
+        assert_eq!(
+            packs[0].subpacks[0]["key"].1,
+            SerdeObject::Str("hello".into()),
+        );
+        assert_eq!(
+            packs[0].subpacks[0]["value"].1,
+            SerdeObject::U64(0),
+        );
+        assert_eq!(
+            packs[0].subpacks[1]["key"].1,
+            SerdeObject::Str("world".into()),
+        );
+        assert_eq!(
+            packs[0].subpacks[1]["value"].1,
+            SerdeObject::U64(1),
+        );
+    }
 }