Merge pull request #22 from gebressler/size_hint

Implement size_hint for de::SeqAccess and de::MapAccess
diff --git a/src/de.rs b/src/de.rs
index a9f274f..7b3ba04 100644
--- a/src/de.rs
+++ b/src/de.rs
@@ -4,6 +4,7 @@
 use serde::de;
 use serde::forward_to_deserialize_any;
 use std::char;
+use std::collections::VecDeque;
 use std::f64;
 
 use crate::error::{Error, Result};
@@ -58,12 +59,8 @@
                     visitor.visit_f64(parse_number(&pair)?)
                 }
             }
-            Rule::array => visitor.visit_seq(Seq {
-                pairs: pair.into_inner(),
-            }),
-            Rule::object => visitor.visit_map(Map {
-                pairs: pair.into_inner(),
-            }),
+            Rule::array => visitor.visit_seq(Seq::new(pair)),
+            Rule::object => visitor.visit_map(Map::new(pair)),
             _ => unreachable!(),
         }
     }
@@ -284,17 +281,27 @@
 }
 
 struct Seq<'de> {
-    pairs: Pairs<'de, Rule>,
+    pairs: VecDeque<Pair<'de, Rule>>,
+}
+
+impl<'de> Seq<'de> {
+    pub fn new(pair: Pair<'de, Rule>) -> Self {
+        Self { pairs: pair.into_inner().into_iter().collect() }
+    }
 }
 
 impl<'de> de::SeqAccess<'de> for Seq<'de> {
     type Error = Error;
 
+    fn size_hint(&self) -> Option<usize> {
+        Some(self.pairs.len())
+    }
+
     fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
     where
         T: de::DeserializeSeed<'de>,
     {
-        if let Some(pair) = self.pairs.next() {
+        if let Some(pair) = self.pairs.pop_front() {
             seed.deserialize(&mut Deserializer::from_pair(pair))
                 .map(Some)
         } else {
@@ -304,17 +311,27 @@
 }
 
 struct Map<'de> {
-    pairs: Pairs<'de, Rule>,
+    pairs: VecDeque<Pair<'de, Rule>>,
+}
+
+impl<'de> Map<'de> {
+    pub fn new(pair: Pair<'de, Rule>) -> Self {
+        Self { pairs: pair.into_inner().into_iter().collect() }
+    }
 }
 
 impl<'de> de::MapAccess<'de> for Map<'de> {
     type Error = Error;
 
+    fn size_hint(&self) -> Option<usize> {
+        Some(self.pairs.len() / 2)
+    }
+
     fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
     where
         K: de::DeserializeSeed<'de>,
     {
-        if let Some(pair) = self.pairs.next() {
+        if let Some(pair) = self.pairs.pop_front() {
             seed.deserialize(&mut Deserializer::from_pair(pair))
                 .map(Some)
         } else {
@@ -326,7 +343,7 @@
     where
         V: de::DeserializeSeed<'de>,
     {
-        seed.deserialize(&mut Deserializer::from_pair(self.pairs.next().unwrap()))
+        seed.deserialize(&mut Deserializer::from_pair(self.pairs.pop_front().unwrap()))
     }
 }
 
@@ -386,9 +403,7 @@
     {
         match self.pair {
             Some(pair) => match pair.as_rule() {
-                Rule::array => visitor.visit_seq(Seq {
-                    pairs: pair.into_inner(),
-                }),
+                Rule::array => visitor.visit_seq(Seq::new(pair)),
                 _ => Err(de::Error::custom("expected an array")),
             }
             None => Err(de::Error::custom("expected an array")),
@@ -401,9 +416,7 @@
     {
         match self.pair {
             Some(pair) => match pair.as_rule() {
-                Rule::object => visitor.visit_map(Map {
-                    pairs: pair.into_inner(),
-                }),
+                Rule::object => visitor.visit_map(Map::new(pair)),
                 _ => Err(de::Error::custom("expected an object")),
             }
             None => Err(de::Error::custom("expected an object")),
diff --git a/tests/de.rs b/tests/de.rs
index b35f1de..c1dbeb7 100644
--- a/tests/de.rs
+++ b/tests/de.rs
@@ -1,6 +1,8 @@
-use serde_derive::Deserialize;
+use serde::de;
+use serde_derive::{Deserialize};
 
 use std::collections::HashMap;
+use std::fmt;
 
 mod common;
 
@@ -304,7 +306,40 @@
             Val::Bool(true),
             Val::String("hello".to_owned()),
         ],
-    )
+    );
+}
+
+#[test]
+fn deserializes_seq_size_hint() {
+    #[derive(Debug, PartialEq)]
+    struct Size(usize);
+    impl<'de> de::Deserialize<'de> for Size {
+        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+        where
+            D: de::Deserializer<'de>,
+        {
+            struct Visitor;
+            impl<'de> de::Visitor<'de> for Visitor {
+                type Value = Size;
+
+                fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+                    f.write_str("array")
+                }
+
+                fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
+                where
+                    A: serde::de::SeqAccess<'de>
+                {
+                    Ok(Size(seq.size_hint().unwrap()))
+                }
+            }
+            deserializer.deserialize_seq(Visitor)
+        }
+    }
+
+    deserializes_to("[]", Size(0));
+    deserializes_to("[42, true, 'hello']", Size(3));
+    deserializes_to("[42, true, [1, 2]]", Size(3));
 }
 
 #[test]
@@ -335,6 +370,39 @@
 }
 
 #[test]
+fn deserializes_map_size_hint() {
+    #[derive(Debug, PartialEq)]
+    struct Size(usize);
+    impl<'de> de::Deserialize<'de> for Size {
+        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+        where
+            D: de::Deserializer<'de>,
+        {
+            struct Visitor;
+            impl<'de> de::Visitor<'de> for Visitor {
+                type Value = Size;
+
+                fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+                    f.write_str("array")
+                }
+
+                fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
+                where
+                    A: serde::de::MapAccess<'de>
+                {
+                    Ok(Size(map.size_hint().unwrap()))
+                }
+            }
+            deserializer.deserialize_map(Visitor)
+        }
+    }
+
+    deserializes_to("{}", Size(0));
+    deserializes_to("{ a: 1, 'b': 2, \"c\": 3 }", Size(3));
+    deserializes_to("{ a: 1, 'b': 2, \"c\": [1, 2] }", Size(3));
+}
+
+#[test]
 fn deserializes_struct() {
     #[derive(Deserialize, PartialEq, Debug)]
     struct S {