refactor: more read an write

This commit is contained in:
quininer 2019-10-01 14:36:16 +08:00
parent 66f17e3b18
commit 86171d34a8
4 changed files with 88 additions and 57 deletions

View File

@ -52,7 +52,7 @@ where
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof); let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; futures::ready!(stream.handshake(cx))?;
} }
@ -127,7 +127,7 @@ where
// write early data // write early data
if let Some(mut early_data) = stream.session.early_data() { if let Some(mut early_data) = stream.session.early_data() {
let len = match dbg!(early_data.write(buf)) { let len = match early_data.write(buf) {
Ok(n) => n, Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending, return Poll::Pending,
@ -138,7 +138,7 @@ where
} }
// complete handshake // complete handshake
if stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; futures::ready!(stream.handshake(cx))?;
} }
@ -166,7 +166,7 @@ where
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")] {
// complete handshake // complete handshake
if stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; futures::ready!(stream.handshake(cx))?;
} }
} }

View File

@ -33,6 +33,18 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Pin::new(self) Pin::new(self)
} }
fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> {
self.session.process_new_packets()
.map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})
}
fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> { struct Reader<'a, 'b, T> {
io: &'a mut T, io: &'a mut T,
@ -56,16 +68,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Err(err) => return Poll::Ready(Err(err)) Err(err) => return Poll::Ready(Err(err))
}; };
self.session.process_new_packets()
.map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})?;
Poll::Ready(Ok(n)) Poll::Ready(Ok(n))
} }
@ -118,29 +120,31 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
} }
if !self.eof && self.session.wants_read() { while !self.eof && self.session.wants_read() {
match self.read_io(cx) { match self.read_io(cx) {
Poll::Ready(Ok(0)) => self.eof = true, Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n, Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true, Poll::Pending => {
read_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
} }
} }
let would_block = write_would_block || read_would_block; self.process_new_packets(cx)?;
return match (self.eof, self.session.is_handshaking(), would_block) { return match (self.eof, self.session.is_handshaking()) {
(true, true, _) => { (true, true) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
Poll::Ready(Err(err)) Poll::Ready(Err(err))
}, },
(_, false, true) => if rdlen != 0 || wrlen != 0 { (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 {
Poll::Ready(Ok((rdlen, wrlen))) Poll::Ready(Ok((rdlen, wrlen)))
} else { } else {
Poll::Pending Poll::Pending
}, },
(_, false, _) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true, true) => Poll::Pending,
(..) => continue (..) => continue
} }
} }
@ -150,53 +154,80 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut pos = 0;
while this.session.wants_read() { while pos != buf.len() {
match this.read_io(cx) { let mut would_block = false;
Poll::Ready(Ok(0)) => break,
Poll::Ready(Ok(_)) => (), // read a packet
Poll::Pending => return Poll::Pending, while this.session.wants_read() {
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) match this.read_io(cx) {
Poll::Ready(Ok(0)) => {
this.eof = true;
break
},
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
this.process_new_packets(cx)?;
return match this.session.read(&mut buf[pos..]) {
Ok(0) if pos == 0 && would_block => Poll::Pending,
Ok(n) if this.eof || would_block => Poll::Ready(Ok(pos + n)),
Ok(n) => {
pos += n;
continue
},
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 =>
Poll::Ready(Ok(pos)),
Err(err) => Poll::Ready(Err(err))
} }
} }
match this.session.read(buf) { Poll::Ready(Ok(pos))
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
}
} }
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut pos = 0;
let len = match this.session.write(buf) { while pos != buf.len() {
Ok(n) => n, let mut would_block = false;
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending, match this.session.write(&buf[pos..]) {
Err(err) => return Poll::Ready(Err(err)) Ok(n) => pos += n,
}; Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
while this.session.wants_write() { Err(err) => return Poll::Ready(Err(err))
match this.write_io(cx) { };
Poll::Ready(Ok(_)) => (),
Poll::Pending if len != 0 => break, while this.session.wants_write() {
Poll::Pending => return Poll::Pending, match this.write_io(cx) {
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break
},
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
return match (pos, would_block) {
(0, true) => Poll::Pending,
(n, true) => Poll::Ready(Ok(n)),
(_, false) => continue
} }
} }
if len != 0 || buf.is_empty() { Poll::Ready(Ok(pos))
Poll::Ready(Ok(len))
} else {
// not write zero
match this.session.write(buf) {
Ok(0) => Poll::Pending,
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(err) => Poll::Ready(Err(err))
}
}
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {

View File

@ -187,7 +187,7 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut
let mut good = Good(server); let mut good = Good(server);
let mut stream = Stream::new(&mut good, client); let mut stream = Stream::new(&mut good, client);
if stream.session.is_handshaking() { while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?; ready!(stream.handshake(cx))?;
} }

View File

@ -47,7 +47,7 @@ where
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof); let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; futures::ready!(stream.handshake(cx))?;
} }