diff --git a/Cargo.toml b/Cargo.toml index 5a61dcb..f892061 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.10.0-alpha.2" +version = "0.12.0-alpha" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/client.rs b/src/client.rs index 616c151..613cd69 100644 --- a/src/client.rs +++ b/src/client.rs @@ -40,160 +40,154 @@ impl TlsStream { impl Future for MidHandshake where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - type Item = TlsStream; - type Error = io::Error; + type Output = io::Result>; #[inline] - fn poll(&mut self) -> Poll { - if let MidHandshake::Handshaking(stream) = self { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let MidHandshake::Handshaking(stream) = &mut *self { let (io, session) = stream.get_mut(); let mut stream = Stream::new(io, session); if stream.session.is_handshaking() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } if stream.session.wants_write() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } } - match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + match mem::replace(&mut *self, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), #[cfg(feature = "early-data")] - MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), MidHandshake::End => panic!(), } } } -impl io::Read for TlsStream +impl AsyncRead for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.state { - #[cfg(feature = "early-data")] - TlsState::EarlyData => { - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - let (pos, data) = &mut self.early_data; - - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; - } - } - - // end - self.state = TlsState::Stream; - data.clear(); - } - - self.read(buf) - } - TlsState::Stream | TlsState::WriteShutdown => { - let mut stream = Stream::new(&mut self.io, &mut self.session); - - match stream.read(buf) { - Ok(0) => { - self.state.shutdown_read(); - Ok(0) - } - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state.shutdown_read(); - if self.state.writeable() { - stream.session.send_close_notify(); - self.state.shutdown_write(); - } - Ok(0) - } - Err(e) => Err(e), - } - } - TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), - } + unsafe fn initializer(&self) -> Initializer { + // TODO + Initializer::nop() } -} - -impl io::Write for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { - let (pos, data) = &mut self.early_data; + let this = self.get_mut(); - // write early data - if let Some(mut early_data) = stream.session.early_data() { - let len = early_data.write(buf)?; - data.extend_from_slice(&buf[..len]); - return Ok(len); - } + let mut stream = Stream::new(&mut this.io, &mut this.session); + let (pos, data) = &mut this.early_data; // complete handshake if stream.session.is_handshaking() { - stream.complete_io()?; + try_ready!(stream.complete_io(cx)); } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = stream.write(&data[*pos..])?; + let len = try_ready!(stream.poll_write(cx, &data[*pos..])); *pos += len; } } // end - self.state = TlsState::Stream; + this.state = TlsState::Stream; data.clear(); - stream.write(buf) + + Pin::new(this).poll_read(cx, buf) } - _ => stream.write(buf), + TlsState::Stream | TlsState::WriteShutdown => { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session); + + match stream.poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) + } + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { + stream.session.send_close_notify(); + this.state.shutdown_write(); + } + Poll::Ready(Ok(0)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), } } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() - } -} - -impl AsyncRead for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } } impl AsyncWrite for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn shutdown(&mut self) -> Poll<(), io::Error> { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session); + + match this.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData => { + let (pos, data) = &mut this.early_data; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = early_data.write(buf)?; // TODO check pending + data.extend_from_slice(&buf[..len]); + return Poll::Ready(Ok(len)); + } + + // complete handshake + if stream.session.is_handshaking() { + try_ready!(stream.complete_io(cx)); + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = try_ready!(stream.poll_write(cx, &data[*pos..])); + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + data.clear(); + stream.poll_write(cx, buf) + } + _ => stream.poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + Stream::new(&mut this.io, &mut this.session).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.flush()); - stream.io.shutdown() + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session); + try_ready!(stream.poll_flush(cx)); + Pin::new(&mut this.io).poll_close(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index ed29f09..71b442f 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,6 +14,7 @@ use smallvec::SmallVec; pub struct Stream<'a, IO, S> { pub io: &'a mut IO, pub session: &'a mut S, + pub eof: bool } pub trait WriteTls { @@ -29,7 +30,18 @@ enum Focus { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { - Stream { io, session } + Stream { + io, + session, + // The state so far is only used to detect EOF, so either Stream + // or EarlyData state should both be all right. + eof: false, + } + } + + pub fn set_eof(mut self, eof: bool) -> Self { + self.eof = eof; + self } pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { @@ -82,7 +94,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; - let mut eof = false; loop { let mut write_would_block = false; @@ -99,9 +110,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - if !eof && self.session.wants_read() { + if !self.eof && self.session.wants_read() { match self.complete_read_io(cx) { - Poll::Ready(Ok(0)) => eof = true, + Poll::Ready(Ok(0)) => self.eof = true, Poll::Ready(Ok(n)) => rdlen += n, Poll::Pending => read_would_block = true, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -114,7 +125,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Focus::Writable => write_would_block, }; - match (eof, self.session.is_handshaking(), would_block) { + match (self.eof, self.session.is_handshaking(), would_block) { (true, true, _) => return Poll::Pending, (_, false, true) => { let would_block = match focus { @@ -167,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { - fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { + pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { while self.session.wants_read() { match self.complete_inner_io(cx, Focus::Readable) { Poll::Ready(Ok((0, _))) => break, @@ -181,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(self.session.read(buf)) } - fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { + pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { let len = self.session.write(buf)?; while self.session.wants_write() { match self.complete_inner_io(cx, Focus::Writable) { @@ -204,7 +215,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - fn poll_flush(&mut self, cx: &mut Context) -> Poll> { + pub fn poll_flush(&mut self, cx: &mut Context) -> Poll> { self.session.flush()?; while self.session.wants_write() { match self.complete_inner_io(cx, Focus::Writable) { @@ -213,7 +224,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - Poll::Ready(Ok(())) + Pin::new(&mut self.io).poll_flush(cx) } } diff --git a/src/lib.rs b/src/lib.rs index 9e15ed7..19f35dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,27 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -// pub mod client; +macro_rules! try_ready { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(output)) => output, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending + } + } +} + +pub mod client; mod common; // pub mod server; -/* use common::Stream; -use futures::{Async, Future, Poll}; +use std::pin::Pin; +use std::task::{ Poll, Context }; +use std::future::Future; +use futures::io::{ AsyncRead, AsyncWrite, Initializer }; use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; use std::sync::Arc; use std::{io, mem}; -use tokio_io::{try_nb, AsyncRead, AsyncWrite}; use webpki::DNSNameRef; #[derive(Debug, Copy, Clone)] @@ -54,6 +65,7 @@ pub struct TlsConnector { early_data: bool, } +/* /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor {