| |
| use std::collections::VecDeque; |
| |
| use crate::msgs::codec; |
| use crate::msgs::message::{Message, MessagePayload}; |
| use crate::msgs::enums::{ContentType, ProtocolVersion}; |
| use crate::msgs::handshake::HandshakeMessagePayload; |
| |
| const HEADER_SIZE: usize = 1 + 3; |
| |
| /// This works to reconstruct TLS handshake messages |
| /// from individual TLS messages. It's guaranteed that |
| /// TLS messages output from this layer contain precisely |
| /// one handshake payload. |
| pub struct HandshakeJoiner { |
| /// Completed handshake frames for output. |
| pub frames: VecDeque<Message>, |
| |
| /// The message payload we're currently accumulating. |
| buf: Vec<u8>, |
| } |
| |
| impl Default for HandshakeJoiner { |
| fn default() -> Self { Self::new() } |
| } |
| |
| impl HandshakeJoiner { |
| /// Make a new HandshakeJoiner. |
| pub fn new() -> HandshakeJoiner { |
| HandshakeJoiner { |
| frames: VecDeque::new(), |
| buf: Vec::new(), |
| } |
| } |
| |
| /// Do we want to process this message? |
| pub fn want_message(&self, msg: &Message) -> bool { |
| msg.is_content_type(ContentType::Handshake) |
| } |
| |
| /// Do we have any buffered data? |
| pub fn is_empty(&self) -> bool { |
| self.buf.is_empty() |
| } |
| |
| /// Take the message, and join/split it as needed. |
| /// Return the number of new messages added to the |
| /// output deque as a result of this message. |
| /// |
| /// Returns None if msg or a preceding message was corrupt. |
| /// You cannot recover from this situation. Otherwise returns |
| /// a count of how many messages we queued. |
| pub fn take_message(&mut self, mut msg: Message) -> Option<usize> { |
| // Input must be opaque, otherwise we might have already |
| // lost information! |
| let payload = msg.take_opaque_payload().unwrap(); |
| |
| self.buf.extend_from_slice(&payload.0[..]); |
| |
| let mut count = 0; |
| while self.buf_contains_message() { |
| if !self.deframe_one(msg.version) { |
| return None; |
| } |
| |
| count += 1; |
| } |
| |
| Some(count) |
| } |
| |
| /// Does our `buf` contain a full handshake payload? It does if it is big |
| /// enough to contain a header, and that header has a length which falls |
| /// within `buf`. |
| fn buf_contains_message(&self) -> bool { |
| self.buf.len() >= HEADER_SIZE && |
| self.buf.len() >= (codec::u24::decode(&self.buf[1..4]).unwrap().0 as usize) + HEADER_SIZE |
| } |
| |
| /// Take a TLS handshake payload off the front of `buf`, and put it onto |
| /// the back of our `frames` deque inside a normal `Message`. |
| /// |
| /// Returns false if the stream is desynchronised beyond repair. |
| fn deframe_one(&mut self, version: ProtocolVersion) -> bool { |
| let used = { |
| let mut rd = codec::Reader::init(&self.buf); |
| let payload = HandshakeMessagePayload::read_version(&mut rd, version); |
| |
| if payload.is_none() { |
| return false; |
| } |
| |
| let m = Message { |
| typ: ContentType::Handshake, |
| version, |
| payload: MessagePayload::Handshake(payload.unwrap()), |
| }; |
| |
| self.frames.push_back(m); |
| rd.used() |
| }; |
| self.buf = self.buf.split_off(used); |
| true |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::HandshakeJoiner; |
| use crate::msgs::enums::{ProtocolVersion, ContentType, HandshakeType}; |
| use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload}; |
| use crate::msgs::message::{Message, MessagePayload}; |
| use crate::msgs::base::Payload; |
| |
| #[test] |
| fn want() { |
| let hj = HandshakeJoiner::new(); |
| assert_eq!(hj.is_empty(), true); |
| |
| let wanted = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"hello world".to_vec()), |
| }; |
| |
| let unwanted = Message { |
| typ: ContentType::Alert, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"ponytown".to_vec()), |
| }; |
| |
| assert_eq!(hj.want_message(&wanted), true); |
| assert_eq!(hj.want_message(&unwanted), false); |
| } |
| |
| fn pop_eq(expect: &Message, hj: &mut HandshakeJoiner) { |
| let got = hj.frames.pop_front().unwrap(); |
| assert_eq!(got.typ, expect.typ); |
| assert_eq!(got.version, expect.version); |
| |
| let (mut left, mut right) = (Vec::new(), Vec::new()); |
| got.payload.encode(&mut left); |
| expect.payload.encode(&mut right); |
| |
| assert_eq!(left, right); |
| } |
| |
| #[test] |
| fn split() { |
| // Check we split two handshake messages within one PDU. |
| let mut hj = HandshakeJoiner::new(); |
| |
| // two HelloRequests |
| let msg = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()), |
| }; |
| |
| assert_eq!(hj.want_message(&msg), true); |
| assert_eq!(hj.take_message(msg), Some(2)); |
| assert_eq!(hj.is_empty(), true); |
| |
| let expect = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::Handshake(HandshakeMessagePayload { |
| typ: HandshakeType::HelloRequest, |
| payload: HandshakePayload::HelloRequest, |
| }), |
| }; |
| |
| pop_eq(&expect, &mut hj); |
| pop_eq(&expect, &mut hj); |
| } |
| |
| #[test] |
| fn broken() { |
| // Check obvious crap payloads are reported as errors, not panics. |
| let mut hj = HandshakeJoiner::new(); |
| |
| // short ClientHello |
| let msg = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"\x01\x00\x00\x02\xff\xff".to_vec()), |
| }; |
| |
| assert_eq!(hj.want_message(&msg), true); |
| assert_eq!(hj.take_message(msg), None); |
| } |
| |
| #[test] |
| fn join() { |
| // Check we join one handshake message split over two PDUs. |
| let mut hj = HandshakeJoiner::new(); |
| assert_eq!(hj.is_empty(), true); |
| |
| // Introduce Finished of 16 bytes, providing 4. |
| let mut msg = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()), |
| }; |
| |
| assert_eq!(hj.want_message(&msg), true); |
| assert_eq!(hj.take_message(msg), Some(0)); |
| assert_eq!(hj.is_empty(), false); |
| |
| // 11 more bytes. |
| msg = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()), |
| }; |
| |
| assert_eq!(hj.want_message(&msg), true); |
| assert_eq!(hj.take_message(msg), Some(0)); |
| assert_eq!(hj.is_empty(), false); |
| |
| // Final 1 byte. |
| msg = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::new_opaque(b"\x0f".to_vec()), |
| }; |
| |
| assert_eq!(hj.want_message(&msg), true); |
| assert_eq!(hj.take_message(msg), Some(1)); |
| assert_eq!(hj.is_empty(), true); |
| |
| let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec(); |
| let expect = Message { |
| typ: ContentType::Handshake, |
| version: ProtocolVersion::TLSv1_2, |
| payload: MessagePayload::Handshake(HandshakeMessagePayload { |
| typ: HandshakeType::Finished, |
| payload: HandshakePayload::Finished(Payload::new(payload)), |
| }), |
| }; |
| |
| pop_eq(&expect, &mut hj); |
| } |
| } |