Ensure errors on TCP writes happen on Windows (#725)
Previously an error in a TCP write might accidentally get covered up as the
`write` function didn't check for `State::Error`. This updates that logic to
propagate the error out to ensure if we see an error it goes upwards.
diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs
index 49f53e0..bd81fbc 100644
--- a/src/sys/windows/selector.rs
+++ b/src/sys/windows/selector.rs
@@ -1,6 +1,6 @@
#![allow(deprecated)]
-use std::{fmt, io, u32};
+use std::{fmt, io};
use std::cell::UnsafeCell;
use std::os::windows::prelude::*;
use std::sync::{Arc, Mutex};
diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs
index b84b09c..e8712e4 100644
--- a/src/sys/windows/tcp.rs
+++ b/src/sys/windows/tcp.rs
@@ -333,9 +333,13 @@
let mut me = self.inner();
let me = &mut *me;
- match me.write {
+ match mem::replace(&mut me.write, State::Empty) {
State::Empty => {}
- _ => return Err(io::ErrorKind::WouldBlock.into())
+ State::Error(e) => return Err(e),
+ other => {
+ me.write = other;
+ return Err(io::ErrorKind::WouldBlock.into())
+ }
}
if !me.iocp.registered() {
@@ -450,13 +454,14 @@
// About to write, clear any pending level triggered events
me.iocp.set_readiness(me.iocp.readiness() - Ready::writable());
- trace!("scheduling a write");
loop {
+ trace!("scheduling a write of {} bytes", buf[pos..].len());
let ret = unsafe {
self.inner.socket.write_overlapped(&buf[pos..], self.inner.write.as_mut_ptr())
};
match ret {
Ok(Some(transferred_bytes)) if me.instant_notify => {
+ trace!("done immediately with {} bytes", transferred_bytes);
if transferred_bytes == buf.len() - pos {
self.add_readiness(me, Ready::writable());
me.write = State::Empty;
@@ -465,12 +470,14 @@
pos += transferred_bytes;
}
Ok(_) => {
+ trace!("scheduled for later");
// see docs above on StreamImp.inner for rationale on forget
me.write = State::Pending((buf, pos));
mem::forget(self.clone());
break;
}
Err(e) => {
+ trace!("write error: {}", e);
me.write = State::Error(e);
self.add_readiness(me, Ready::writable());
me.iocp.put_buffer(buf);
diff --git a/test/test_tcp.rs b/test/test_tcp.rs
index 2a95b2c..24a816a 100644
--- a/test/test_tcp.rs
+++ b/test/test_tcp.rs
@@ -517,27 +517,80 @@
}
- #[test]
- #[cfg_attr(target_os = "fuchsia", ignore)]
- fn connect_error() {
- let poll = Poll::new().unwrap();
- let mut events = Events::with_capacity(16);
+#[test]
+#[cfg_attr(target_os = "fuchsia", ignore)]
+fn connect_error() {
+ let poll = Poll::new().unwrap();
+ let mut events = Events::with_capacity(16);
- // Pick a "random" port that shouldn't be in use.
- let l = TcpStream::connect(&"127.0.0.1:38381".parse().unwrap()).unwrap();
- poll.register(&l, Token(0), Ready::writable(), PollOpt::edge()).unwrap();
+ // Pick a "random" port that shouldn't be in use.
+ let l = TcpStream::connect(&"127.0.0.1:38381".parse().unwrap()).unwrap();
+ poll.register(&l, Token(0), Ready::writable(), PollOpt::edge()).unwrap();
- 'outer:
- loop {
- poll.poll(&mut events, None).unwrap();
+ 'outer:
+ loop {
+ poll.poll(&mut events, None).unwrap();
- for event in &events {
- if event.token() == Token(0) {
- assert!(event.readiness().is_writable());
- break 'outer
- }
- }
- }
+ for event in &events {
+ if event.token() == Token(0) {
+ assert!(event.readiness().is_writable());
+ break 'outer
+ }
+ }
+ }
- assert_eq!(l.take_error().unwrap().unwrap().kind(), io::ErrorKind::ConnectionRefused);
- }
+ assert!(l.take_error().unwrap().is_some());
+}
+
+#[test]
+fn write_error() {
+ let poll = Poll::new().unwrap();
+ let mut events = Events::with_capacity(16);
+ let (tx, rx) = channel();
+
+ let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
+ let addr = listener.local_addr().unwrap();
+ let t = thread::spawn(move || {
+ let (conn, _addr) = listener.accept().unwrap();
+ rx.recv().unwrap();
+ drop(conn);
+ });
+
+ let mut s = TcpStream::connect(&addr).unwrap();
+ poll.register(&s,
+ Token(0),
+ Ready::readable() | Ready::writable(),
+ PollOpt::edge()).unwrap();
+
+ let mut wait_writable = || {
+ 'outer:
+ loop {
+ poll.poll(&mut events, None).unwrap();
+
+ for event in &events {
+ if event.token() == Token(0) && event.readiness().is_writable() {
+ break 'outer
+ }
+ }
+ }
+ };
+
+ wait_writable();
+
+ tx.send(()).unwrap();
+ t.join().unwrap();
+
+ let buf = [0; 1024];
+ loop {
+ match s.write(&buf) {
+ Ok(_) => {}
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ wait_writable()
+ }
+ Err(e) => {
+ println!("good error: {}", e);
+ break
+ }
+ }
+ }
+}