refactor complete_io

This commit is contained in:
quininer 2019-04-16 20:32:11 +08:00
parent 75b94acaba
commit 0485be9e4b
3 changed files with 76 additions and 32 deletions

View File

@ -1,4 +1,5 @@
use super::*; use super::*;
use std::io::Write;
use rustls::Session; use rustls::Session;
@ -197,7 +198,7 @@ where IO: AsyncRead + AsyncWrite
} }
let mut stream = Stream::new(&mut self.io, &mut self.session); let mut stream = Stream::new(&mut self.io, &mut self.session);
try_nb!(stream.complete_io()); try_nb!(stream.flush());
stream.io.shutdown() stream.io.shutdown()
} }
} }

View File

@ -15,51 +15,94 @@ pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write {
fn write_tls(&mut self) -> io::Result<usize>; fn write_tls(&mut self) -> io::Result<usize>;
} }
#[derive(Clone, Copy)]
enum Focus {
Empty,
Readable,
Writable
}
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream { io, session } Stream { io, session }
} }
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 self.complete_inner_io(Focus::Empty)
let until_handshaked = self.session.is_handshaking();
let mut eof = false;
let mut wrlen = 0;
let mut rdlen = 0;
loop {
while self.session.wants_write() {
wrlen += self.write_tls()?;
} }
if !until_handshaked && wrlen > 0 { fn complete_read_io(&mut self) -> io::Result<usize> {
return Ok((rdlen, wrlen)); let n = self.session.read_tls(self.io)?;
}
if !eof && self.session.wants_read() { self.session.process_new_packets()
match self.session.read_tls(self.io)? { .map_err(|err| {
0 => eof = true,
n => rdlen += n
}
}
match self.session.process_new_packets() {
Ok(_) => {},
Err(e) => {
// In case we have an alert to send describing this error, // In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary // try a last-gasp write -- but don't predate the primary
// error. // error.
let _ignored = self.write_tls(); let _ = self.write_tls();
return Err(io::Error::new(io::ErrorKind::InvalidData, e)); io::Error::new(io::ErrorKind::InvalidData, err)
})?;
Ok(n)
}
fn complete_write_io(&mut self) -> io::Result<usize> {
self.write_tls()
}
fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> {
let mut wrlen = 0;
let mut rdlen = 0;
let mut eof = false;
loop {
let mut write_would_block = false;
let mut read_would_block = false;
while self.session.wants_write() {
match self.complete_write_io() {
Ok(n) => wrlen += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
write_would_block = true;
break
}, },
Err(err) => return Err(err)
}
}
if !eof && self.session.wants_read() {
match self.complete_read_io() {
Ok(0) => eof = true,
Ok(n) => rdlen += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true,
Err(err) => return Err(err)
}
}
let would_block = match focus {
Focus::Empty => write_would_block || read_would_block,
Focus::Readable => read_would_block,
Focus::Writable => write_would_block,
}; };
match (eof, until_handshaked, self.session.is_handshaking()) { match (eof, self.session.is_handshaking(), would_block) {
(_, true, false) => return Ok((rdlen, wrlen)), (true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()),
(_, false, true) => {
let would_block = match focus {
Focus::Empty => rdlen == 0 && wrlen == 0,
Focus::Readable => rdlen == 0,
Focus::Writable => wrlen == 0
};
return if would_block {
Err(io::ErrorKind::WouldBlock.into())
} else {
Ok((rdlen, wrlen))
};
},
(_, false, _) => return Ok((rdlen, wrlen)), (_, false, _) => return Ok((rdlen, wrlen)),
(true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), (_, true, true) => return Err(io::ErrorKind::WouldBlock.into()),
(..) => () (..) => ()
} }
} }
@ -92,7 +135,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.session.wants_read() { while self.session.wants_read() {
if let (0, 0) = self.complete_io()? { if let (0, _) = self.complete_inner_io(Focus::Readable)? {
break break
} }
} }
@ -104,7 +147,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let len = self.session.write(buf)?; let len = self.session.write(buf)?;
while self.session.wants_write() { while self.session.wants_write() {
match self.complete_io() { match self.complete_inner_io(Focus::Writable) {
Ok(_) => (), Ok(_) => (),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
Err(err) => return Err(err) Err(err) => return Err(err)
@ -126,8 +169,8 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
self.session.flush()?; self.session.flush()?;
if self.session.wants_write() { while self.session.wants_write() {
self.complete_io()?; self.complete_inner_io(Focus::Writable)?;
} }
Ok(()) Ok(())
} }

View File

@ -85,7 +85,7 @@ fn stream_good() -> io::Result<()> {
let mut buf = Vec::new(); let mut buf = Vec::new();
stream.read_to_end(&mut buf)?; stream.read_to_end(&mut buf)?;
assert_eq!(buf, FILE); assert_eq!(buf, FILE);
stream.write_all(b"Hello World!")? stream.write_all(b"Hello World!")?;
} }
let mut buf = String::new(); let mut buf = String::new();