From 37954cd647ea6b7d626f15925518b38f2a858bee Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 9 Aug 2018 10:46:28 +0800 Subject: [PATCH] use tokio-tls 0.2 api --- examples/client/src/main.rs | 8 +-- examples/server/src/main.rs | 8 +-- src/lib.rs | 111 +++++++++++++++++------------------- src/tokio_impl.rs | 4 +- tests/test.rs | 14 ++--- 5 files changed, 70 insertions(+), 75 deletions(-) diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 8499993..0d34f64 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -15,7 +15,7 @@ use tokio::io; use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; -use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; fn app() -> App<'static, 'static> { App::new("client") @@ -49,7 +49,7 @@ fn main() { } else { config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } - let arc_config = Arc::new(config); + let config = TlsConnector::from(Arc::new(config)); let socket = TcpStream::connect(&addr); @@ -70,7 +70,7 @@ fn main() { socket .and_then(move |stream| { let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) + config.connect(domain, stream) }) .and_then(move |stream| io::write_all(stream, text)) .and_then(move |(stream, _)| { @@ -93,7 +93,7 @@ fn main() { socket .and_then(move |stream| { let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) + config.connect(domain, stream) }) .and_then(move |stream| io::write_all(stream, text)) .and_then(move |(stream, _)| { diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 2222c1e..2a94c58 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -7,7 +7,7 @@ use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; use tokio_rustls::{ - ServerConfigExt, + TlsAcceptor, rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig, internal::pemfile::{ certs, rsa_private_keys } @@ -49,13 +49,13 @@ fn main() { let mut config = ServerConfig::new(NoClientAuth::new()); config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)) .expect("invalid key or certificate"); - let arc_config = Arc::new(config); + let config = TlsAcceptor::from(Arc::new(config)); let socket = TcpListener::bind(&addr).unwrap(); let done = socket.incoming() .for_each(move |stream| if flag_echo { let addr = stream.peer_addr().ok(); - let done = arc_config.accept_async(stream) + let done = config.accept(stream) .and_then(|stream| { let (reader, writer) = stream.split(); io::copy(reader, writer) @@ -67,7 +67,7 @@ fn main() { Ok(()) } else { let addr = stream.peer_addr().ok(); - let done = arc_config.accept_async(stream) + let done = config.accept(stream) .and_then(|stream| io::write_all( stream, &b"HTTP/1.0 200 ok\r\n\ diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..8d43c22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,70 +16,69 @@ use rustls::{ }; -/// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ClientConfigExt: sealed::Sealed { - fn connect_async(&self, domain: DNSNameRef, stream: S) - -> ConnectAsync - where S: io::Read + io::Write; +pub struct TlsConnector { + inner: Arc } -/// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ServerConfigExt: sealed::Sealed { - fn accept_async(&self, stream: S) - -> AcceptAsync - where S: io::Read + io::Write; +pub struct TlsAcceptor { + inner: Arc +} + +impl From> for TlsConnector { + fn from(inner: Arc) -> TlsConnector { + TlsConnector { inner } + } +} + +impl From> for TlsAcceptor { + fn from(inner: Arc) -> TlsAcceptor { + TlsAcceptor { inner } + } +} + +impl TlsConnector { + pub fn connect(&self, domain: DNSNameRef, stream: S) -> Connect + where S: io::Read + io::Write + { + Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) + } + + #[inline] + pub fn connect_with_session(stream: S, session: ClientSession) + -> Connect + where S: io::Read + io::Write + { + Connect(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) + }) + } +} + +impl TlsAcceptor { + pub fn accept(&self, stream: S) -> Accept + where S: io::Read + io::Write, + { + Self::accept_with_session(stream, ServerSession::new(&self.inner)) + } + + #[inline] + pub fn accept_with_session(stream: S, session: ServerSession) -> Accept + where S: io::Read + io::Write + { + Accept(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) + }) + } } /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct ConnectAsync(MidHandshake); +pub struct Connect(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct AcceptAsync(MidHandshake); - -impl sealed::Sealed for Arc {} - -impl ClientConfigExt for Arc { - fn connect_async(&self, domain: DNSNameRef, stream: S) - -> ConnectAsync - where S: io::Read + io::Write - { - connect_async_with_session(stream, ClientSession::new(self, domain)) - } -} - -#[inline] -pub fn connect_async_with_session(stream: S, session: ClientSession) - -> ConnectAsync - where S: io::Read + io::Write -{ - ConnectAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) - }) -} - -impl sealed::Sealed for Arc {} - -impl ServerConfigExt for Arc { - fn accept_async(&self, stream: S) - -> AcceptAsync - where S: io::Read + io::Write - { - accept_async_with_session(stream, ServerSession::new(self)) - } -} - -#[inline] -pub fn accept_async_with_session(stream: S, session: ServerSession) - -> AcceptAsync - where S: io::Read + io::Write -{ - AcceptAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) - }) -} +pub struct Accept(MidHandshake); struct MidHandshake { @@ -143,7 +142,3 @@ impl io::Write for TlsStream self.io.flush() } } - -mod sealed { - pub trait Sealed {} -} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..d9598bf 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -6,7 +6,7 @@ use self::tokio::io::{ AsyncRead, AsyncWrite }; use self::tokio::prelude::Poll; -impl Future for ConnectAsync { +impl Future for Connect { type Item = TlsStream; type Error = io::Error; @@ -15,7 +15,7 @@ impl Future for ConnectAsync { } } -impl Future for AcceptAsync { +impl Future for Accept { type Item = TlsStream; type Error = io::Error; diff --git a/tests/test.rs b/tests/test.rs index e64dd82..7eae2af 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -13,7 +13,7 @@ use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; use tokio::net::{ TcpListener, TcpStream }; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; +use tokio_rustls::{ TlsConnector, TlsAcceptor }; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); @@ -28,7 +28,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, rsa) .expect("invalid key or certificate"); - let config = Arc::new(config); + let config = TlsAcceptor::from(Arc::new(config)); let (send, recv) = channel(); @@ -40,7 +40,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { let done = listener.incoming() .for_each(move |stream| { - let done = config.accept_async(stream) + let done = config.accept(stream) .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) .and_then(|(stream, buf)| { assert_eq!(buf, HELLO_WORLD); @@ -68,10 +68,10 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option