blob: 7d0ab73d8f55f1abe67250a0ef95aeb6f2b0a5c4 [file] [log] [blame]
mod error;
mod framed_read;
mod framed_write;
pub use self::error::{RecvError, SendError, UserError};
use self::framed_read::FramedRead;
use self::framed_write::FramedWrite;
use crate::frame::{self, Data, Frame};
use bytes::Buf;
use futures_core::Stream;
use futures_sink::Sink;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::length_delimited;
use std::io;
#[derive(Debug)]
pub struct Codec<T, B> {
inner: FramedRead<FramedWrite<T, B>>,
}
impl<T, B> Codec<T, B>
where
T: AsyncRead + AsyncWrite + Unpin,
B: Buf,
{
/// Returns a new `Codec` with the default max frame size
#[inline]
pub fn new(io: T) -> Self {
Self::with_max_recv_frame_size(io, frame::DEFAULT_MAX_FRAME_SIZE as usize)
}
/// Returns a new `Codec` with the given maximum frame size
pub fn with_max_recv_frame_size(io: T, max_frame_size: usize) -> Self {
// Wrap with writer
let framed_write = FramedWrite::new(io);
// Delimit the frames
let delimited = length_delimited::Builder::new()
.big_endian()
.length_field_length(3)
.length_adjustment(9)
.num_skip(0) // Don't skip the header
.new_read(framed_write);
let mut inner = FramedRead::new(delimited);
// Use FramedRead's method since it checks the value is within range.
inner.set_max_frame_size(max_frame_size);
Codec { inner }
}
}
impl<T, B> Codec<T, B> {
/// Updates the max received frame size.
///
/// The change takes effect the next time a frame is decoded. In other
/// words, if a frame is currently in process of being decoded with a frame
/// size greater than `val` but less than the max frame size in effect
/// before calling this function, then the frame will be allowed.
#[inline]
pub fn set_max_recv_frame_size(&mut self, val: usize) {
self.inner.set_max_frame_size(val)
}
/// Returns the current max received frame size setting.
///
/// This is the largest size this codec will accept from the wire. Larger
/// frames will be rejected.
#[cfg(feature = "unstable")]
#[inline]
pub fn max_recv_frame_size(&self) -> usize {
self.inner.max_frame_size()
}
/// Returns the max frame size that can be sent to the peer.
pub fn max_send_frame_size(&self) -> usize {
self.inner.get_ref().max_frame_size()
}
/// Set the peer's max frame size.
pub fn set_max_send_frame_size(&mut self, val: usize) {
self.framed_write().set_max_frame_size(val)
}
/// Set the peer's header table size size.
pub fn set_send_header_table_size(&mut self, val: usize) {
self.framed_write().set_header_table_size(val)
}
/// Set the max header list size that can be received.
pub fn set_max_recv_header_list_size(&mut self, val: usize) {
self.inner.set_max_header_list_size(val);
}
/// Get a reference to the inner stream.
#[cfg(feature = "unstable")]
pub fn get_ref(&self) -> &T {
self.inner.get_ref().get_ref()
}
/// Get a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut().get_mut()
}
/// Takes the data payload value that was fully written to the socket
pub(crate) fn take_last_data_frame(&mut self) -> Option<Data<B>> {
self.framed_write().take_last_data_frame()
}
fn framed_write(&mut self) -> &mut FramedWrite<T, B> {
self.inner.get_mut()
}
}
impl<T, B> Codec<T, B>
where
T: AsyncWrite + Unpin,
B: Buf,
{
/// Returns `Ready` when the codec can buffer a frame
pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
self.framed_write().poll_ready(cx)
}
/// Buffer a frame.
///
/// `poll_ready` must be called first to ensure that a frame may be
/// accepted.
///
/// TODO: Rename this to avoid conflicts with Sink::buffer
pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
self.framed_write().buffer(item)
}
/// Flush buffered data to the wire
pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
self.framed_write().flush(cx)
}
/// Shutdown the send half
pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
self.framed_write().shutdown(cx)
}
}
impl<T, B> Stream for Codec<T, B>
where
T: AsyncRead + Unpin,
{
type Item = Result<Frame, RecvError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
impl<T, B> Sink<Frame<B>> for Codec<T, B>
where
T: AsyncWrite + Unpin,
B: Buf,
{
type Error = SendError;
fn start_send(mut self: Pin<&mut Self>, item: Frame<B>) -> Result<(), Self::Error> {
Codec::buffer(&mut self, item)?;
Ok(())
}
/// Returns `Ready` when the codec can buffer a frame
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.framed_write().poll_ready(cx).map_err(Into::into)
}
/// Flush buffered data to the wire
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.framed_write().flush(cx).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.shutdown(cx))?;
Poll::Ready(Ok(()))
}
}
// TODO: remove (or improve) this
impl<T> From<T> for Codec<T, bytes::Bytes>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn from(src: T) -> Self {
Self::new(src)
}
}