feat: try futures 0.2

This commit is contained in:
quininer 2018-03-20 20:17:44 +08:00
parent daac8f585f
commit 9f78454cf1
3 changed files with 64 additions and 58 deletions

View File

@ -15,17 +15,15 @@ travis-ci = { repository = "quininer/tokio-rustls" }
appveyor = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" }
[dependencies] [dependencies]
futures = "0.1.15" futures = "0.2.0-alpha"
tokio-io = "0.1.3" tokio = { version = "0.1", features = [ "unstable-futures" ] }
rustls = "0.12" rustls = "0.12"
webpki = "0.18.0-alpha" webpki = "0.18.0-alpha"
tokio-proto = { version = "0.1.1", optional = true }
[dev-dependencies] [dev-dependencies]
tokio-core = "0.1" tokio = { version = "0.1", features = [ "unstable-futures" ] }
tokio = "0.1"
clap = "2.26" clap = "2.26"
webpki-roots = "0.14" webpki-roots = "0.14"
[target.'cfg(unix)'.dev-dependencies] [patch.crates-io]
tokio-file-unix = "0.4" tokio = { git = "https://github.com/tokio-rs/tokio" }

View File

@ -1,17 +1,14 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
extern crate futures;
#[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; extern crate tokio;
#[macro_use] extern crate tokio_io;
extern crate rustls; extern crate rustls;
extern crate webpki; extern crate webpki;
pub mod proto;
use std::io; use std::io;
use std::sync::Arc; use std::sync::Arc;
use futures::{ Future, Poll, Async }; use futures::{ Future, Poll, Async };
use tokio_io::{ AsyncRead, AsyncWrite }; use futures::task::Context;
use rustls::{ use rustls::{
Session, ClientSession, ServerSession, Session, ClientSession, ServerSession,
ClientConfig, ServerConfig ClientConfig, ServerConfig
@ -22,14 +19,14 @@ use rustls::{
pub trait ClientConfigExt { pub trait ClientConfigExt {
fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S) fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S)
-> ConnectAsync<S> -> ConnectAsync<S>
where S: AsyncRead + AsyncWrite; where S: io::Read + io::Write;
} }
/// Extension trait for the `Arc<ServerConfig>` type in the `rustls` crate. /// Extension trait for the `Arc<ServerConfig>` type in the `rustls` crate.
pub trait ServerConfigExt { pub trait ServerConfigExt {
fn accept_async<S>(&self, stream: S) fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S> -> AcceptAsync<S>
where S: AsyncRead + AsyncWrite; where S: io::Read + io::Write;
} }
@ -45,7 +42,7 @@ pub struct AcceptAsync<S>(MidHandshake<S, ServerSession>);
impl ClientConfigExt for Arc<ClientConfig> { impl ClientConfigExt for Arc<ClientConfig> {
fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S) fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S)
-> ConnectAsync<S> -> ConnectAsync<S>
where S: AsyncRead + AsyncWrite where S: io::Read + io::Write
{ {
connect_async_with_session(stream, ClientSession::new(self, domain)) connect_async_with_session(stream, ClientSession::new(self, domain))
} }
@ -54,7 +51,7 @@ impl ClientConfigExt for Arc<ClientConfig> {
#[inline] #[inline]
pub fn connect_async_with_session<S>(stream: S, session: ClientSession) pub fn connect_async_with_session<S>(stream: S, session: ClientSession)
-> ConnectAsync<S> -> ConnectAsync<S>
where S: AsyncRead + AsyncWrite where S: io::Read + io::Write
{ {
ConnectAsync(MidHandshake { ConnectAsync(MidHandshake {
inner: Some(TlsStream::new(stream, session)) inner: Some(TlsStream::new(stream, session))
@ -64,7 +61,7 @@ pub fn connect_async_with_session<S>(stream: S, session: ClientSession)
impl ServerConfigExt for Arc<ServerConfig> { impl ServerConfigExt for Arc<ServerConfig> {
fn accept_async<S>(&self, stream: S) fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S> -> AcceptAsync<S>
where S: AsyncRead + AsyncWrite where S: io::Read + io::Write
{ {
accept_async_with_session(stream, ServerSession::new(self)) accept_async_with_session(stream, ServerSession::new(self))
} }
@ -73,28 +70,28 @@ impl ServerConfigExt for Arc<ServerConfig> {
#[inline] #[inline]
pub fn accept_async_with_session<S>(stream: S, session: ServerSession) pub fn accept_async_with_session<S>(stream: S, session: ServerSession)
-> AcceptAsync<S> -> AcceptAsync<S>
where S: AsyncRead + AsyncWrite where S: io::Read + io::Write
{ {
AcceptAsync(MidHandshake { AcceptAsync(MidHandshake {
inner: Some(TlsStream::new(stream, session)) inner: Some(TlsStream::new(stream, session))
}) })
} }
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> { impl<S: io::Read + io::Write> Future for ConnectAsync<S> {
type Item = TlsStream<S, ClientSession>; type Item = TlsStream<S, ClientSession>;
type Error = io::Error; type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll() self.0.poll(ctx)
} }
} }
impl<S: AsyncRead + AsyncWrite> Future for AcceptAsync<S> { impl<S: io::Read + io::Write> Future for AcceptAsync<S> {
type Item = TlsStream<S, ServerSession>; type Item = TlsStream<S, ServerSession>;
type Error = io::Error; type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll() self.0.poll(ctx)
} }
} }
@ -104,12 +101,12 @@ struct MidHandshake<S, C> {
} }
impl<S, C> Future for MidHandshake<S, C> impl<S, C> Future for MidHandshake<S, C>
where S: AsyncRead + AsyncWrite, C: Session where S: io::Read + io::Write, C: Session
{ {
type Item = TlsStream<S, C>; type Item = TlsStream<S, C>;
type Error = io::Error; type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self, _: &mut Context) -> Poll<Self::Item, Self::Error> {
loop { loop {
let stream = self.inner.as_mut().unwrap(); let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break }; if !stream.session.is_handshaking() { break };
@ -121,7 +118,7 @@ impl<S, C> Future for MidHandshake<S, C>
(..) => break (..) => break
}, },
Err(e) => match (e.kind(), stream.session.is_handshaking()) { Err(e) => match (e.kind(), stream.session.is_handshaking()) {
(io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending),
(io::ErrorKind::WouldBlock, false) => break, (io::ErrorKind::WouldBlock, false) => break,
(..) => return Err(e) (..) => return Err(e)
} }
@ -154,7 +151,7 @@ impl<S, C> TlsStream<S, C> {
} }
impl<S, C> TlsStream<S, C> impl<S, C> TlsStream<S, C>
where S: AsyncRead + AsyncWrite, C: Session where S: io::Read + io::Write, C: Session
{ {
#[inline] #[inline]
pub fn new(io: S, session: C) -> TlsStream<S, C> { pub fn new(io: S, session: C) -> TlsStream<S, C> {
@ -214,7 +211,7 @@ impl<S, C> TlsStream<S, C>
} }
impl<S, C> io::Read for TlsStream<S, C> impl<S, C> io::Read for TlsStream<S, C>
where S: AsyncRead + AsyncWrite, C: Session where S: io::Read + io::Write, C: Session
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop { loop {
@ -233,7 +230,7 @@ impl<S, C> io::Read for TlsStream<S, C>
} }
impl<S, C> io::Write for TlsStream<S, C> impl<S, C> io::Write for TlsStream<S, C>
where S: AsyncRead + AsyncWrite, C: Session where S: io::Read + io::Write, C: Session
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() { if buf.is_empty() {
@ -272,26 +269,40 @@ impl<S, C> io::Write for TlsStream<S, C>
} }
} }
impl<S, C> AsyncRead for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{}
impl<S, C> AsyncWrite for TlsStream<S, C> mod tokio_impl {
where use super::*;
S: AsyncRead + AsyncWrite, use tokio::io::{ AsyncRead, AsyncWrite };
C: Session use tokio::prelude::Poll;
{
fn shutdown(&mut self) -> Poll<(), io::Error> { impl<S, C> AsyncRead for TlsStream<S, C>
if !self.is_shutdown { where
self.session.send_close_notify(); S: AsyncRead + AsyncWrite,
self.is_shutdown = true; C: Session
{}
impl<S, C> AsyncWrite for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
if !self.is_shutdown {
self.session.send_close_notify();
self.is_shutdown = true;
}
while self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
}
self.io.flush()?;
self.io.shutdown()
} }
while self.session.wants_write() {
try_nb!(self.session.write_tls(&mut self.io));
}
try_nb!(self.io.flush());
self.io.shutdown()
} }
} }
mod futures_impl {
use super::*;
use futures::io::{ AsyncRead, AsyncWrite };
// TODO
}

View File

@ -1,7 +1,6 @@
extern crate rustls; extern crate rustls;
extern crate futures; extern crate futures;
extern crate tokio; extern crate tokio;
extern crate tokio_io;
extern crate tokio_rustls; extern crate tokio_rustls;
extern crate webpki; extern crate webpki;
@ -10,10 +9,9 @@ use std::io::{ BufReader, Cursor };
use std::sync::Arc; use std::sync::Arc;
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; use std::net::{ SocketAddr, IpAddr, Ipv4Addr };
use futures::{ Future, Stream }; use futures::{ FutureExt, StreamExt };
use tokio::executor::current_thread;
use tokio::net::{ TcpListener, TcpStream }; use tokio::net::{ TcpListener, TcpStream };
use tokio_io::io as aio; use tokio::io as aio;
use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys }; use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt };
@ -48,13 +46,12 @@ fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
.map(drop) .map(drop)
.map_err(drop); .map_err(drop);
current_thread::spawn(done); tokio::spawn2(done);
Ok(()) Ok(())
}) })
.map(drop) .then(|_| Ok(()));
.map_err(drop);
current_thread::run(|_| current_thread::spawn(done)); tokio::runtime::run2(done);
}); });
recv.recv().unwrap() recv.recv().unwrap()