blob: aeb27a76a9e390c9eb37ac1fadecceecb294a0f8 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::io::{self, Read, Write};
use byteorder::{NativeEndian, ReadBytesExt, WriteBytesExt};
use failure::Fail;
use fuchsia_zircon as zx;
use crate::{Fixed, NewId, ObjectId};
#[derive(Copy, Clone, Debug)]
pub enum ArgKind {
Int,
Uint,
Fixed,
String,
Object,
NewId,
Array,
Handle,
}
#[derive(Debug)]
pub enum Arg {
Int(i32),
Uint(u32),
Fixed(Fixed),
String(String),
Object(ObjectId),
NewId(NewId),
Array(Vec<u8>),
Handle(zx::Handle),
}
impl Arg {
pub fn kind(&self) -> ArgKind {
match self {
Arg::Int(_) => ArgKind::Int,
Arg::Uint(_) => ArgKind::Uint,
Arg::Fixed(_) => ArgKind::Fixed,
Arg::String(_) => ArgKind::String,
Arg::Object(_) => ArgKind::Object,
Arg::NewId(_) => ArgKind::NewId,
Arg::Array(_) => ArgKind::Array,
Arg::Handle(_) => ArgKind::Handle,
}
}
}
macro_rules! impl_unwrap_arg(
($name:ident, $type:ty, $enumtype:ident) => (
pub fn $name(self) -> $type {
if let Arg::$enumtype(x) = self {
x
} else {
panic!("Argument is not of the required type: \
expected {:?}, found {:?}",
stringify!($enumtype), self);
}
}
)
);
impl Arg {
impl_unwrap_arg!(unwrap_int, i32, Int);
impl_unwrap_arg!(unwrap_uint, u32, Uint);
impl_unwrap_arg!(unwrap_fixed, Fixed, Fixed);
impl_unwrap_arg!(unwrap_object, ObjectId, Object);
impl_unwrap_arg!(unwrap_new_id, ObjectId, NewId);
impl_unwrap_arg!(unwrap_string, String, String);
impl_unwrap_arg!(unwrap_array, Vec<u8>, Array);
impl_unwrap_arg!(unwrap_handle, zx::Handle, Handle);
}
#[derive(Debug, Fail)]
#[fail(
display = "Argument is not of the required type: expected {:?}, found {:?}",
expected, found
)]
pub struct MismatchedArgKind {
pub expected: ArgKind,
pub found: ArgKind,
}
macro_rules! impl_as_arg(
($name:ident, $type:ty, $enumtype:ident) => (
pub fn $name(self) -> Result<$type, MismatchedArgKind> {
if let Arg::$enumtype(x) = self {
Ok(x)
} else {
Err(MismatchedArgKind {
expected: ArgKind::$enumtype,
found: self.kind(),
})
}
}
)
);
impl Arg {
impl_as_arg!(as_int, i32, Int);
impl_as_arg!(as_uint, u32, Uint);
impl_as_arg!(as_fixed, Fixed, Fixed);
impl_as_arg!(as_object, ObjectId, Object);
impl_as_arg!(as_new_id, ObjectId, NewId);
impl_as_arg!(as_string, String, String);
impl_as_arg!(as_array, Vec<u8>, Array);
impl_as_arg!(as_handle, zx::Handle, Handle);
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct MessageHeader {
pub sender: u32,
pub opcode: u16,
pub length: u16,
}
pub struct Message {
byte_buf: io::Cursor<Vec<u8>>,
handle_buf: Vec<zx::Handle>,
}
fn compute_padding(size: u64) -> usize {
(-(size as i64) & 3) as usize
}
impl Message {
/// Initialize an empty Message.
pub fn new() -> Self {
Message {
byte_buf: io::Cursor::new(Vec::new()),
handle_buf: Vec::new(),
}
}
pub fn from_parts(bytes: Vec<u8>, handles: Vec<zx::Handle>) -> Self {
Message {
byte_buf: io::Cursor::new(bytes),
handle_buf: handles,
}
}
/// Returns |true| iff this buffer has no more data (bytes or handles).
pub fn is_empty(&self) -> bool {
self.byte_buf.get_ref().len() as u64 == self.byte_buf.position()
&& self.handle_buf.is_empty()
}
pub fn clear(&mut self) {
self.byte_buf.set_position(0);
self.byte_buf.get_mut().truncate(0);
self.handle_buf.truncate(0);
}
pub fn rewind(&mut self) {
self.byte_buf.set_position(0);
}
pub fn bytes(&self) -> &[u8] {
self.byte_buf.get_ref().as_slice()
}
pub fn take(self) -> (Vec<u8>, Vec<zx::Handle>) {
(self.byte_buf.into_inner(), self.handle_buf)
}
/// Converts the value in |arg| into the appropriate wire format
/// serialization.
pub fn write_arg(&mut self, arg: Arg) -> io::Result<()> {
match arg {
Arg::Int(i) => self.byte_buf.write_i32::<NativeEndian>(i),
Arg::Uint(i) => self.byte_buf.write_u32::<NativeEndian>(i),
Arg::Fixed(i) => self.byte_buf.write_i32::<NativeEndian>(i.bits()),
Arg::Object(i) => self.byte_buf.write_u32::<NativeEndian>(i),
Arg::NewId(i) => self.byte_buf.write_u32::<NativeEndian>(i),
Arg::String(s) => self.write_slice(s.as_bytes(), true),
Arg::Array(vec) => self.write_slice(vec.as_slice(), false),
Arg::Handle(h) => {
self.handle_buf.push(h);
Ok(())
}
}
}
/// Reads an Arg out of this Message and return the value.
pub fn read_arg(&mut self, arg: ArgKind) -> io::Result<Arg> {
match arg {
ArgKind::Int => self.byte_buf.read_i32::<NativeEndian>().map(Arg::Int),
ArgKind::Uint => self.byte_buf.read_u32::<NativeEndian>().map(Arg::Uint),
ArgKind::Fixed => self.byte_buf.read_i32::<NativeEndian>().map(|i| Arg::Fixed(i.into())),
ArgKind::Object => self.byte_buf.read_u32::<NativeEndian>().map(Arg::Object),
ArgKind::NewId => self.byte_buf.read_u32::<NativeEndian>().map(Arg::NewId),
ArgKind::String => self
.read_slice(true)
.and_then(|vec| {
String::from_utf8(vec).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "Unable to decode string")
})
})
.map(Arg::String),
ArgKind::Array => self.read_slice(false).map(Arg::Array),
ArgKind::Handle => {
if !self.handle_buf.is_empty() {
Ok(Arg::Handle(self.handle_buf.remove(0)))
} else {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Unable to read handle from Message",
))
}
}
}
}
/// Reads the set of arguments specified by the given &[ArgKind].
pub fn read_args(&mut self, args: &[ArgKind]) -> io::Result<Vec<Arg>> {
args.iter().map(|arg| self.read_arg(*arg)).collect()
}
/// Reads the set of arguments specified by the given &[ArgKind].
pub fn write_args(&mut self, args: Vec<Arg>) -> io::Result<()> {
args.into_iter().try_for_each(|arg| self.write_arg(arg))?;
Ok(())
}
/// Reads a |MessageHeader| without advancing the read pointer in the
/// underlying buffer.
pub fn peek_header(&mut self) -> io::Result<MessageHeader> {
let pos = self.byte_buf.position();
let header = self.read_header();
self.byte_buf.set_position(pos);
header
}
pub fn read_header(&mut self) -> io::Result<MessageHeader> {
let sender = self.byte_buf.read_u32::<NativeEndian>()?;
let word = self.byte_buf.read_u32::<NativeEndian>()?;
Ok(MessageHeader {
sender,
length: (word >> 16) as u16,
opcode: word as u16,
})
}
pub fn write_header(&mut self, header: &MessageHeader) -> io::Result<()> {
self.byte_buf.write_u32::<NativeEndian>(header.sender)?;
self.byte_buf
.write_u32::<NativeEndian>((header.length as u32) << 16 | header.opcode as u32)?;
Ok(())
}
fn read_slice(&mut self, null_term: bool) -> io::Result<Vec<u8>> {
let pos = self.byte_buf.position();
let len = self.byte_buf.read_u32::<NativeEndian>()?;
let mut vec: Vec<u8> = Vec::with_capacity(len as usize);
if len == 0 {
return Ok(vec);
}
vec.resize(len as usize, 0);
self.byte_buf.read_exact(vec.as_mut_slice())?;
if null_term {
match vec.pop() {
Some(term) => {
if term != b'\0' {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Expected null terminator; found {}", term),
));
}
}
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Missing null terminator on string argument",
));
}
}
}
let pad = compute_padding(self.byte_buf.position() - pos);
if pad > 0 {
self.byte_buf.read_uint::<NativeEndian>(pad)?;
}
assert!(self.byte_buf.position() % 4 == 0);
Ok(vec)
}
fn write_slice(&mut self, s: &[u8], null_term: bool) -> io::Result<()> {
let pos = self.byte_buf.position();
let mut len: u32 = s.len() as u32;
if null_term {
len += 1;
}
self.byte_buf.write_u32::<NativeEndian>(len)?;
self.byte_buf.write_all(s)?;
if null_term {
self.byte_buf.write_u8(0)?;
}
// Pad to 32-bit boundary.
let pad = compute_padding(self.byte_buf.position() - pos);
if pad > 0 {
self.byte_buf.write_uint::<NativeEndian>(0, pad)?;
}
assert!(self.byte_buf.position() % 4 == 0);
Ok(())
}
}
/// Convert from the zx::MessageBuf we get out of the channel.
impl From<zx::MessageBuf> for Message {
fn from(buf: zx::MessageBuf) -> Self {
let (bytes, handles) = buf.split();
Message::from_parts(bytes, handles)
}
}
#[cfg(test)]
mod tests {
use super::*;
use failure::Error;
/// Helper to assist asserting a single match branch.
///
/// Ex:
///
/// let arg = Arg::Uint(8);
/// assert_matches(arg, Arg::Uint(x) => assert_eq!(x, 8));
macro_rules! assert_matches(
($e:expr, $p:pat => $a:expr) => (
match $e {
$p => $a,
_ => panic!("Failed to match!"),
}
)
);
const UINT_VALUE: u32 = 0x1234567;
const INT_VALUE: i32 = -12345678;
const FIXED_VALUE: i32 = 0x11223344;
const OBJECT_VALUE: u32 = 0x88775566;
const NEW_ID_VALUE: u32 = 0x55443322;
const STRING_VALUE: &str = "Hello from a test";
const ARRAY_VALUE: &[u8] = &[0, 1, 2, 3, 4, 5, 6];
/// Emit one argument of each type and read it back out.
#[test]
fn sanity() {
let (h1, _h2) = zx::Socket::create(zx::SocketOpts::STREAM).unwrap();
let mut message = Message::new();
assert!(message.write_arg(Arg::Uint(UINT_VALUE)).is_ok());
assert!(message.write_arg(Arg::Int(INT_VALUE)).is_ok());
assert!(message.write_arg(Arg::Fixed(Fixed::from_bits(FIXED_VALUE))).is_ok());
assert!(message.write_arg(Arg::Object(OBJECT_VALUE)).is_ok());
assert!(message.write_arg(Arg::NewId(NEW_ID_VALUE)).is_ok());
assert!(message
.write_arg(Arg::String(STRING_VALUE.to_owned()))
.is_ok());
assert!(message
.write_arg(Arg::Array(ARRAY_VALUE.to_owned()))
.is_ok());
assert!(message.write_arg(Arg::Handle(h1.into())).is_ok());
let (bytes, handles) = message.take();
assert_eq!(1, handles.len());
const INT_SIZE: u32 = 4;
const UINT_SIZE: u32 = 4;
const FIXED_SIZE: u32 = 4;
const OBJECT_SIZE: u32 = 4;
const NEW_ID_SIZE: u32 = 4;
const STRING_SIZE: u32 = 24;
const VEC_SIZE: u32 = 12;
let expected_size =
INT_SIZE + UINT_SIZE + FIXED_SIZE + OBJECT_SIZE + NEW_ID_SIZE + STRING_SIZE + VEC_SIZE;
assert_eq!(expected_size as usize, bytes.len());
let mut message = Message::from_parts(bytes, handles);
let arg = message.read_arg(ArgKind::Uint).unwrap();
assert_matches!(arg, Arg::Uint(x) => assert_eq!(x, UINT_VALUE));
let arg = message.read_arg(ArgKind::Int).unwrap();
assert_matches!(arg, Arg::Int(x) => assert_eq!(x, INT_VALUE));
let arg = message.read_arg(ArgKind::Fixed).unwrap();
assert_matches!(arg, Arg::Fixed(x) => assert_eq!(x, Fixed::from_bits(FIXED_VALUE)));
let arg = message.read_arg(ArgKind::Object).unwrap();
assert_matches!(arg, Arg::Object(x) => assert_eq!(x, OBJECT_VALUE));
let arg = message.read_arg(ArgKind::NewId).unwrap();
assert_matches!(arg, Arg::NewId(x) => assert_eq!(x, NEW_ID_VALUE));
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref x) => assert_eq!(x, STRING_VALUE));
let arg = message.read_arg(ArgKind::Array).unwrap();
assert_matches!(arg, Arg::Array(ref x) => assert_eq!(x.as_slice(), ARRAY_VALUE));
}
/// Test some of the padding edge-cases for string values.
#[test]
fn string_test() {
let mut message = Message::new();
// 4 bytes length + 1 null terminating byte + 3 pad bytes
let s0 = "";
message.clear();
assert!(message.write_arg(Arg::String(s0.to_owned())).is_ok());
assert_eq!(8, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s0));
// 4 bytes length + 1 char + 1 null terminating byte + 2 pad bytes
let s1 = "1";
message.clear();
assert!(message.write_arg(Arg::String(s1.to_owned())).is_ok());
assert_eq!(8, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s1));
// 4 bytes length + 2 char + 1 null terminating byte + 1 pad bytes
let s2 = "22";
message.clear();
assert!(message.write_arg(Arg::String(s2.to_owned())).is_ok());
assert_eq!(8, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s2));
// 4 bytes length + 3 char + 1 null terminating byte + 0 pad bytes
let s3 = "333";
message.clear();
assert!(message.write_arg(Arg::String(s3.to_owned())).is_ok());
assert_eq!(8, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s3));
// 4 bytes length + 4 char + 1 null terminating byte + 3 pad bytes
let s4 = "4444";
message.clear();
assert!(message.write_arg(Arg::String(s4.to_owned())).is_ok());
assert_eq!(12, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s4));
// 4 bytes length + 5 char + 1 null terminating byte + 2 pad bytes
let s5 = "55555";
message.clear();
assert!(message.write_arg(Arg::String(s5.to_owned())).is_ok());
assert_eq!(12, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, s5));
// Null string (no \0 termiator). This is distinct from the empty
// string (which is encoded as a single null terminator character).
// Note that once decoded, a null and an empty string will have the
// same representation, which is a Rust 'String' with length 0, but
// this test ensures that we decode the null string correctly.
//
// 4 bytes length
message.clear();
assert!(message.write_arg(Arg::Uint(0)).is_ok());
assert_eq!(4, message.bytes().len());
message.rewind();
let arg = message.read_arg(ArgKind::String).unwrap();
assert_matches!(arg, Arg::String(ref s) => assert_eq!(s, ""));
}
#[test]
fn peek_header() -> Result<(), Error> {
// Write just a message header.
let header = MessageHeader {
sender: 3,
opcode: 2,
length: 8,
};
let mut message = Message::new();
message.write_header(&header)?;
message.rewind();
// Verify peek doesn't advance the read pointer.
assert_eq!(header, message.peek_header()?);
assert_eq!(header, message.peek_header()?);
assert_eq!(header, message.peek_header()?);
assert_eq!(header, message.peek_header()?);
Ok(())
}
#[test]
fn empty_message() {
let (h1, h2) = zx::Channel::create().unwrap();
let message = Message::new();
assert!(message.is_empty());
let message = Message::from_parts(vec![], vec![]);
assert!(message.is_empty());
let message = Message::from_parts(vec![1], vec![]);
assert!(!message.is_empty());
let message = Message::from_parts(vec![], vec![h1.into()]);
assert!(!message.is_empty());
let mut message = Message::from_parts(vec![0, 0, 0, 0], vec![h2.into()]);
assert!(!message.is_empty());
let _ = message.read_arg(ArgKind::Uint).unwrap();
assert!(!message.is_empty());
let _ = message.read_arg(ArgKind::Handle).unwrap();
assert!(message.is_empty());
}
}