refactor: more read an write

This commit is contained in:
quininer 2019-10-01 14:36:16 +08:00
parent 66f17e3b18
commit 86171d34a8
4 changed files with 88 additions and 57 deletions

View File

@ -52,7 +52,7 @@ where
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() {
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}
@ -127,7 +127,7 @@ where
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match dbg!(early_data.write(buf)) {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending,
@ -138,7 +138,7 @@ where
}
// complete handshake
if stream.session.is_handshaking() {
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}
@ -166,7 +166,7 @@ where
#[cfg(feature = "early-data")] {
// complete handshake
if stream.session.is_handshaking() {
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}
}

View File

@ -33,6 +33,18 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Pin::new(self)
}
fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> {
self.session.process_new_packets()
.map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})
}
fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> {
io: &'a mut T,
@ -56,16 +68,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Err(err) => return Poll::Ready(Err(err))
};
self.session.process_new_packets()
.map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})?;
Poll::Ready(Ok(n))
}
@ -118,29 +120,31 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}
}
if !self.eof && self.session.wants_read() {
while !self.eof && self.session.wants_read() {
match self.read_io(cx) {
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true,
Poll::Pending => {
read_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
let would_block = write_would_block || read_would_block;
self.process_new_packets(cx)?;
return match (self.eof, self.session.is_handshaking(), would_block) {
(true, true, _) => {
return match (self.eof, self.session.is_handshaking()) {
(true, true) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
Poll::Ready(Err(err))
},
(_, false, true) => if rdlen != 0 || wrlen != 0 {
(_, false) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 {
Poll::Ready(Ok((rdlen, wrlen)))
} else {
Poll::Pending
},
(_, false, _) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true, true) => Poll::Pending,
(..) => continue
}
}
@ -150,53 +154,80 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut pos = 0;
while this.session.wants_read() {
match this.read_io(cx) {
Poll::Ready(Ok(0)) => break,
Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
while pos != buf.len() {
let mut would_block = false;
// read a packet
while this.session.wants_read() {
match this.read_io(cx) {
Poll::Ready(Ok(0)) => {
this.eof = true;
break
},
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
this.process_new_packets(cx)?;
return match this.session.read(&mut buf[pos..]) {
Ok(0) if pos == 0 && would_block => Poll::Pending,
Ok(n) if this.eof || would_block => Poll::Ready(Ok(pos + n)),
Ok(n) => {
pos += n;
continue
},
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 =>
Poll::Ready(Ok(pos)),
Err(err) => Poll::Ready(Err(err))
}
}
match this.session.read(buf) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
}
Poll::Ready(Ok(pos))
}
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut pos = 0;
let len = match this.session.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending,
Err(err) => return Poll::Ready(Err(err))
};
while this.session.wants_write() {
match this.write_io(cx) {
Poll::Ready(Ok(_)) => (),
Poll::Pending if len != 0 => break,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
while pos != buf.len() {
let mut would_block = false;
match this.session.write(&buf[pos..]) {
Ok(n) => pos += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err))
};
while this.session.wants_write() {
match this.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break
},
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
return match (pos, would_block) {
(0, true) => Poll::Pending,
(n, true) => Poll::Ready(Ok(n)),
(_, false) => continue
}
}
if len != 0 || buf.is_empty() {
Poll::Ready(Ok(len))
} else {
// not write zero
match this.session.write(buf) {
Ok(0) => Poll::Pending,
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(err) => Poll::Ready(Err(err))
}
}
Poll::Ready(Ok(pos))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {

View File

@ -187,7 +187,7 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut
let mut good = Good(server);
let mut stream = Stream::new(&mut good, client);
if stream.session.is_handshaking() {
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

View File

@ -47,7 +47,7 @@ where
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() {
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}