use tokio-tls 0.2 api

This commit is contained in:
quininer 2018-08-09 10:46:28 +08:00
parent 32d3f46a9e
commit 37954cd647
5 changed files with 70 additions and 75 deletions

View File

@ -15,7 +15,7 @@ use tokio::io;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::prelude::*; use tokio::prelude::*;
use clap::{ App, Arg }; use clap::{ App, Arg };
use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; use tokio_rustls::{ TlsConnector, rustls::ClientConfig };
fn app() -> App<'static, 'static> { fn app() -> App<'static, 'static> {
App::new("client") App::new("client")
@ -49,7 +49,7 @@ fn main() {
} else { } else {
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
} }
let arc_config = Arc::new(config); let config = TlsConnector::from(Arc::new(config));
let socket = TcpStream::connect(&addr); let socket = TcpStream::connect(&addr);
@ -70,7 +70,7 @@ fn main() {
socket socket
.and_then(move |stream| { .and_then(move |stream| {
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
arc_config.connect_async(domain, stream) config.connect(domain, stream)
}) })
.and_then(move |stream| io::write_all(stream, text)) .and_then(move |stream| io::write_all(stream, text))
.and_then(move |(stream, _)| { .and_then(move |(stream, _)| {
@ -93,7 +93,7 @@ fn main() {
socket socket
.and_then(move |stream| { .and_then(move |stream| {
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
arc_config.connect_async(domain, stream) config.connect(domain, stream)
}) })
.and_then(move |stream| io::write_all(stream, text)) .and_then(move |stream| io::write_all(stream, text))
.and_then(move |(stream, _)| { .and_then(move |(stream, _)| {

View File

@ -7,7 +7,7 @@ use std::net::ToSocketAddrs;
use std::io::BufReader; use std::io::BufReader;
use std::fs::File; use std::fs::File;
use tokio_rustls::{ use tokio_rustls::{
ServerConfigExt, TlsAcceptor,
rustls::{ rustls::{
Certificate, NoClientAuth, PrivateKey, ServerConfig, Certificate, NoClientAuth, PrivateKey, ServerConfig,
internal::pemfile::{ certs, rsa_private_keys } internal::pemfile::{ certs, rsa_private_keys }
@ -49,13 +49,13 @@ fn main() {
let mut config = ServerConfig::new(NoClientAuth::new()); let mut config = ServerConfig::new(NoClientAuth::new());
config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)) config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0))
.expect("invalid key or certificate"); .expect("invalid key or certificate");
let arc_config = Arc::new(config); let config = TlsAcceptor::from(Arc::new(config));
let socket = TcpListener::bind(&addr).unwrap(); let socket = TcpListener::bind(&addr).unwrap();
let done = socket.incoming() let done = socket.incoming()
.for_each(move |stream| if flag_echo { .for_each(move |stream| if flag_echo {
let addr = stream.peer_addr().ok(); let addr = stream.peer_addr().ok();
let done = arc_config.accept_async(stream) let done = config.accept(stream)
.and_then(|stream| { .and_then(|stream| {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
io::copy(reader, writer) io::copy(reader, writer)
@ -67,7 +67,7 @@ fn main() {
Ok(()) Ok(())
} else { } else {
let addr = stream.peer_addr().ok(); let addr = stream.peer_addr().ok();
let done = arc_config.accept_async(stream) let done = config.accept(stream)
.and_then(|stream| io::write_all( .and_then(|stream| io::write_all(
stream, stream,
&b"HTTP/1.0 200 ok\r\n\ &b"HTTP/1.0 200 ok\r\n\

View File

@ -16,70 +16,69 @@ use rustls::{
}; };
/// Extension trait for the `Arc<ClientConfig>` type in the `rustls` crate. pub struct TlsConnector {
pub trait ClientConfigExt: sealed::Sealed { inner: Arc<ClientConfig>
fn connect_async<S>(&self, domain: DNSNameRef, stream: S)
-> ConnectAsync<S>
where S: io::Read + io::Write;
} }
/// Extension trait for the `Arc<ServerConfig>` type in the `rustls` crate. pub struct TlsAcceptor {
pub trait ServerConfigExt: sealed::Sealed { inner: Arc<ServerConfig>
fn accept_async<S>(&self, stream: S) }
-> AcceptAsync<S>
where S: io::Read + io::Write; impl From<Arc<ClientConfig>> for TlsConnector {
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
TlsConnector { inner }
}
}
impl From<Arc<ServerConfig>> for TlsAcceptor {
fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
TlsAcceptor { inner }
}
}
impl TlsConnector {
pub fn connect<S>(&self, domain: DNSNameRef, stream: S) -> Connect<S>
where S: io::Read + io::Write
{
Self::connect_with_session(stream, ClientSession::new(&self.inner, domain))
}
#[inline]
pub fn connect_with_session<S>(stream: S, session: ClientSession)
-> Connect<S>
where S: io::Read + io::Write
{
Connect(MidHandshake {
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
})
}
}
impl TlsAcceptor {
pub fn accept<S>(&self, stream: S) -> Accept<S>
where S: io::Read + io::Write,
{
Self::accept_with_session(stream, ServerSession::new(&self.inner))
}
#[inline]
pub fn accept_with_session<S>(stream: S, session: ServerSession) -> Accept<S>
where S: io::Read + io::Write
{
Accept(MidHandshake {
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
})
}
} }
/// Future returned from `ClientConfigExt::connect_async` which will resolve /// Future returned from `ClientConfigExt::connect_async` which will resolve
/// once the connection handshake has finished. /// once the connection handshake has finished.
pub struct ConnectAsync<S>(MidHandshake<S, ClientSession>); pub struct Connect<S>(MidHandshake<S, ClientSession>);
/// Future returned from `ServerConfigExt::accept_async` which will resolve /// Future returned from `ServerConfigExt::accept_async` which will resolve
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct AcceptAsync<S>(MidHandshake<S, ServerSession>); pub struct Accept<S>(MidHandshake<S, ServerSession>);
impl sealed::Sealed for Arc<ClientConfig> {}
impl ClientConfigExt for Arc<ClientConfig> {
fn connect_async<S>(&self, domain: DNSNameRef, stream: S)
-> ConnectAsync<S>
where S: io::Read + io::Write
{
connect_async_with_session(stream, ClientSession::new(self, domain))
}
}
#[inline]
pub fn connect_async_with_session<S>(stream: S, session: ClientSession)
-> ConnectAsync<S>
where S: io::Read + io::Write
{
ConnectAsync(MidHandshake {
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
})
}
impl sealed::Sealed for Arc<ServerConfig> {}
impl ServerConfigExt for Arc<ServerConfig> {
fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S>
where S: io::Read + io::Write
{
accept_async_with_session(stream, ServerSession::new(self))
}
}
#[inline]
pub fn accept_async_with_session<S>(stream: S, session: ServerSession)
-> AcceptAsync<S>
where S: io::Read + io::Write
{
AcceptAsync(MidHandshake {
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
})
}
struct MidHandshake<S, C> { struct MidHandshake<S, C> {
@ -143,7 +142,3 @@ impl<S, C> io::Write for TlsStream<S, C>
self.io.flush() self.io.flush()
} }
} }
mod sealed {
pub trait Sealed {}
}

View File

@ -6,7 +6,7 @@ use self::tokio::io::{ AsyncRead, AsyncWrite };
use self::tokio::prelude::Poll; use self::tokio::prelude::Poll;
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> { impl<S: AsyncRead + AsyncWrite> Future for Connect<S> {
type Item = TlsStream<S, ClientSession>; type Item = TlsStream<S, ClientSession>;
type Error = io::Error; type Error = io::Error;
@ -15,7 +15,7 @@ impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
} }
} }
impl<S: AsyncRead + AsyncWrite> Future for AcceptAsync<S> { impl<S: AsyncRead + AsyncWrite> Future for Accept<S> {
type Item = TlsStream<S, ServerSession>; type Item = TlsStream<S, ServerSession>;
type Error = io::Error; type Error = io::Error;

View File

@ -13,7 +13,7 @@ use std::net::{ SocketAddr, IpAddr, Ipv4Addr };
use tokio::net::{ TcpListener, TcpStream }; use tokio::net::{ TcpListener, TcpStream };
use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys }; use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; use tokio_rustls::{ TlsConnector, TlsAcceptor };
const CERT: &str = include_str!("end.cert"); const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain"); const CHAIN: &str = include_str!("end.chain");
@ -28,7 +28,7 @@ fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
let mut config = ServerConfig::new(rustls::NoClientAuth::new()); let mut config = ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(cert, rsa) config.set_single_cert(cert, rsa)
.expect("invalid key or certificate"); .expect("invalid key or certificate");
let config = Arc::new(config); let config = TlsAcceptor::from(Arc::new(config));
let (send, recv) = channel(); let (send, recv) = channel();
@ -40,7 +40,7 @@ fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
let done = listener.incoming() let done = listener.incoming()
.for_each(move |stream| { .for_each(move |stream| {
let done = config.accept_async(stream) let done = config.accept(stream)
.and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
.and_then(|(stream, buf)| { .and_then(|(stream, buf)| {
assert_eq!(buf, HELLO_WORLD); assert_eq!(buf, HELLO_WORLD);
@ -68,10 +68,10 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<
if let Some(mut chain) = chain { if let Some(mut chain) = chain {
config.root_store.add_pem_file(&mut chain).unwrap(); config.root_store.add_pem_file(&mut chain).unwrap();
} }
let config = Arc::new(config); let config = TlsConnector::from(Arc::new(config));
let done = TcpStream::connect(addr) let done = TcpStream::connect(addr)
.and_then(|stream| config.connect_async(domain, stream)) .and_then(|stream| config.connect(domain, stream))
.and_then(|stream| aio::write_all(stream, HELLO_WORLD)) .and_then(|stream| aio::write_all(stream, HELLO_WORLD))
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
.and_then(|(stream, buf)| { .and_then(|(stream, buf)| {
@ -94,10 +94,10 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor
if let Some(mut chain) = chain { if let Some(mut chain) = chain {
config.root_store.add_pem_file(&mut chain).unwrap(); config.root_store.add_pem_file(&mut chain).unwrap();
} }
let config = Arc::new(config); let config = TlsConnector::from(Arc::new(config));
let done = TcpStream::connect(addr) let done = TcpStream::connect(addr)
.and_then(|stream| config.connect_async(domain, stream)) .and_then(|stream| config.connect(domain, stream))
.and_then(|stream| stream.write_all(HELLO_WORLD)) .and_then(|stream| stream.write_all(HELLO_WORLD))
.and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()])) .and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()]))
.and_then(|(stream, buf)| { .and_then(|(stream, buf)| {