feat: split tokio_impl/futures_impl

This commit is contained in:
quininer 2018-03-21 13:08:47 +08:00
parent 9f78454cf1
commit 8c79329c7a
5 changed files with 168 additions and 95 deletions

View File

@ -15,15 +15,15 @@ travis-ci = { repository = "quininer/tokio-rustls" }
appveyor = { repository = "quininer/tokio-rustls" }
[dependencies]
futures = "0.2.0-alpha"
tokio = { version = "0.1", features = [ "unstable-futures" ] }
futures = { version = "0.2.0-alpha", optional = true }
tokio = { version = "0.1", optional = true }
rustls = "0.12"
webpki = "0.18.0-alpha"
[dev-dependencies]
tokio = { version = "0.1", features = [ "unstable-futures" ] }
tokio = "0.1"
clap = "2.26"
webpki-roots = "0.14"
[patch.crates-io]
tokio = { git = "https://github.com/tokio-rs/tokio" }
[features]
default = [ "futures", "tokio" ]

80
src/futures_impl.rs Normal file
View File

@ -0,0 +1,80 @@
use super::*;
use futures::{ Future, Poll, Async };
use futures::io::{ Error, AsyncRead, AsyncWrite };
use futures::task::Context;
impl<S: io::Read + io::Write> Future for ConnectAsync<S> {
type Item = TlsStream<S, ClientSession>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll(ctx)
}
}
impl<S: io::Read + io::Write> Future for AcceptAsync<S> {
type Item = TlsStream<S, ServerSession>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll(ctx)
}
}
impl<S, C> Future for MidHandshake<S, C>
where S: io::Read + io::Write, C: Session
{
type Item = TlsStream<S, C>;
type Error = io::Error;
fn poll(&mut self, _: &mut Context) -> Poll<Self::Item, Self::Error> {
loop {
let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break };
match stream.do_io() {
Ok(()) => match (stream.eof, stream.session.is_handshaking()) {
(true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(false, true) => continue,
(..) => break
},
Err(e) => match (e.kind(), stream.session.is_handshaking()) {
(io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending),
(io::ErrorKind::WouldBlock, false) => break,
(..) => return Err(e)
}
}
}
Ok(Async::Ready(self.inner.take().unwrap()))
}
}
impl<S, C> AsyncRead for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn poll_read(&mut self, _: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
unimplemented!()
}
}
impl<S, C> AsyncWrite for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn poll_write(&mut self, _: &mut Context, buf: &[u8]) -> Poll<usize, Error> {
unimplemented!()
}
fn poll_flush(&mut self, _: &mut Context) -> Poll<(), Error> {
unimplemented!()
}
fn poll_close(&mut self, _: &mut Context) -> Poll<(), Error> {
unimplemented!()
}
}

View File

@ -5,10 +5,11 @@ extern crate tokio;
extern crate rustls;
extern crate webpki;
mod tokio_impl;
mod futures_impl;
use std::io;
use std::sync::Arc;
use futures::{ Future, Poll, Async };
use futures::task::Context;
use rustls::{
Session, ClientSession, ServerSession,
ClientConfig, ServerConfig
@ -77,58 +78,11 @@ pub fn accept_async_with_session<S>(stream: S, session: ServerSession)
})
}
impl<S: io::Read + io::Write> Future for ConnectAsync<S> {
type Item = TlsStream<S, ClientSession>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll(ctx)
}
}
impl<S: io::Read + io::Write> Future for AcceptAsync<S> {
type Item = TlsStream<S, ServerSession>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll(ctx)
}
}
struct MidHandshake<S, C> {
inner: Option<TlsStream<S, C>>
}
impl<S, C> Future for MidHandshake<S, C>
where S: io::Read + io::Write, C: Session
{
type Item = TlsStream<S, C>;
type Error = io::Error;
fn poll(&mut self, _: &mut Context) -> Poll<Self::Item, Self::Error> {
loop {
let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break };
match stream.do_io() {
Ok(()) => match (stream.eof, stream.session.is_handshaking()) {
(true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(false, true) => continue,
(..) => break
},
Err(e) => match (e.kind(), stream.session.is_handshaking()) {
(io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending),
(io::ErrorKind::WouldBlock, false) => break,
(..) => return Err(e)
}
}
}
Ok(Async::Ready(self.inner.take().unwrap()))
}
}
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
@ -268,41 +222,3 @@ impl<S, C> io::Write for TlsStream<S, C>
self.io.flush()
}
}
mod tokio_impl {
use super::*;
use tokio::io::{ AsyncRead, AsyncWrite };
use tokio::prelude::Poll;
impl<S, C> AsyncRead for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{}
impl<S, C> AsyncWrite for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
if !self.is_shutdown {
self.session.send_close_notify();
self.is_shutdown = true;
}
while self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
}
self.io.flush()?;
self.io.shutdown()
}
}
}
mod futures_impl {
use super::*;
use futures::io::{ AsyncRead, AsyncWrite };
// TODO
}

76
src/tokio_impl.rs Normal file
View File

@ -0,0 +1,76 @@
use super::*;
use tokio::prelude::*;
use tokio::io::{ AsyncRead, AsyncWrite };
use tokio::prelude::Poll;
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
type Item = TlsStream<S, ClientSession>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.0.poll()
}
}
impl<S: AsyncRead + AsyncWrite> Future for AcceptAsync<S> {
type Item = TlsStream<S, ServerSession>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.0.poll()
}
}
impl<S, C> Future for MidHandshake<S, C>
where S: io::Read + io::Write, C: Session
{
type Item = TlsStream<S, C>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break };
match stream.do_io() {
Ok(()) => match (stream.eof, stream.session.is_handshaking()) {
(true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(false, true) => continue,
(..) => break
},
Err(e) => match (e.kind(), stream.session.is_handshaking()) {
(io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady),
(io::ErrorKind::WouldBlock, false) => break,
(..) => return Err(e)
}
}
}
Ok(Async::Ready(self.inner.take().unwrap()))
}
}
impl<S, C> AsyncRead for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{}
impl<S, C> AsyncWrite for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
if !self.is_shutdown {
self.session.send_close_notify();
self.is_shutdown = true;
}
while self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
}
self.io.flush()?;
self.io.shutdown()
}
}

View File

@ -9,7 +9,8 @@ use std::io::{ BufReader, Cursor };
use std::sync::Arc;
use std::sync::mpsc::channel;
use std::net::{ SocketAddr, IpAddr, Ipv4Addr };
use futures::{ FutureExt, StreamExt };
use tokio::prelude::*;
// use futures::{ FutureExt, StreamExt };
use tokio::net::{ TcpListener, TcpStream };
use tokio::io as aio;
use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig };
@ -46,12 +47,12 @@ fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
.map(drop)
.map_err(drop);
tokio::spawn2(done);
tokio::spawn(done);
Ok(())
})
.then(|_| Ok(()));
tokio::runtime::run2(done);
tokio::runtime::run(done);
});
recv.recv().unwrap()