diff --git a/src/client.rs b/src/client.rs index 613cd69..8527121 100644 --- a/src/client.rs +++ b/src/client.rs @@ -45,10 +45,13 @@ where type Output = io::Result>; #[inline] - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - if let MidHandshake::Handshaking(stream) = &mut *self { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(stream) = this { + let eof = !stream.state.readable(); let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { try_ready!(stream.complete_io(cx)); @@ -59,7 +62,7 @@ where } } - match mem::replace(&mut *self, MidHandshake::End) { + match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), #[cfg(feature = "early-data")] MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), @@ -83,7 +86,8 @@ where TlsState::EarlyData => { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); let (pos, data) = &mut this.early_data; // complete handshake @@ -107,7 +111,8 @@ where } TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); match stream.poll_read(cx, buf) { Poll::Ready(Ok(0)) => { @@ -136,9 +141,10 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); match this.state { #[cfg(feature = "early-data")] @@ -174,9 +180,11 @@ where } } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session).poll_flush(cx) + Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()) + .poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -186,7 +194,8 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); try_ready!(stream.poll_flush(cx)); Pin::new(&mut this.io).poll_close(cx) } diff --git a/src/common/mod.rs b/src/common/mod.rs index 71b442f..dabdcbc 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use std::task::Poll; use std::marker::Unpin; -use std::io::{ self, Read, Write }; +use std::io::{ self, Read }; use rustls::Session; use rustls::WriteV; use futures::task::Context; diff --git a/src/lib.rs b/src/lib.rs index 19f35dc..f962fc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ macro_rules! try_ready { pub mod client; mod common; -// pub mod server; +pub mod server; use common::Stream; use std::pin::Pin; @@ -25,7 +25,7 @@ use std::{io, mem}; use webpki::DNSNameRef; #[derive(Debug, Copy, Clone)] -pub enum TlsState { +enum TlsState { #[cfg(feature = "early-data")] EarlyData, Stream, @@ -35,26 +35,33 @@ pub enum TlsState { } impl TlsState { - pub(crate) fn shutdown_read(&mut self) { + fn shutdown_read(&mut self) { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, _ => *self = TlsState::ReadShutdown, } } - pub(crate) fn shutdown_write(&mut self) { + fn shutdown_write(&mut self) { match *self { TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, _ => *self = TlsState::WriteShutdown, } } - pub(crate) fn writeable(&self) -> bool { + fn writeable(&self) -> bool { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => false, _ => true, } } + + fn readable(self) -> bool { + match self { + TlsState::ReadShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } } /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. @@ -65,7 +72,6 @@ pub struct TlsConnector { early_data: bool, } -/* /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { @@ -170,32 +176,31 @@ impl TlsAcceptor { } } -/// Future returned from `ClientConfigExt::connect_async` which will resolve +/// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(client::MidHandshake); -/// Future returned from `ServerConfigExt::accept_async` which will resolve +/// Future returned from `TlsAcceptor::accept` which will resolve /// once the accept handshake has finished. pub struct Accept(server::MidHandshake); -impl Future for Connect { - type Item = client::TlsStream; - type Error = io::Error; +impl Future for Connect { + type Output = io::Result>; - fn poll(&mut self) -> Poll { - self.0.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.0).poll(cx) } } -impl Future for Accept { - type Item = server::TlsStream; - type Error = io::Error; +impl Future for Accept { + type Output = io::Result>; - fn poll(&mut self) -> Poll { - self.0.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.0).poll(cx) } } +/* #[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; diff --git a/src/server.rs b/src/server.rs index 1568414..9db4867 100644 --- a/src/server.rs +++ b/src/server.rs @@ -34,100 +34,102 @@ impl TlsStream { impl Future for MidHandshake where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - type Item = TlsStream; - type Error = io::Error; + type Output = io::Result>; #[inline] - fn poll(&mut self) -> Poll { - if let MidHandshake::Handshaking(stream) = self { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(stream) = this { + let eof = !stream.state.readable(); let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } if stream.session.wants_write() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } } - match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + match mem::replace(this, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), MidHandshake::End => panic!(), } } } -impl io::Read for TlsStream +impl AsyncRead for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); + unsafe fn initializer(&self) -> Initializer { + // TODO + Initializer::nop() + } - match self.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) { - Ok(0) => { - self.state.shutdown_read(); - Ok(0) + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + + match this.state { + TlsState::Stream | TlsState::WriteShutdown => match stream.poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) } - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state.shutdown_read(); - if self.state.writeable() { + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { stream.session.send_close_notify(); - self.state.shutdown_write(); + this.state.shutdown_write(); } - Ok(0) + Poll::Ready(Ok(0)) } - Err(e) => Err(e), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending }, - TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), #[cfg(feature = "early-data")] s => unreachable!("server TLS can not hit this state: {:?}", s), } } } -impl io::Write for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - stream.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() - } -} - -impl AsyncRead for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - impl AsyncWrite for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn shutdown(&mut self) -> Poll<(), io::Error> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()) + .poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()) + .poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.complete_io()); - stream.io.shutdown() + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + try_ready!(stream.complete_io(cx)); + Pin::new(&mut this.io).poll_close(cx) } }