blob: 4c667fae60121575182f7dfe86793b1862b48859 [file] [log] [blame]
// Copyright (c) 2016 The Rouille developers
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
// at your option. All files in the project carrying such
// notice may not be copied, modified, or distributed except
// according to those terms.
//! Low-level parsing of websocket frames.
//!
//! Usage:
//!
//! - Create a `StateMachine` with `StateMachine::new`.
//! - Whenever data is received on the socket, call `StateMachine::feed`.
//! - The returned iterator produces zero, one or multiple `Element` objects containing what was
//! received.
//! - For `Element::Data`, the `Data` object is an iterator over the decoded bytes.
//! - If `Element::Error` is produced, immediately end the connection.
//!
//! Glossary:
//!
//! - A websocket stream is made of multiple *messages*.
//! - Each message is made of one or more *frames*. See https://tools.ietf.org/html/rfc6455#section-5.4.
//! - Each frame can be received progressively, where each packet is an `Element` object (below).
/// A websocket element decoded from the data given to `StateMachine::feed`.
#[derive(Debug, PartialEq, Eq)]
pub enum Element<'a> {
/// A new frame has started.
FrameStart {
/// If true, this is the last frame of the message.
fin: bool,
/// Length of the frame in bytes.
length: u64,
/// Opcode. See https://tools.ietf.org/html/rfc6455#section-5.2.
opcode: u8,
},
/// Data was received as part of the current frame.
Data {
/// The decoded data. An iterator that produces `u8`s.
data: Data<'a>,
/// If true, this is the last packet in the frame.
last_in_frame: bool,
},
/// An error in the stream. The connection must be dropped ASAP.
Error {
/// A description of the error. Can or cannot be be returned to the client.
desc: &'static str
},
}
/// Decoded data. Implements `Iterator<Item = u8>`.
#[derive(Debug, PartialEq, Eq)]
pub struct Data<'a> {
// Source data. Undecoded.
data: &'a [u8],
// Copy of the mask of the current frame.
mask: u32,
// Same as `StateMachineInner::InData::offset`. Updated at each iteration.
offset: u8,
}
/// A websocket state machine. Contains partial data.
pub struct StateMachine {
// Actual state.
inner: StateMachineInner,
// Contains the start of the header. Must be empty if `inner` is equal to `InData`.
buffer: Vec<u8>, // TODO: use SmallVec?
}
enum StateMachineInner {
// If `StateMachine::inner` is `InHeader`, then `buffer` contains the start of the header.
InHeader,
// If `StateMachine::inner` is `InData`, then `buffer` must be empty.
InData {
// Mask to decode the message.
mask: u32,
// Value between 0 and 3 that indicates the number of bytes between the start of the data
// and the next expected byte.
offset: u8,
// Number of bytes remaining in the frame.
remaining_len: u64,
}
}
impl StateMachine {
/// Initializes a new state machine for a new stream. Expects to see a new frame as the first
/// packet.
pub fn new() -> StateMachine {
StateMachine {
inner: StateMachineInner::InHeader,
buffer: Vec::with_capacity(14),
}
}
/// Feeds data to the state machine. Returns an iterator to the list of elements that were
/// received.
#[inline]
pub fn feed<'a>(&'a mut self, data: &'a [u8]) -> ElementsIter<'a> {
ElementsIter { state: self, data }
}
}
/// Iterator to the list of elements that were received.
pub struct ElementsIter<'a> {
state: &'a mut StateMachine,
data: &'a [u8],
}
impl<'a> Iterator for ElementsIter<'a> {
type Item = Element<'a>;
fn next(&mut self) -> Option<Element<'a>> {
if self.data.is_empty() {
return None;
}
match self.state.inner {
// First situation, we are in the header.
StateMachineInner::InHeader => {
// We need at least 6 bytes for a succesful header. Otherwise we just return.
let total_buffered = self.state.buffer.len() + self.data.len();
if total_buffered < 6 {
self.state.buffer.extend_from_slice(self.data);
self.data = &[];
return None;
}
// Retreive the first two bytes of the header.
let (first_byte, second_byte) = {
let mut mask_iter = self.state.buffer.iter().chain(self.data.iter());
let first_byte = *mask_iter.next().unwrap();
let second_byte = *mask_iter.next().unwrap();
(first_byte, second_byte)
};
// Reserved bits must be zero, otherwise error.
if (first_byte & 0x70) != 0 {
return Some(Element::Error {
desc: "Reserved bits must be zero"
});
}
// Client-to-server messages **must** be encoded.
if (second_byte & 0x80) == 0 {
return Some(Element::Error {
desc: "Client-to-server messages must be masked"
});
}
// Find the length of the frame and the mask.
let (length, mask) = match second_byte & 0x7f {
126 => {
if total_buffered < 8 {
self.state.buffer.extend_from_slice(self.data);
self.data = &[];
return None;
}
let mut mask_iter = self.state.buffer.iter().chain(self.data.iter()).skip(2);
let length = {
let a = u64::from(*mask_iter.next().unwrap());
let b = u64::from(*mask_iter.next().unwrap());
(a << 8) | (b << 0)
};
let mask = {
let a = u32::from(*mask_iter.next().unwrap());
let b = u32::from(*mask_iter.next().unwrap());
let c = u32::from(*mask_iter.next().unwrap());
let d = u32::from(*mask_iter.next().unwrap());
(a << 24) | (b << 16) | (c << 8) | (d << 0)
};
(length, mask)
},
127 => {
if total_buffered < 14 {
self.state.buffer.extend_from_slice(self.data);
self.data = &[];
return None;
}
let mut mask_iter = self.state.buffer.iter().chain(self.data.iter()).skip(2);
let length = {
let a = u64::from(*mask_iter.next().unwrap());
let b = u64::from(*mask_iter.next().unwrap());
let c = u64::from(*mask_iter.next().unwrap());
let d = u64::from(*mask_iter.next().unwrap());
let e = u64::from(*mask_iter.next().unwrap());
let f = u64::from(*mask_iter.next().unwrap());
let g = u64::from(*mask_iter.next().unwrap());
let h = u64::from(*mask_iter.next().unwrap());
// The most significant bit must be zero according to the specs.
if (a & 0x80) != 0 {
return Some(Element::Error {
desc: "Most-significant bit of the length must be zero"
});
}
(a << 56) | (b << 48) | (c << 40) | (d << 32) |
(e << 24) | (f << 16) | (g << 8) | (h << 0)
};
let mask = {
let a = u32::from(*mask_iter.next().unwrap());
let b = u32::from(*mask_iter.next().unwrap());
let c = u32::from(*mask_iter.next().unwrap());
let d = u32::from(*mask_iter.next().unwrap());
(a << 24) | (b << 16) | (c << 8) | (d << 0)
};
(length, mask)
},
n => {
let mut mask_iter = self.state.buffer.iter().chain(self.data.iter()).skip(2);
let mask = {
let a = u32::from(*mask_iter.next().unwrap());
let b = u32::from(*mask_iter.next().unwrap());
let c = u32::from(*mask_iter.next().unwrap());
let d = u32::from(*mask_iter.next().unwrap());
(a << 24) | (b << 16) | (c << 8) | (d << 0)
};
(u64::from(n), mask)
},
};
// Builds a slice containing the start of the data.
let data_start = {
let data_start_off = match second_byte & 0x7f {
126 => 8,
127 => 14,
_ => 6
};
assert!(self.state.buffer.len() < data_start_off);
&self.data[(data_start_off - self.state.buffer.len()) ..]
};
// Update ourselves for the next loop and return a FrameStart message.
self.data = data_start;
self.state.buffer.clear();
self.state.inner = StateMachineInner::InData {
mask,
remaining_len: length,
offset: 0
};
Some(Element::FrameStart {
fin: (first_byte & 0x80) != 0,
length,
opcode: first_byte & 0xf,
})
},
// Second situation, we are in the message and we don't have enough data to finish the
// current frame.
StateMachineInner::InData { mask, ref mut remaining_len, ref mut offset }
if *remaining_len > self.data.len() as u64 =>
{
let data = Data {
data: self.data,
mask,
offset: *offset,
};
*offset += (self.data.len() % 4) as u8;
*offset %= 4;
*remaining_len -= self.data.len() as u64;
self.data = &[];
Some(Element::Data { data, last_in_frame: false })
},
// Third situation, we have enough data to finish the frame.
StateMachineInner::InData { mask, remaining_len, offset } => {
debug_assert!(self.data.len() as u64 >= remaining_len);
let data = Data {
data: &self.data[0 .. remaining_len as usize],
mask,
offset,
};
self.data = &self.data[remaining_len as usize ..];
self.state.inner = StateMachineInner::InHeader;
debug_assert!(self.state.buffer.is_empty());
Some(Element::Data { data, last_in_frame: true })
},
}
}
}
impl<'a> Iterator for Data<'a> {
type Item = u8;
#[inline]
fn next(&mut self) -> Option<u8> {
if self.data.is_empty() {
return None;
}
let byte = self.data[0];
let mask = ((self.mask >> (3 - self.offset) * 8) & 0xff) as u8;
let decoded = byte ^ mask;
self.data = &self.data[1..];
self.offset = (self.offset + 1) % 4;
Some(decoded)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let l = self.data.len();
(l, Some(l))
}
}
impl<'a> ExactSizeIterator for Data<'a> {
}
#[cfg(test)]
mod tests {
use super::Element;
use super::StateMachine;
#[test]
fn basic() {
let mut machine = StateMachine::new();
let data = &[0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58];
let mut iter = machine.feed(data);
assert_eq!(iter.next().unwrap(), Element::FrameStart {
fin: true,
length: 5,
opcode: 1
});
match iter.next().unwrap() {
Element::Data { data, last_in_frame } => {
assert!(last_in_frame);
assert_eq!(data.collect::<Vec<_>>(), b"Hello");
}
_ => panic!()
}
}
}