summary refs log blame commit diff stats
path: root/src/vm/de.rs
blob: 640b1aa5210ae5ead8d0b86a42f45748a1a12368 (plain) (tree)
1
2
3
4
5
6
7
8
9
10


                                             
                                            
 




                             

                          
                                     
 



                       
                 

                       

                            


                       



                                                                            



                                                      
                      

                                        

 





























                                                                                                       

 
                                                               
                             
                      
                                             


                          


                                   

         
 








                                                     
         
     

 




















                                                                               


                 




                                                             


                                   




































































                                                                              
              







                                                                      
              






                                                          
                                                               


                                

                                      
                               


                                                 








                                                                              
































































































                                                                              
         






















                                                                            


                                                                        















































                                                                    






                                                                           

                                


                      
                                     
                                
                                                                  





                             
                                                                  










































                                                                                                                                                                 




                                                   
 




























                                                                              
                                                                        



                                                


                                                             



































                                                                              

                                                                  







                                                           

                                                                  








                                                           

                                                                  
















                                                                                 




                                                                          



                                                                     
// Copyright (C) 2022 Soni L.
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Deserialization-related parts of the VM.

use std::borrow::Cow;
use std::marker::PhantomData;

use indexmap::IndexMap;

use serde::Serialize;
use serde::de::Error as _;
use serde::de::IntoDeserializer as _;

use smallvec::SmallVec;

use these::These;

use super::Frame;
use super::Interpreter;
use super::Pack;
use super::PatternConstants;
use super::PatternElement;
use super::SerdeObject;
use super::Type;
use super::Value;

/// A `DeserializeSeed` for Datafu input.
///
/// This converts from Serde to Datafu's internal representation (a "pack").
pub(crate) struct Packer<'pat, 'state, O: Serialize> {
    /// The global interpreter state.
    interp: Interpreter<'pat, 'state, O>,
    /// Current call limit.
    call_limit: usize,
    /// Whether we're collecting values.
    collecting: bool,
}

struct FramesMut<'packer, 'pat> {
    frames: std::cell::RefMut<'packer, Vec<Frame<'pat>>>,
}

struct Frames<'packer, 'pat> {
    frames: std::cell::Ref<'packer, Vec<Frame<'pat>>>,
}

impl<'packer, 'pat> FramesMut<'packer, 'pat> {
    fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item=&'a mut Frame<'pat>> where 'packer: 'a {
        self.frames.iter_mut()
    }

    fn iter_active_mut<'a>(&'a mut self) -> impl Iterator<Item=&'a mut Frame<'pat>> where 'packer: 'a {
        self.iter_mut().filter(|frame| {
            frame.matches
        })
    }
}

impl<'packer, 'pat> Frames<'packer, 'pat> {
    fn iter<'a>(&'a self) -> impl Iterator<Item=&'a Frame<'pat>> where 'packer: 'a {
        self.frames.iter()
    }

    fn iter_active<'a>(&'a self) -> impl Iterator<Item=&'a Frame<'pat>> where 'packer: 'a {
        self.iter().filter(|frame| {
            frame.matches
        })
    }
}

impl<'pat, 'state, 'de, O: Serialize> Packer<'pat, 'state, O> {
    /// Creates a new Packer.
    pub(crate) fn new(
        interp: Interpreter<'pat, 'state, O>,
        call_limit: usize,
    ) -> Self {
        Self {
            interp: interp,
            call_limit: call_limit,
            collecting: false,
        }
    }

    fn frames_mut(&mut self) -> FramesMut<'_, 'pat> {
        FramesMut {
            frames: self.interp.frames.borrow_mut(),
        }
    }

    fn frames(&mut self) -> Frames<'_, 'pat> {
        Frames {
            frames: self.interp.frames.borrow(),
        }
    }
}

// what steps do we have to take?
//
//  1. figure out what type we need to deserialize (and ask the deserializer
//      for it).
//  2. visit value. figure out whether we need to store it or not?
//  3. if we need to store it how do we figure out *where* to store it?
//  4. if we *don't* need to store it, what do we do?
//  5. how do we tell if we do or don't need to store it? how do we propagate
//      those requirements deeper into the Deserialize's and how do we bring
//      the values back out (recursively?) to parent Deserialize's, without
//      wasting time storing things we don't actually care about?
//      5.a. just have a flag in the DeserializeSeed for whether to capture the
//          values. propagation is more or less trivial from there.
//  6. how do you handle value subtrees?
//      6.a. you don't. for now.
//  7. how do you handle errors?
//      7.a. put them into a "state" and raise a D::Error::custom. then
//          override it in the relevant Pattern call.

impl<'pat, 'state, 'de, O> serde::de::DeserializeSeed<'de>
for Packer<'pat, 'state, O>
where
    O: Serialize,
{
    type Value = (Pack<'pat, 'de>, Option<SerdeObject<'de>>);
    fn deserialize<D>(
        mut self,
        deserializer: D,
    ) -> Result<Self::Value, D::Error>
    where
        D: serde::Deserializer<'de>
    {
        self.frames_mut().iter_mut().for_each(|frame| {
            if !frame.next() {
                frame.matches = false;
            }
        });
        let pat = self.interp.pat;
        let target_type = self.frames().iter_active().fold(
            Type::IgnoredAny,
            |target_type, frame| {
                match (target_type, frame.get_type(pat)) {
                    (Type::IgnoredAny, Some((ty, _))) => ty,
                    (ty, Some((Type::IgnoredAny, _))) => ty,
                    (Type::String, Some((Type::Str, _))) => {
                        Type::String
                    },
                    (Type::Str, Some((Type::String, _))) => {
                        Type::String
                    },
                    (Type::Bytes, Some((Type::ByteBuf, _))) => {
                        Type::ByteBuf
                    },
                    (Type::ByteBuf, Some((Type::Bytes, _))) => {
                        Type::ByteBuf
                    },
                    (left, Some((right, _))) if left == right => {
                        left
                    },
                    _ => Type::Any,
                }
            },
        );
        match target_type {
            Type::Any => deserializer.deserialize_any(self),
            Type::IgnoredAny => deserializer.deserialize_ignored_any(self),
            Type::Bool => deserializer.deserialize_bool(self),
            Type::I8 => deserializer.deserialize_i8(self),
            Type::I16 => deserializer.deserialize_i16(self),
            Type::I32 => deserializer.deserialize_i32(self),
            Type::I64 => deserializer.deserialize_i64(self),
            Type::I128 => deserializer.deserialize_i128(self),
            Type::U8 => deserializer.deserialize_u8(self),
            Type::U16 => deserializer.deserialize_u16(self),
            Type::U32 => deserializer.deserialize_u32(self),
            Type::U64 => deserializer.deserialize_u64(self),
            Type::U128 => deserializer.deserialize_u128(self),
            Type::F32 => deserializer.deserialize_f32(self),
            Type::F64 => deserializer.deserialize_f64(self),
            Type::Char => deserializer.deserialize_char(self),
            Type::Str if !self.collecting => {
                deserializer.deserialize_str(self)
            },
            Type::Str | Type::String => deserializer.deserialize_string(self),
            Type::Bytes if !self.collecting => {
                deserializer.deserialize_bytes(self)
            },
            Type::Bytes | Type::ByteBuf => {
                deserializer.deserialize_byte_buf(self)
            },
            Type::Option => deserializer.deserialize_option(self),
            Type::Unit => deserializer.deserialize_unit(self),
            Type::Seq => deserializer.deserialize_seq(self),
            Type::Map => deserializer.deserialize_map(self),
            Type::Identifier => deserializer.deserialize_identifier(self),
            Type::Tuple(len) => deserializer.deserialize_tuple(len, self),
            Type::UnitStruct(name) => {
                deserializer.deserialize_unit_struct(name, self)
            },
            Type::NewtypeStruct(name) => {
                deserializer.deserialize_newtype_struct(name, self)
            },
            Type::TupleStruct { name, len } => {
                deserializer.deserialize_tuple_struct(name, len, self)
            },
            Type::Struct { name, fields } => {
                deserializer.deserialize_struct(name, fields, self)
            },
            Type::Enum { name, variants } => {
                deserializer.deserialize_enum(name, variants, self)
            },
        }
    }
}

/// visit method generator for simple values (primitives).
macro_rules! vs {
    ($visit:ident $obj:ident $t:ty) => {
        fn $visit<E>(mut self, v: $t) -> Result<Self::Value, E>
        where
            E: serde::de::Error,
        {
            // FIXME filtering/errors
            let pat = self.interp.pat;
            let mut obj = None;
            if self.collecting {
                obj = Some(SerdeObject::$obj(v));
            }
            let mut pack = Pack::default();
            self.frames_mut().iter_active_mut().try_for_each(|frame| {
                let mut map = IndexMap::new();
                if let Some(name) = frame.get_name(pat) {
                    map.insert(name, (Pack::default(), SerdeObject::$obj(v)));
                }
                pack.subpacks.push(map);
                Ok(())
            })?;
            Ok((pack, obj))
        }
    }
}

impl<'pat, 'state, 'de, O> serde::de::Visitor<'de>
for Packer<'pat, 'state, O>
where
    O: Serialize,
{
    type Value = (Pack<'pat, 'de>, Option<SerdeObject<'de>>);
    fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "unsure")
    }

    vs!(visit_bool Bool bool);
    vs!(visit_i8 I8 i8);
    vs!(visit_i16 I16 i16);
    vs!(visit_i32 I32 i32);
    vs!(visit_i64 I64 i64);
    vs!(visit_i128 I128 i128);
    vs!(visit_u8 U8 u8);
    vs!(visit_u16 U16 u16);
    vs!(visit_u32 U32 u32);
    vs!(visit_u64 U64 u64);
    vs!(visit_u128 U128 u128);
    vs!(visit_f32 F32 f32);
    vs!(visit_f64 F64 f64);
    vs!(visit_char Char char);

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Str(Cow::Owned(String::from(v))));
        }
        todo!()
    }
    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Str(Cow::Borrowed(v)));
        }
        todo!()
    }
    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Str(Cow::Owned(v)));
        }
        todo!()
    }
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Bytes(Cow::Owned(Vec::from(v))));
        }
        todo!()
    }
    fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Bytes(Cow::Borrowed(v)));
        }
        todo!()
    }
    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Bytes(Cow::Owned(v)));
        }
        todo!()
    }
    fn visit_none<E>(self) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::None);
        }
        todo!()
    }
    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
    where
        D: serde::de::Deserializer<'de>,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Some(todo!()));
        }
        todo!()
    }
    fn visit_unit<E>(self) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        // FIXME subtrees
        let mut obj = None;
        let mut pack = Pack::default();
        if self.collecting {
            obj = Some(SerdeObject::Unit);
        }
        let mut map = IndexMap::new();
        //for name in self.get_name() {
        //    map.insert(name, (Default::default(), SerdeObject::Unit));
        //}
        pack.subpacks.push(map);
        Ok((pack, obj))
    }
    fn visit_newtype_struct<D>(
        self,
        deserializer: D
    ) -> Result<Self::Value, D::Error>
    where
        D: serde::de::Deserializer<'de>,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::NewtypeStruct(todo!()));
        }
        todo!()
    }
    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
    where
        A: serde::de::SeqAccess<'de>,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Seq(Vec::new()));
        }
        todo!()
    }
    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
    where
        A: serde::de::MapAccess<'de>,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Map(Vec::new()));
        }
        todo!()
    }
    fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
    where
        A: serde::de::EnumAccess<'de>,
    {
        let mut obj = None;
        if self.collecting {
            obj = Some(SerdeObject::Enum {
                variant: todo!(),
                data: todo!(),
            });
        }
        todo!()
    }
}

/// A `Deserializer` for Datafu output.
///
/// This converts from Datafu's internal representation (a "pack") into the
/// desired output type.
pub struct Unpacker<'pat, 'de> {
    pack: Pack<'pat, 'de>,
    call_limit: usize,
}

impl<'pat, 'de> Unpacker<'pat, 'de> {
    /// Unpacks a Datafu "pack".
    pub fn new(pack: Pack<'pat, 'de>, call_limit: usize) -> Self {
        Self {
            pack, call_limit,
        }
    }
}

impl<'pat, 'de> serde::Deserializer<'de> for Unpacker<'pat, 'de> {
    // TODO datafu errors
    type Error = serde::de::value::Error;
    fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_bool<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_i8<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_i16<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_i32<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_i64<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_u8<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_u16<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_u32<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_u64<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_f32<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_f64<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_char<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_str<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_string<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_bytes<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_byte_buf<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_option<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_unit<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_unit_struct<V>(self, _: &'static str, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_newtype_struct<V>(self, _: &'static str, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_seq<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_tuple<V>(self, _: usize, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_map<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_struct<V>(
        self,
        _: &'static str,
        fields: &'static [&'static str],
        visitor: V,
    ) -> Result<V::Value, Self::Error>
    where
        V: serde::de::Visitor<'de>,
    {
        todo!()
    }
    fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_identifier<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
    fn deserialize_ignored_any<V>(self, _: V) -> Result<V::Value, Self::Error> where V: serde::de::Visitor<'de> { todo!() }
}

/// Deserializes a SerdeObject
pub(crate) struct SerdeObjectDeserializer<'de, E> {
    pub(crate) obj: SerdeObject<'de>,
    pub(crate) value: Option<SerdeObject<'de>>,
    pub(crate) _e: PhantomData<fn() -> E>,
}

impl<'de, E> serde::de::Deserializer<'de> for SerdeObjectDeserializer<'de, E>
where
    E: serde::de::Error,
{
    type Error = E;
    fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
    where
        V: serde::de::Visitor<'de>,
    {
        match self.obj {
            SerdeObject::Bool(x) => v.visit_bool(x),
            SerdeObject::I8(x) => v.visit_i8(x),
            SerdeObject::I16(x) => v.visit_i16(x),
            SerdeObject::I32(x) => v.visit_i32(x),
            SerdeObject::I64(x) => v.visit_i64(x),
            SerdeObject::I128(x) => v.visit_i128(x),
            SerdeObject::U8(x) => v.visit_u8(x),
            SerdeObject::U16(x) => v.visit_u16(x),
            SerdeObject::U32(x) => v.visit_u32(x),
            SerdeObject::U64(x) => v.visit_u64(x),
            SerdeObject::U128(x) => v.visit_u128(x),
            SerdeObject::F32(x) => v.visit_f32(x),
            SerdeObject::F64(x) => v.visit_f64(x),
            SerdeObject::Char(x) => v.visit_char(x),
            SerdeObject::Str(Cow::Owned(x)) => v.visit_string(x),
            SerdeObject::Str(Cow::Borrowed(x)) => v.visit_borrowed_str(x),
            SerdeObject::Bytes(Cow::Owned(x)) => v.visit_byte_buf(x),
            SerdeObject::Bytes(Cow::Borrowed(x)) => v.visit_borrowed_bytes(x),
            SerdeObject::Some(x) => v.visit_some(x.into_deserializer()),
            SerdeObject::None => v.visit_none(),
            SerdeObject::Unit => v.visit_unit(),
            SerdeObject::Seq(x) => todo!(),
            SerdeObject::Map(x) => todo!(),
            SerdeObject::NewtypeStruct(x) => {
                v.visit_newtype_struct(x.into_deserializer())
            },
            SerdeObject::Enum { variant, data } => todo!(),
        }
    }
    fn deserialize_ignored_any<V>(self, v: V) -> Result<V::Value, Self::Error>
    where
        V: serde::de::Visitor<'de>,
    {
        drop(self);
        v.visit_unit()
    }
    serde::forward_to_deserialize_any! {
        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
        bytes byte_buf option unit unit_struct newtype_struct seq tuple
        tuple_struct map struct enum identifier
    }
}

#[cfg(test)]
mod tests {
    use super::Packer;
    use super::super::PatternConstants;
    use crate::vm::MAX_CALLS;
    use crate::vm::Interpreter;
    use crate::vm::Type;
    use crate::vm::Value;
    use crate::vm::PatternElement;
    use crate::vm::SerdeObject;
    use these::These;
    use serde_json::Deserializer as JsonDeserializer;
    use serde::de::DeserializeSeed as _;

    #[test]
    #[should_panic]
    fn test_broken() {
        let consts = PatternConstants::<()>::default();
        let mut err = Default::default();
        let frames = Default::default();
        let interp = Interpreter::new(&consts, &mut err, &frames);
        let _ = Packer::new(interp, MAX_CALLS);
    }

    #[test]
    fn test_empty_create() {
        let mut consts = PatternConstants::<()>::default();
        consts.protos.push(Vec::new());
        let mut err = Default::default();
        let frames = Default::default();
        let interp = Interpreter::new(&consts, &mut err, &frames);
        let _ = Packer::new(interp, MAX_CALLS);
    }

    #[test]
    fn test_empty_match() {
        let mut consts = PatternConstants::<()>::default();
        consts.protos.push(Vec::new());
        let mut der = JsonDeserializer::from_str("{}");
        let mut err = Default::default();
        let frames = Default::default();
        let interp = Interpreter::new(&consts, &mut err, &frames);
        let pack = Packer::new(interp, MAX_CALLS).deserialize(&mut der).unwrap();
    }

    #[test]
    fn test_simple_match() {
        let mut consts = PatternConstants::<()>::default();
        consts.strings.push("hello".into());
        consts.protos.push(vec![
            PatternElement::Value {
                name_and_value: These::Both(0, Value::Type {
                    ty: Type::U64,
                    skippable: false,
                })
            }
        ]);
        let mut der = JsonDeserializer::from_str("3");
        let mut err = Default::default();
        let frames = Default::default();
        let interp = Interpreter::new(&consts, &mut err, &frames);
        let packed = Packer::new(interp, MAX_CALLS).deserialize(&mut der);
        let (pack, obj) = packed.unwrap();
        assert!(obj.is_none());
        assert_eq!(pack.subpacks[0]["hello"].1, SerdeObject::U64(3));
    }
}