summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--src/lib.rs312
-rw-r--r--tests/unit.rs13
3 files changed, 308 insertions, 19 deletions
diff --git a/Cargo.toml b/Cargo.toml
index be0542c..8065cca 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "serde_transmute"
-version = "0.1.1"
+version = "0.1.2"
 edition = "2021"
 description = "Transmute objects through serde."
 license = "MIT OR Apache-2.0"
diff --git a/src/lib.rs b/src/lib.rs
index a4a1371..c15c45b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -13,6 +13,7 @@
 //! But we don't care because this crate was built to power parts of `datafu`.
 //! And it's pretty good at that.
 
+use serde::de::Error as _;
 use serde::de::{Deserialize, DeserializeSeed, Deserializer, Visitor};
 use serde::ser::{Serialize, Serializer};
 
@@ -65,6 +66,7 @@ pub struct Settings {
     human_readable: bool,
     /// Whether structs are sequences.
     structs_are_seqs: bool,
+    numeric_variants: bool,
 }
 
 impl Default for Settings {
@@ -72,6 +74,7 @@ impl Default for Settings {
         Self {
             human_readable: true,
             structs_are_seqs: false,
+            numeric_variants: false,
         }
     }
 }
@@ -108,6 +111,20 @@ impl Settings {
         self.structs_are_seqs = true;
         self
     }
+
+    /// Treat enum variants as keyed by strings.
+    ///
+    /// This is the default.
+    pub fn variants_are_strings(&mut self) -> &mut Self {
+        self.numeric_variants = false;
+        self
+    }
+
+    /// Treat enum variants as keyed by numbers.
+    pub fn variants_are_numbers(&mut self) -> &mut Self {
+        self.numeric_variants = true;
+        self
+    }
 }
 
 struct Transmute<'settings> {
@@ -130,7 +147,78 @@ impl<'settings, 'de> Deserializer<'de> for Transmute<'settings> {
         self,
         visitor: Vis,
     ) -> Result<Vis::Value, Self::Error> {
-        todo!()
+        let settings = self.settings;
+        match self.collection {
+            Collection::Bool(value) => visitor.visit_bool(value),
+            Collection::I8(value) => visitor.visit_i8(value),
+            Collection::I16(value) => visitor.visit_i16(value),
+            Collection::I32(value) => visitor.visit_i32(value),
+            Collection::I64(value) => visitor.visit_i64(value),
+            Collection::I128(value) => visitor.visit_i128(value),
+            Collection::U8(value) => visitor.visit_u8(value),
+            Collection::U16(value) => visitor.visit_u16(value),
+            Collection::U32(value) => visitor.visit_u32(value),
+            Collection::U64(value) => visitor.visit_u64(value),
+            Collection::U128(value) => visitor.visit_u128(value),
+            Collection::F32(value) => visitor.visit_f32(value),
+            Collection::F64(value) => visitor.visit_f64(value),
+            Collection::Char(value) => visitor.visit_char(value),
+            Collection::Str(value) => {
+                visitor.visit_string(String::from(value))
+            },
+            Collection::Bytes(value) => {
+                visitor.visit_byte_buf(Vec::from(value))
+            },
+            Collection::Some(value) => {
+                visitor.visit_some(Self {
+                    settings,
+                    collection: *value,
+                })
+            },
+            Collection::None => visitor.visit_none(),
+            | Collection::Unit
+            | Collection::UnitStruct { .. }
+            | Collection::UnitVariant { .. }
+            => visitor.visit_unit(),
+            | Collection::NewtypeVariant { value, .. }
+            | Collection::NewtypeStruct { value, .. }
+            => {
+                visitor.visit_newtype_struct(Self {
+                    settings,
+                    collection: *value,
+                })
+            },
+            | Collection::Seq(_)
+            | Collection::Tuple(_)
+            | Collection::TupleVariant { .. }
+            | Collection::TupleStruct { .. }
+            => visitor.visit_seq(self),
+            Collection::Map(_) => {
+                todo!()
+            }
+            | Collection::StructVariant { .. }
+            | Collection::Struct { .. }
+            => {
+                todo!()
+            }
+        }
+    }
+
+    fn deserialize_enum<Vis: Visitor<'de>>(
+        self,
+        name: &'static str,
+        variants: &'static [&'static str],
+        visitor: Vis,
+    ) -> Result<Vis::Value, Self::Error> {
+        match self.collection {
+            | Collection::UnitVariant { .. }
+            | Collection::NewtypeVariant { .. }
+            | Collection::TupleVariant { .. }
+            | Collection::StructVariant { .. }
+            => visitor.visit_enum(self),
+            // FIXME?
+            _ => Err(TransmuteError(())),
+        }
     }
 
     fn deserialize_ignored_any<Vis: Visitor<'de>>(
@@ -144,7 +232,195 @@ impl<'settings, 'de> Deserializer<'de> for Transmute<'settings> {
         <Vis: Visitor<'de>>
         bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
         bytes byte_buf option unit unit_struct newtype_struct seq tuple
-        tuple_struct map struct enum identifier
+        tuple_struct map struct identifier
+    }
+
+    fn is_human_readable(&self) -> bool {
+        self.settings.human_readable
+    }
+}
+
+impl<'settings, 'de> serde::de::SeqAccess<'de> for Transmute<'settings> {
+    type Error = TransmuteError;
+    fn next_element_seed<T>(
+        &mut self,
+        seed: T,
+    ) -> Result<Option<T::Value>, Self::Error>
+    where
+        T: DeserializeSeed<'de>,
+    {
+        match &mut self.collection {
+            | Collection::Seq(values)
+            | Collection::Tuple(values)
+            | Collection::TupleVariant { values, .. }
+            | Collection::TupleStruct { values, .. }
+            => values.next(),
+            _ => unreachable!(),
+        }.map(|value| seed.deserialize(Self {
+            settings: self.settings,
+            collection: value,
+        })).transpose()
+    }
+    fn size_hint(&self) -> Option<usize> {
+        match &self.collection {
+            | Collection::Seq(values)
+            | Collection::Tuple(values)
+            | Collection::TupleVariant { values, .. }
+            | Collection::TupleStruct { values, .. }
+            => values.size_hint().1,
+            _ => unreachable!(),
+        }
+    }
+}
+
+impl<'settings, 'de> serde::de::EnumAccess<'de> for Transmute<'settings> {
+    type Error = TransmuteError;
+    type Variant = Self;
+
+    fn variant_seed<T>(
+        self,
+        seed: T,
+    ) -> Result<(T::Value, Self), Self::Error>
+    where
+        T: DeserializeSeed<'de>,
+    {
+        let tag = match self.collection {
+            | Collection::UnitVariant { variant, variant_index, .. }
+            | Collection::NewtypeVariant { variant, variant_index, .. }
+            | Collection::TupleVariant { variant, variant_index, .. }
+            | Collection::StructVariant { variant, variant_index, .. }
+            => {
+                (variant, variant_index)
+            }
+            _ => unreachable!(),
+        };
+        let val = match self.settings.numeric_variants {
+            false => seed.deserialize({
+                serde::de::value::StrDeserializer::new(tag.0)
+            }),
+            true => seed.deserialize({
+                serde::de::value::U32Deserializer::new(tag.1)
+            }),
+        };
+        val.map(|val| (val, self))
+    }
+}
+
+impl<'settings, 'de> serde::de::VariantAccess<'de> for Transmute<'settings> {
+    type Error = TransmuteError;
+    fn unit_variant(self) -> Result<(), Self::Error> {
+        match self.collection {
+            Collection::UnitVariant { .. } => Ok(()),
+            _ => {
+                let unexp = match self.collection {
+                    Collection::UnitVariant { .. } => {
+                        serde::de::Unexpected::UnitVariant
+                    },
+                    Collection::NewtypeVariant { .. } => {
+                        serde::de::Unexpected::NewtypeVariant
+                    },
+                    Collection::TupleVariant { .. } => {
+                        serde::de::Unexpected::TupleVariant
+                    },
+                    Collection::StructVariant { .. } => {
+                        serde::de::Unexpected::StructVariant
+                    },
+                    _ => unreachable!(),
+                };
+                Err(Self::Error::invalid_type(unexp, &"unit variant"))
+            }
+        }
+    }
+    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
+    where
+        T: DeserializeSeed<'de>,
+    {
+        match self.collection {
+            Collection::NewtypeVariant { value, .. } => {
+                seed.deserialize(Self {
+                    settings: self.settings,
+                    collection: *value,
+                })
+            },
+            _ => {
+                let unexp = match self.collection {
+                    Collection::UnitVariant { .. } => {
+                        serde::de::Unexpected::UnitVariant
+                    },
+                    Collection::NewtypeVariant { .. } => {
+                        serde::de::Unexpected::NewtypeVariant
+                    },
+                    Collection::TupleVariant { .. } => {
+                        serde::de::Unexpected::TupleVariant
+                    },
+                    Collection::StructVariant { .. } => {
+                        serde::de::Unexpected::StructVariant
+                    },
+                    _ => unreachable!(),
+                };
+                Err(Self::Error::invalid_type(unexp, &"newtype variant"))
+            }
+        }
+    }
+    fn tuple_variant<Vis>(
+        self,
+        len: usize,
+        visitor: Vis,
+    ) -> Result<Vis::Value, Self::Error>
+    where
+        Vis: Visitor<'de>,
+    {
+        match self.collection {
+            Collection::TupleVariant { .. } => todo!(),
+            _ => {
+                let unexp = match self.collection {
+                    Collection::UnitVariant { .. } => {
+                        serde::de::Unexpected::UnitVariant
+                    },
+                    Collection::NewtypeVariant { .. } => {
+                        serde::de::Unexpected::NewtypeVariant
+                    },
+                    Collection::TupleVariant { .. } => {
+                        serde::de::Unexpected::TupleVariant
+                    },
+                    Collection::StructVariant { .. } => {
+                        serde::de::Unexpected::StructVariant
+                    },
+                    _ => unreachable!(),
+                };
+                Err(Self::Error::invalid_type(unexp, &"tuple variant"))
+            }
+        }
+    }
+    fn struct_variant<Vis>(
+        self,
+        fields: &'static [&'static str],
+        visitor: Vis,
+    ) -> Result<Vis::Value, Self::Error>
+    where
+        Vis: Visitor<'de>,
+    {
+        match self.collection {
+            Collection::StructVariant { .. } => todo!(),
+            _ => {
+                let unexp = match self.collection {
+                    Collection::UnitVariant { .. } => {
+                        serde::de::Unexpected::UnitVariant
+                    },
+                    Collection::NewtypeVariant { .. } => {
+                        serde::de::Unexpected::NewtypeVariant
+                    },
+                    Collection::TupleVariant { .. } => {
+                        serde::de::Unexpected::TupleVariant
+                    },
+                    Collection::StructVariant { .. } => {
+                        serde::de::Unexpected::StructVariant
+                    },
+                    _ => unreachable!(),
+                };
+                Err(Self::Error::invalid_type(unexp, &"struct variant"))
+            }
+        }
     }
 }
 
@@ -416,7 +692,7 @@ impl<'settings> serde::ser::SerializeSeq for CollectSeq<'settings> {
         value.serialize(self.ser).map(|it| self.values.push(it))
     }
     fn end(self) -> Result<Self::Ok, Self::Error> {
-        Ok(Collection::Seq(self.values))
+        Ok(Collection::Seq(self.values.into_iter()))
     }
 }
 impl<'settings> serde::ser::SerializeTuple for CollectTuple<'settings> {
@@ -433,7 +709,7 @@ impl<'settings> serde::ser::SerializeTuple for CollectTuple<'settings> {
             // FIXME?
             Err(TransmuteError(()))
         } else {
-            Ok(Collection::Tuple(self.values))
+            Ok(Collection::Tuple(self.values.into_iter()))
         }
     }
 }
@@ -453,7 +729,7 @@ impl<'settings> serde::ser::SerializeTupleStruct for CollectTupleStruct<'setting
         } else {
             Ok(Collection::TupleStruct {
                 name: self.name,
-                values: self.values
+                values: self.values.into_iter()
             })
         }
     }
@@ -476,7 +752,7 @@ impl<'settings> serde::ser::SerializeTupleVariant for CollectTupleVariant<'setti
                 name: self.name,
                 variant: self.variant,
                 variant_index: self.variant_index,
-                values: self.values
+                values: self.values.into_iter()
             })
         }
     }
@@ -511,7 +787,7 @@ impl<'settings> serde::ser::SerializeMap for CollectMap<'settings> {
             // FIXME?
             Err(TransmuteError(()))
         } else {
-            Ok(Collection::Map(self.values))
+            Ok(Collection::Map(self.values.into_iter()))
         }
     }
 }
@@ -535,7 +811,7 @@ impl<'settings> serde::ser::SerializeStruct for CollectStruct<'settings> {
         } else {
             Ok(Collection::Struct {
                 name: self.name,
-                values: self.values
+                values: self.values.into_iter()
             })
         }
     }
@@ -562,14 +838,14 @@ impl<'settings> serde::ser::SerializeStructVariant for CollectStructVariant<'set
                 name: self.name,
                 variant: self.variant,
                 variant_index: self.variant_index,
-                values: self.values
+                values: self.values.into_iter()
             })
         }
     }
 }
 
 /// Types serialized by serde. Refer to `Serializer` for details.
-#[derive(Debug)]
+//#[derive(Debug)]
 enum Collection {
     Bool(bool),
     I8(i8),
@@ -586,7 +862,7 @@ enum Collection {
     F64(f64),
     Char(char),
     Str(Box<str>),
-    Bytes(Vec<u8>),
+    Bytes(Box<[u8]>),
     Some(Box<Collection>),
     None,
     Unit,
@@ -608,28 +884,28 @@ enum Collection {
         variant_index: u32,
         value: Box<Collection>,
     },
-    Seq(Vec<Collection>),
-    Tuple(Vec<Collection>),
+    Seq(std::vec::IntoIter<Collection>),
+    Tuple(std::vec::IntoIter<Collection>),
     TupleStruct {
         name: &'static str,
-        values: Vec<Collection>,
+        values: std::vec::IntoIter<Collection>,
     },
     TupleVariant {
         name: &'static str,
         variant: &'static str,
         variant_index: u32,
-        values: Vec<Collection>,
+        values: std::vec::IntoIter<Collection>,
     },
     // NOTE: support for multimaps!
-    Map(Vec<(Collection, Collection)>),
+    Map(std::vec::IntoIter<(Collection, Collection)>),
     Struct {
         name: &'static str,
-        values: Vec<(&'static str, Collection)>,
+        values: std::vec::IntoIter<(&'static str, Collection)>,
     },
     StructVariant {
         name: &'static str,
         variant: &'static str,
         variant_index: u32,
-        values: Vec<(&'static str, Collection)>,
+        values: std::vec::IntoIter<(&'static str, Collection)>,
     },
 }
diff --git a/tests/unit.rs b/tests/unit.rs
new file mode 100644
index 0000000..c18ee31
--- /dev/null
+++ b/tests/unit.rs
@@ -0,0 +1,13 @@
+#[derive(serde::Serialize)]
+struct Foo(String);
+
+#[derive(serde::Deserialize)]
+struct Bar(String);
+
+#[test]
+fn transmute() {
+    let settings = Default::default();
+    let foo = Foo(String::from("Hello!"));
+    let bar: Bar = serde_transmute::transmute(&foo, &settings).unwrap();
+    assert_eq!(foo.0, bar.0);
+}