From 0485be9e4bb9bea80398d5a2c220ec21a5a96443 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 16 Apr 2019 20:32:11 +0800 Subject: [PATCH] refactor complete_io --- src/client.rs | 3 +- src/common/mod.rs | 103 +++++++++++++++++++++++++++----------- src/common/test_stream.rs | 2 +- 3 files changed, 76 insertions(+), 32 deletions(-) diff --git a/src/client.rs b/src/client.rs index 27ab944..c4d93ee 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,5 @@ use super::*; +use std::io::Write; use rustls::Session; @@ -197,7 +198,7 @@ where IO: AsyncRead + AsyncWrite } let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.complete_io()); + try_nb!(stream.flush()); stream.io.shutdown() } } diff --git a/src/common/mod.rs b/src/common/mod.rs index e4e25cb..14d2f71 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -15,51 +15,94 @@ pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { fn write_tls(&mut self) -> io::Result; } +#[derive(Clone, Copy)] +enum Focus { + Empty, + Readable, + Writable +} + impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { io, session } } pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { - // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 + self.complete_inner_io(Focus::Empty) + } - let until_handshaked = self.session.is_handshaking(); - let mut eof = false; + fn complete_read_io(&mut self) -> io::Result { + let n = self.session.read_tls(self.io)?; + + self.session.process_new_packets() + .map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_tls(); + + io::Error::new(io::ErrorKind::InvalidData, err) + })?; + + Ok(n) + } + + fn complete_write_io(&mut self) -> io::Result { + self.write_tls() + } + + fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> { let mut wrlen = 0; let mut rdlen = 0; + let mut eof = false; loop { + let mut write_would_block = false; + let mut read_would_block = false; + while self.session.wants_write() { - wrlen += self.write_tls()?; - } - - if !until_handshaked && wrlen > 0 { - return Ok((rdlen, wrlen)); - } - - if !eof && self.session.wants_read() { - match self.session.read_tls(self.io)? { - 0 => eof = true, - n => rdlen += n + match self.complete_write_io() { + Ok(n) => wrlen += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + write_would_block = true; + break + }, + Err(err) => return Err(err) } } - match self.session.process_new_packets() { - Ok(_) => {}, - Err(e) => { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ignored = self.write_tls(); + if !eof && self.session.wants_read() { + match self.complete_read_io() { + Ok(0) => eof = true, + Ok(n) => rdlen += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true, + Err(err) => return Err(err) + } + } - return Err(io::Error::new(io::ErrorKind::InvalidData, e)); - }, + let would_block = match focus { + Focus::Empty => write_would_block || read_would_block, + Focus::Readable => read_would_block, + Focus::Writable => write_would_block, }; - match (eof, until_handshaked, self.session.is_handshaking()) { - (_, true, false) => return Ok((rdlen, wrlen)), + match (eof, self.session.is_handshaking(), would_block) { + (true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()), + (_, false, true) => { + let would_block = match focus { + Focus::Empty => rdlen == 0 && wrlen == 0, + Focus::Readable => rdlen == 0, + Focus::Writable => wrlen == 0 + }; + + return if would_block { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Ok((rdlen, wrlen)) + }; + }, (_, false, _) => return Ok((rdlen, wrlen)), - (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (_, true, true) => return Err(io::ErrorKind::WouldBlock.into()), (..) => () } } @@ -92,7 +135,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream< impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { - if let (0, 0) = self.complete_io()? { + if let (0, _) = self.complete_inner_io(Focus::Readable)? { break } } @@ -104,7 +147,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; while self.session.wants_write() { - match self.complete_io() { + match self.complete_inner_io(Focus::Writable) { Ok(_) => (), Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, Err(err) => return Err(err) @@ -126,8 +169,8 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - if self.session.wants_write() { - self.complete_io()?; + while self.session.wants_write() { + self.complete_inner_io(Focus::Writable)?; } Ok(()) } diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index a43622c..744758a 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -85,7 +85,7 @@ fn stream_good() -> io::Result<()> { let mut buf = Vec::new(); stream.read_to_end(&mut buf)?; assert_eq!(buf, FILE); - stream.write_all(b"Hello World!")? + stream.write_all(b"Hello World!")?; } let mut buf = String::new();