blob: 219066ac401587536398f0116ec7bccf1ad1439b [file] [log] [blame]
// Copyright 2015 The tiny-http Contributors
// Copyright 2015 The rust-chunked-transfer Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Result as IoResult;
use std::io::Read;
use std::io::Error as IoError;
use std::io::ErrorKind;
use std::fmt;
use std::error::Error;
/// Reads HTTP chunks and sends back real data.
///
/// # Example
///
/// ```
/// use chunked_transfer::Decoder;
/// use std::io::Read;
///
/// let encoded = b"3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n";
/// let mut decoded = String::new();
///
/// let mut decoder = Decoder::new(encoded as &[u8]);
/// decoder.read_to_string(&mut decoded);
///
/// assert_eq!(decoded, "hello world!!!");
/// ```
pub struct Decoder<R> {
// where the chunks come from
source: R,
// remaining size of the chunk being read
// none if we are not in a chunk
remaining_chunks_size: Option<usize>,
}
impl<R> Decoder<R> where R: Read {
pub fn new(source: R) -> Decoder<R> {
Decoder {
source: source,
remaining_chunks_size: None,
}
}
fn read_chunk_size(&mut self) -> IoResult<usize> {
let mut chunk_size = Vec::new();
let mut has_ext = false;
loop {
let byte = match self.source.by_ref().bytes().next() {
Some(b) => try!(b),
None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
};
if byte == b'\r' {
break;
}
if byte == b';' {
has_ext = true;
break;
}
chunk_size.push(byte);
}
// Ignore extensions for now
if has_ext {
loop {
let byte = match self.source.by_ref().bytes().next() {
Some(b) => try!(b),
None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
};
if byte == b'\r' {
break;
}
}
}
try!(self.read_line_feed());
let chunk_size = match String::from_utf8(chunk_size) {
Ok(c) => c,
Err(_) => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError))
};
let chunk_size = match usize::from_str_radix(chunk_size.trim(), 16) {
Ok(c) => c,
Err(_) => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError))
};
Ok(chunk_size)
}
fn read_carriage_return(&mut self) -> IoResult<()> {
match self.source.by_ref().bytes().next() {
Some(Ok(b'\r')) => Ok(()),
_ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
}
}
fn read_line_feed(&mut self) -> IoResult<()> {
match self.source.by_ref().bytes().next() {
Some(Ok(b'\n')) => Ok(()),
_ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
}
}
}
impl<R> Read for Decoder<R> where R: Read {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
let remaining_chunks_size = match self.remaining_chunks_size {
Some(c) => c,
None => {
// first possibility: we are not in a chunk, so we'll attempt to determine
// the chunks size
let chunk_size = try!(self.read_chunk_size());
// if the chunk size is 0, we are at EOF
if chunk_size == 0 {
try!(self.read_carriage_return());
try!(self.read_line_feed());
return Ok(0);
}
// now that we now the current chunk size, calling ourselves recursively
self.remaining_chunks_size = Some(chunk_size);
return self.read(buf);
}
};
// second possibility: we continue reading from a chunk
if buf.len() < remaining_chunks_size {
let read = try!(self.source.read(buf));
self.remaining_chunks_size = Some(remaining_chunks_size - read);
return Ok(read);
}
// third possibility: the read request goes further than the current chunk
// we simply read until the end of the chunk and return
assert!(buf.len() >= remaining_chunks_size);
let buf = &mut buf[.. remaining_chunks_size];
let read = try!(self.source.read(buf));
self.remaining_chunks_size = if read == remaining_chunks_size {
try!(self.read_carriage_return());
try!(self.read_line_feed());
None
} else {
Some(remaining_chunks_size - read)
};
return Ok(read);
}
}
#[derive(Debug, Copy, Clone)]
struct DecoderError;
impl fmt::Display for DecoderError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(fmt, "Error while decoding chunks")
}
}
impl Error for DecoderError {
fn description(&self) -> &str {
"Error while decoding chunks"
}
}
#[cfg(test)]
mod test {
use super::Decoder;
use std::io;
use std::io::Read;
/// This unit test is taken from from Hyper
/// https://github.com/hyperium/hyper
/// Copyright (c) 2014 Sean McArthur
#[test]
fn test_read_chunk_size() {
fn read(s: &str, expected: usize) {
let mut decoded = Decoder::new(s.as_bytes());
let actual = decoded.read_chunk_size().unwrap();
assert_eq!(expected, actual);
}
fn read_err(s: &str) {
let mut decoded = Decoder::new(s.as_bytes());
let err_kind = decoded.read_chunk_size().unwrap_err().kind();
assert_eq!(err_kind, io::ErrorKind::InvalidInput);
}
read("1\r\n", 1);
read("01\r\n", 1);
read("0\r\n", 0);
read("00\r\n", 0);
read("A\r\n", 10);
read("a\r\n", 10);
read("Ff\r\n", 255);
read("Ff \r\n", 255);
// Missing LF or CRLF
read_err("F\rF");
read_err("F");
// Invalid hex digit
read_err("X\r\n");
read_err("1X\r\n");
read_err("-\r\n");
read_err("-1\r\n");
// Acceptable (if not fully valid) extensions do not influence the size
read("1;extension\r\n", 1);
read("a;ext name=value\r\n", 10);
read("1;extension;extension2\r\n", 1);
read("1;;; ;\r\n", 1);
read("2; extension...\r\n", 2);
read("3 ; extension=123\r\n", 3);
read("3 ;\r\n", 3);
read("3 ; \r\n", 3);
// Invalid extensions cause an error
read_err("1 invalid extension\r\n");
read_err("1 A\r\n");
read_err("1;no CRLF");
}
#[test]
fn test_valid_chunk_decode() {
let source = io::Cursor::new("3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n".to_string().into_bytes());
let mut decoded = Decoder::new(source);
let mut string = String::new();
decoded.read_to_string(&mut string).unwrap();
assert_eq!(string, "hello world!!!");
}
#[test]
fn test_decode_zero_length() {
let mut decoder = Decoder::new(b"0\r\n\r\n" as &[u8]);
let mut decoded = String::new();
decoder.read_to_string(&mut decoded).unwrap();
assert_eq!(decoded, "");
}
#[test]
fn test_decode_invalid_chunk_length() {
let mut decoder = Decoder::new(b"m\r\n\r\n" as &[u8]);
let mut decoded = String::new();
assert!(decoder.read_to_string(&mut decoded).is_err());
}
#[test]
fn invalid_input1() {
let source = io::Cursor::new("2\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n".to_string().into_bytes());
let mut decoded = Decoder::new(source);
let mut string = String::new();
decoded.read_to_string(&mut string).is_err();
}
#[test]
fn invalid_input2() {
let source = io::Cursor::new("3\rhel\r\nb\r\nlo world!!!\r\n0\r\n".to_string().into_bytes());
let mut decoded = Decoder::new(source);
let mut string = String::new();
decoded.read_to_string(&mut string).is_err();
}
}