Merge pull request #22 from gebressler/size_hint
Implement size_hint for de::SeqAccess and de::MapAccess
diff --git a/src/de.rs b/src/de.rs
index a9f274f..7b3ba04 100644
--- a/src/de.rs
+++ b/src/de.rs
@@ -4,6 +4,7 @@
use serde::de;
use serde::forward_to_deserialize_any;
use std::char;
+use std::collections::VecDeque;
use std::f64;
use crate::error::{Error, Result};
@@ -58,12 +59,8 @@
visitor.visit_f64(parse_number(&pair)?)
}
}
- Rule::array => visitor.visit_seq(Seq {
- pairs: pair.into_inner(),
- }),
- Rule::object => visitor.visit_map(Map {
- pairs: pair.into_inner(),
- }),
+ Rule::array => visitor.visit_seq(Seq::new(pair)),
+ Rule::object => visitor.visit_map(Map::new(pair)),
_ => unreachable!(),
}
}
@@ -284,17 +281,27 @@
}
struct Seq<'de> {
- pairs: Pairs<'de, Rule>,
+ pairs: VecDeque<Pair<'de, Rule>>,
+}
+
+impl<'de> Seq<'de> {
+ pub fn new(pair: Pair<'de, Rule>) -> Self {
+ Self { pairs: pair.into_inner().into_iter().collect() }
+ }
}
impl<'de> de::SeqAccess<'de> for Seq<'de> {
type Error = Error;
+ fn size_hint(&self) -> Option<usize> {
+ Some(self.pairs.len())
+ }
+
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
- if let Some(pair) = self.pairs.next() {
+ if let Some(pair) = self.pairs.pop_front() {
seed.deserialize(&mut Deserializer::from_pair(pair))
.map(Some)
} else {
@@ -304,17 +311,27 @@
}
struct Map<'de> {
- pairs: Pairs<'de, Rule>,
+ pairs: VecDeque<Pair<'de, Rule>>,
+}
+
+impl<'de> Map<'de> {
+ pub fn new(pair: Pair<'de, Rule>) -> Self {
+ Self { pairs: pair.into_inner().into_iter().collect() }
+ }
}
impl<'de> de::MapAccess<'de> for Map<'de> {
type Error = Error;
+ fn size_hint(&self) -> Option<usize> {
+ Some(self.pairs.len() / 2)
+ }
+
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: de::DeserializeSeed<'de>,
{
- if let Some(pair) = self.pairs.next() {
+ if let Some(pair) = self.pairs.pop_front() {
seed.deserialize(&mut Deserializer::from_pair(pair))
.map(Some)
} else {
@@ -326,7 +343,7 @@
where
V: de::DeserializeSeed<'de>,
{
- seed.deserialize(&mut Deserializer::from_pair(self.pairs.next().unwrap()))
+ seed.deserialize(&mut Deserializer::from_pair(self.pairs.pop_front().unwrap()))
}
}
@@ -386,9 +403,7 @@
{
match self.pair {
Some(pair) => match pair.as_rule() {
- Rule::array => visitor.visit_seq(Seq {
- pairs: pair.into_inner(),
- }),
+ Rule::array => visitor.visit_seq(Seq::new(pair)),
_ => Err(de::Error::custom("expected an array")),
}
None => Err(de::Error::custom("expected an array")),
@@ -401,9 +416,7 @@
{
match self.pair {
Some(pair) => match pair.as_rule() {
- Rule::object => visitor.visit_map(Map {
- pairs: pair.into_inner(),
- }),
+ Rule::object => visitor.visit_map(Map::new(pair)),
_ => Err(de::Error::custom("expected an object")),
}
None => Err(de::Error::custom("expected an object")),
diff --git a/tests/de.rs b/tests/de.rs
index b35f1de..c1dbeb7 100644
--- a/tests/de.rs
+++ b/tests/de.rs
@@ -1,6 +1,8 @@
-use serde_derive::Deserialize;
+use serde::de;
+use serde_derive::{Deserialize};
use std::collections::HashMap;
+use std::fmt;
mod common;
@@ -304,7 +306,40 @@
Val::Bool(true),
Val::String("hello".to_owned()),
],
- )
+ );
+}
+
+#[test]
+fn deserializes_seq_size_hint() {
+ #[derive(Debug, PartialEq)]
+ struct Size(usize);
+ impl<'de> de::Deserialize<'de> for Size {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: de::Deserializer<'de>,
+ {
+ struct Visitor;
+ impl<'de> de::Visitor<'de> for Visitor {
+ type Value = Size;
+
+ fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("array")
+ }
+
+ fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::SeqAccess<'de>
+ {
+ Ok(Size(seq.size_hint().unwrap()))
+ }
+ }
+ deserializer.deserialize_seq(Visitor)
+ }
+ }
+
+ deserializes_to("[]", Size(0));
+ deserializes_to("[42, true, 'hello']", Size(3));
+ deserializes_to("[42, true, [1, 2]]", Size(3));
}
#[test]
@@ -335,6 +370,39 @@
}
#[test]
+fn deserializes_map_size_hint() {
+ #[derive(Debug, PartialEq)]
+ struct Size(usize);
+ impl<'de> de::Deserialize<'de> for Size {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: de::Deserializer<'de>,
+ {
+ struct Visitor;
+ impl<'de> de::Visitor<'de> for Visitor {
+ type Value = Size;
+
+ fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("array")
+ }
+
+ fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::MapAccess<'de>
+ {
+ Ok(Size(map.size_hint().unwrap()))
+ }
+ }
+ deserializer.deserialize_map(Visitor)
+ }
+ }
+
+ deserializes_to("{}", Size(0));
+ deserializes_to("{ a: 1, 'b': 2, \"c\": 3 }", Size(3));
+ deserializes_to("{ a: 1, 'b': 2, \"c\": [1, 2] }", Size(3));
+}
+
+#[test]
fn deserializes_struct() {
#[derive(Deserialize, PartialEq, Debug)]
struct S {