blob: 8e4a5b1e7a83bb7de76287fbde1180a61c61f109 [file] [log] [blame]
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
mod vecbuf;
use std::io::{ self, Read, Write };
#[cfg(feature = "nightly")]
use std::io::Initializer;
use rustls::Session;
#[cfg(feature = "nightly")]
use rustls::WriteV;
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
use tokio_io::AsyncWrite;
pub struct Stream<'a, S: 'a, IO: 'a> {
pub session: &'a mut S,
pub io: &'a mut IO
}
pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write {
fn write_tls(&mut self) -> io::Result<usize>;
}
impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> {
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
Stream { session, io }
}
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161
let until_handshaked = self.session.is_handshaking();
let mut eof = false;
let mut wrlen = 0;
let mut rdlen = 0;
loop {
while self.session.wants_write() {
wrlen += self.write_tls()?;
}
if !until_handshaked && wrlen > 0 {
return Ok((rdlen, wrlen));
}
if !eof && self.session.wants_read() {
match self.session.read_tls(self.io)? {
0 => eof = true,
n => rdlen += n
}
}
match self.session.process_new_packets() {
Ok(_) => {},
Err(e) => {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ignored = self.write_tls();
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
},
};
match (eof, until_handshaked, self.session.is_handshaking()) {
(_, true, false) => return Ok((rdlen, wrlen)),
(_, false, _) => return Ok((rdlen, wrlen)),
(true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(..) => ()
}
}
}
}
#[cfg(not(feature = "nightly"))]
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
fn write_tls(&mut self) -> io::Result<usize> {
self.session.write_tls(self.io)
}
}
#[cfg(feature = "nightly")]
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
default fn write_tls(&mut self) -> io::Result<usize> {
self.session.write_tls(self.io)
}
}
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
fn write_tls(&mut self) -> io::Result<usize> {
use futures::Async;
use self::vecbuf::VecBuf;
struct V<'a, IO: 'a>(&'a mut IO);
impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> {
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize> {
let mut vbytes = VecBuf::new(vbytes);
match self.0.write_buf(&mut vbytes) {
Ok(Async::Ready(n)) => Ok(n),
Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()),
Err(err) => Err(err)
}
}
}
let mut vecbuf = V(self.io);
self.session.writev_tls(&mut vecbuf)
}
}
impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> {
#[cfg(feature = "nightly")]
unsafe fn initializer(&self) -> Initializer {
Initializer::nop()
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.session.wants_read() {
if let (0, 0) = self.complete_io()? {
break
}
}
self.session.read(buf)
}
}
impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let len = self.session.write(buf)?;
while self.session.wants_write() {
match self.complete_io() {
Ok(_) => (),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
Err(err) => return Err(err)
}
}
Ok(len)
}
fn flush(&mut self) -> io::Result<()> {
self.session.flush()?;
if self.session.wants_write() {
self.complete_io()?;
}
Ok(())
}
}
#[cfg(test)]
mod test_stream;