blob: 5b098bd49b2cc2fb50780cc4f24820f89332963c [file] [log] [blame]
use crate::codec::{Decoder, Encoder};
use tokio::{net::UdpSocket, stream::Stream};
use bytes::{BufMut, BytesMut};
use futures_core::ready;
use futures_sink::Sink;
use std::io;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::pin::Pin;
use std::task::{Context, Poll};
/// A unified `Stream` and `Sink` interface to an underlying `UdpSocket`, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
///
/// Raw UDP sockets work with datagrams, but higher-level code usually wants to
/// batch these into meaningful chunks, called "frames". This method layers
/// framing on top of this socket by using the `Encoder` and `Decoder` traits to
/// handle encoding and decoding of messages frames. Note that the incoming and
/// outgoing frame types may be distinct.
///
/// This function returns a *single* object that is both `Stream` and `Sink`;
/// grouping this into a single object is often useful for layering things which
/// require both read and write access to the underlying object.
///
/// If you want to work more directly with the streams and sink, consider
/// calling `split` on the `UdpFramed` returned by this method, which will break
/// them into separate objects, allowing them to interact more easily.
#[must_use = "sinks do nothing unless polled"]
#[cfg_attr(docsrs, doc(all(feature = "codec", feature = "udp")))]
#[derive(Debug)]
pub struct UdpFramed<C> {
socket: UdpSocket,
codec: C,
rd: BytesMut,
wr: BytesMut,
out_addr: SocketAddr,
flushed: bool,
}
impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
type Item = Result<(C::Item, SocketAddr), C::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
pin.rd.reserve(INITIAL_RD_CAPACITY);
let (_n, addr) = unsafe {
// Read into the buffer without having to initialize the memory.
//
// safety: we know tokio::net::UdpSocket never reads from the memory
// during a recv
let res = {
let bytes = &mut *(pin.rd.bytes_mut() as *mut _ as *mut [u8]);
ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, bytes))
};
let (n, addr) = res?;
pin.rd.advance_mut(n);
(n, addr)
};
let frame_res = pin.codec.decode(&mut pin.rd);
pin.rd.clear();
let frame = frame_res?;
let result = frame.map(|frame| Ok((frame, addr))); // frame -> (frame, addr)
Poll::Ready(result)
}
}
impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.flushed {
match self.poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
let (frame, out_addr) = item;
let pin = self.get_mut();
pin.codec.encode(frame, &mut pin.wr)?;
pin.out_addr = out_addr;
pin.flushed = false;
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.flushed {
return Poll::Ready(Ok(()));
}
let Self {
ref mut socket,
ref mut out_addr,
ref mut wr,
..
} = *self;
let n = ready!(socket.poll_send_to(cx, &wr, &out_addr))?;
let wrote_all = n == self.wr.len();
self.wr.clear();
self.flushed = true;
let res = if wrote_all {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"failed to write entire datagram to socket",
)
.into())
};
Poll::Ready(res)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.poll_flush(cx))?;
Poll::Ready(Ok(()))
}
}
const INITIAL_RD_CAPACITY: usize = 64 * 1024;
const INITIAL_WR_CAPACITY: usize = 8 * 1024;
impl<C> UdpFramed<C> {
/// Create a new `UdpFramed` backed by the given socket and codec.
///
/// See struct level documentation for more details.
pub fn new(socket: UdpSocket, codec: C) -> UdpFramed<C> {
UdpFramed {
socket,
codec,
out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
flushed: true,
}
}
/// Returns a reference to the underlying I/O stream wrapped by `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &UdpSocket {
&self.socket
}
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut UdpSocket {
&mut self.socket
}
/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> UdpSocket {
self.socket
}
}