From 518ad51376ace135487d29dd1e41e0f1392b9c40 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 15:29:16 +0800 Subject: [PATCH] impl complete_io --- src/common.rs | 97 +++++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 4 +- src/tokio_impl.rs | 18 +++++---- 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/src/common.rs b/src/common.rs index 799ee7e..df83537 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,16 +12,98 @@ pub struct Stream<'a, S: 'a, IO: 'a> { io: &'a mut IO } -/* -impl<'a, S: Session, IO: Write> Stream<'a, S, IO> { - pub default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) +pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { + fn write_tls(&mut self) -> io::Result; + fn complete_io(&mut self) -> io::Result<(usize, usize)>; +} + +impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { + pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { + Stream { session, io } } } -*/ -impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { - pub fn write_tls(&mut self) -> io::Result { +impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } + + fn complete_io(&mut self) -> io::Result<(usize, usize)> { + // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 + + let until_handshaked = self.session.is_handshaking(); + let mut eof = false; + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + while self.session.wants_write() { + wrlen += self.write_tls()?; + } + + if !until_handshaked && wrlen > 0 { + return Ok((rdlen, wrlen)); + } + + if !eof && self.session.wants_read() { + match self.session.read_tls(self.io)? { + 0 => eof = true, + n => rdlen += n + } + } + + match self.session.process_new_packets() { + Ok(_) => {}, + Err(e) => { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ignored = self.write_tls(); + + return Err(io::Error::new(io::ErrorKind::InvalidData, e)); + }, + }; + + match (eof, until_handshaked, self.session.is_handshaking()) { + (_, true, false) => return Ok((rdlen, wrlen)), + (_, false, _) => return Ok((rdlen, wrlen)), + (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (..) => () + } + } + } +} + +impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + while self.session.wants_read() { + if let (0, 0) = self.complete_io()? { + break + } + } + + self.session.read(buf) + } +} + +impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = self.session.write(buf)?; + self.complete_io()?; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + self.session.flush()?; + if self.session.wants_write() { + self.complete_io()?; + } + Ok(()) + } +} + +impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { struct V<'a, IO: 'a>(&'a mut IO); impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { @@ -41,6 +123,7 @@ impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { } +// TODO test struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, diff --git a/src/lib.rs b/src/lib.rs index 81da5fe..f8432d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +#![feature(specialization)] + pub extern crate rustls; pub extern crate webpki; @@ -18,8 +20,8 @@ use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, - Stream }; +use common::Stream; /// Extension trait for the `Arc` type in the `rustls` crate. diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index e9a00a9..663d6ca 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -2,6 +2,7 @@ use super::*; use tokio::prelude::*; use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::prelude::Poll; +use common::{ Stream, CompleteIo }; impl Future for ConnectAsync { @@ -29,16 +30,17 @@ impl Future for MidHandshake type Error = io::Error; fn poll(&mut self) -> Poll { - loop { + { let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; + if stream.session.is_handshaking() { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(session, io); - let (io, session) = stream.get_mut(); - - match session.complete_io(io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) + match stream.complete_io() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) + } } }