refactor complete_io
This commit is contained in:
parent
75b94acaba
commit
0485be9e4b
@ -1,4 +1,5 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
use std::io::Write;
|
||||||
use rustls::Session;
|
use rustls::Session;
|
||||||
|
|
||||||
|
|
||||||
@ -197,7 +198,7 @@ where IO: AsyncRead + AsyncWrite
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||||
try_nb!(stream.complete_io());
|
try_nb!(stream.flush());
|
||||||
stream.io.shutdown()
|
stream.io.shutdown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,51 +15,94 @@ pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write {
|
|||||||
fn write_tls(&mut self) -> io::Result<usize>;
|
fn write_tls(&mut self) -> io::Result<usize>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
enum Focus {
|
||||||
|
Empty,
|
||||||
|
Readable,
|
||||||
|
Writable
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
|
||||||
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
|
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
|
||||||
Stream { io, session }
|
Stream { io, session }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
|
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
|
||||||
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161
|
self.complete_inner_io(Focus::Empty)
|
||||||
|
}
|
||||||
|
|
||||||
let until_handshaked = self.session.is_handshaking();
|
fn complete_read_io(&mut self) -> io::Result<usize> {
|
||||||
let mut eof = false;
|
let n = self.session.read_tls(self.io)?;
|
||||||
|
|
||||||
|
self.session.process_new_packets()
|
||||||
|
.map_err(|err| {
|
||||||
|
// In case we have an alert to send describing this error,
|
||||||
|
// try a last-gasp write -- but don't predate the primary
|
||||||
|
// error.
|
||||||
|
let _ = self.write_tls();
|
||||||
|
|
||||||
|
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complete_write_io(&mut self) -> io::Result<usize> {
|
||||||
|
self.write_tls()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> {
|
||||||
let mut wrlen = 0;
|
let mut wrlen = 0;
|
||||||
let mut rdlen = 0;
|
let mut rdlen = 0;
|
||||||
|
let mut eof = false;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
let mut write_would_block = false;
|
||||||
|
let mut read_would_block = false;
|
||||||
|
|
||||||
while self.session.wants_write() {
|
while self.session.wants_write() {
|
||||||
wrlen += self.write_tls()?;
|
match self.complete_write_io() {
|
||||||
}
|
Ok(n) => wrlen += n,
|
||||||
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||||
if !until_handshaked && wrlen > 0 {
|
write_would_block = true;
|
||||||
return Ok((rdlen, wrlen));
|
break
|
||||||
}
|
},
|
||||||
|
Err(err) => return Err(err)
|
||||||
if !eof && self.session.wants_read() {
|
|
||||||
match self.session.read_tls(self.io)? {
|
|
||||||
0 => eof = true,
|
|
||||||
n => rdlen += n
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.session.process_new_packets() {
|
if !eof && self.session.wants_read() {
|
||||||
Ok(_) => {},
|
match self.complete_read_io() {
|
||||||
Err(e) => {
|
Ok(0) => eof = true,
|
||||||
// In case we have an alert to send describing this error,
|
Ok(n) => rdlen += n,
|
||||||
// try a last-gasp write -- but don't predate the primary
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true,
|
||||||
// error.
|
Err(err) => return Err(err)
|
||||||
let _ignored = self.write_tls();
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
|
let would_block = match focus {
|
||||||
},
|
Focus::Empty => write_would_block || read_would_block,
|
||||||
|
Focus::Readable => read_would_block,
|
||||||
|
Focus::Writable => write_would_block,
|
||||||
};
|
};
|
||||||
|
|
||||||
match (eof, until_handshaked, self.session.is_handshaking()) {
|
match (eof, self.session.is_handshaking(), would_block) {
|
||||||
(_, true, false) => return Ok((rdlen, wrlen)),
|
(true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()),
|
||||||
|
(_, false, true) => {
|
||||||
|
let would_block = match focus {
|
||||||
|
Focus::Empty => rdlen == 0 && wrlen == 0,
|
||||||
|
Focus::Readable => rdlen == 0,
|
||||||
|
Focus::Writable => wrlen == 0
|
||||||
|
};
|
||||||
|
|
||||||
|
return if would_block {
|
||||||
|
Err(io::ErrorKind::WouldBlock.into())
|
||||||
|
} else {
|
||||||
|
Ok((rdlen, wrlen))
|
||||||
|
};
|
||||||
|
},
|
||||||
(_, false, _) => return Ok((rdlen, wrlen)),
|
(_, false, _) => return Ok((rdlen, wrlen)),
|
||||||
(true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
|
(_, true, true) => return Err(io::ErrorKind::WouldBlock.into()),
|
||||||
(..) => ()
|
(..) => ()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -92,7 +135,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<
|
|||||||
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> {
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
while self.session.wants_read() {
|
while self.session.wants_read() {
|
||||||
if let (0, 0) = self.complete_io()? {
|
if let (0, _) = self.complete_inner_io(Focus::Readable)? {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -104,7 +147,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
|
|||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
let len = self.session.write(buf)?;
|
let len = self.session.write(buf)?;
|
||||||
while self.session.wants_write() {
|
while self.session.wants_write() {
|
||||||
match self.complete_io() {
|
match self.complete_inner_io(Focus::Writable) {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
|
||||||
Err(err) => return Err(err)
|
Err(err) => return Err(err)
|
||||||
@ -126,8 +169,8 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
|
|||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
self.session.flush()?;
|
self.session.flush()?;
|
||||||
if self.session.wants_write() {
|
while self.session.wants_write() {
|
||||||
self.complete_io()?;
|
self.complete_inner_io(Focus::Writable)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ fn stream_good() -> io::Result<()> {
|
|||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
stream.read_to_end(&mut buf)?;
|
stream.read_to_end(&mut buf)?;
|
||||||
assert_eq!(buf, FILE);
|
assert_eq!(buf, FILE);
|
||||||
stream.write_all(b"Hello World!")?
|
stream.write_all(b"Hello World!")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buf = String::new();
|
let mut buf = String::new();
|
||||||
|
Loading…
Reference in New Issue
Block a user