Rename more tests (#1)

* Rename more tests

* Clean up smoke test

* fmt

* Clean up ci and remove all-features test
This commit is contained in:
Lucio Franco 2020-02-27 18:32:52 -05:00 committed by GitHub
parent 01fdb7ccf4
commit 7e41beaff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 402 additions and 818 deletions

View File

@ -1,6 +1,10 @@
name: CI name: CI
on: [push, pull_request] on:
push:
branches:
- master
pull_request: {}
jobs: jobs:
check: check:
@ -40,7 +44,7 @@ jobs:
profile: minimal profile: minimal
- uses: actions/checkout@master - uses: actions/checkout@master
- name: Test - name: Test
run: cargo test --all --all-features run: cargo test --all
lints: lints:
name: Lints name: Lints

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,500 +1,126 @@
#![warn(rust_2018_idioms)]
use cfg_if::cfg_if;
use env_logger;
use futures::join; use futures::join;
use native_tls; use native_tls::{Certificate, Identity};
use native_tls::{Identity, TlsAcceptor, TlsConnector}; use std::io::Error;
use std::io::Write; use tokio::{
use std::marker::Unpin; io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
use std::process::Command; net::{TcpListener, TcpStream},
use std::ptr; };
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind}; use tokio_native_tls::{TlsAcceptor, TlsConnector};
use tokio::net::{TcpListener, TcpStream};
use tokio::stream::StreamExt;
macro_rules! t { #[tokio::test]
($e:expr) => { async fn client_to_server() {
match $e { let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
Ok(e) => e, let addr = srv.local_addr().unwrap();
Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
} let (server_tls, client_tls) = context();
// Create a future to accept one socket, connect the ssl stream, and then
// read all the data from it.
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let mut socket = server_tls.accept(socket).await.unwrap();
let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap();
data
}; };
// Create a future to connect to our server, connect the ssl stream, and
// then write a bunch of data to it.
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let socket = client_tls.connect("foobar.com", socket).await.unwrap();
copy_data(socket).await
};
// Finally, run everything!
let (data, _) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
} }
#[allow(dead_code)] #[tokio::test]
struct Keys { async fn server_to_client() {
cert_der: Vec<u8>, // Create a server listening on a port, then figure out what that port is
pkey_der: Vec<u8>, let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
pkcs12_der: Vec<u8>, let addr = srv.local_addr().unwrap();
let (server_tls, client_tls) = context();
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let socket = server_tls.accept(socket).await.unwrap();
copy_data(socket).await
};
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap();
data
};
// Finally, run everything!
let (_, data) = join!(server, client);
assert!(data == vec![9; AMT]);
} }
#[allow(dead_code)] #[tokio::test]
fn openssl_keys() -> &'static Keys { async fn one_byte_at_a_time() {
static INIT: Once = Once::new(); const AMT: usize = 1024;
static mut KEYS: *mut Keys = ptr::null_mut();
INIT.call_once(|| { let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let path = t!(env::current_exe()); let addr = srv.local_addr().unwrap();
let path = path.parent().unwrap();
let keyfile = path.join("test.key");
let certfile = path.join("test.crt");
let config = path.join("openssl.config");
File::create(&config) let (server_tls, client_tls) = context();
.unwrap()
.write_all(
b"\
[req]\n\
distinguished_name=dn\n\
[ dn ]\n\
CN=localhost\n\
[ ext ]\n\
basicConstraints=CA:FALSE,pathlen:0\n\
subjectAltName = @alt_names
extendedKeyUsage=serverAuth,clientAuth
[alt_names]
DNS.1 = localhost
",
)
.unwrap();
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; let server = async move {
let output = t!(Command::new("openssl") let (socket, _) = srv.accept().await.unwrap();
.arg("req") let mut socket = server_tls.accept(socket).await.unwrap();
.arg("-nodes") let mut amt = 0;
.arg("-x509") for b in std::iter::repeat(9).take(AMT) {
.arg("-newkey") let data = [b as u8];
.arg("rsa:2048") socket.write_all(&data).await.unwrap();
.arg("-config") amt += 1;
.arg(&config)
.arg("-extensions")
.arg("ext")
.arg("-subj")
.arg(subj)
.arg("-keyout")
.arg(&keyfile)
.arg("-out")
.arg(&certfile)
.arg("-days")
.arg("1")
.output());
assert!(output.status.success());
let crtout = t!(Command::new("openssl")
.arg("x509")
.arg("-outform")
.arg("der")
.arg("-in")
.arg(&certfile)
.output());
assert!(crtout.status.success());
let keyout = t!(Command::new("openssl")
.arg("rsa")
.arg("-outform")
.arg("der")
.arg("-in")
.arg(&keyfile)
.output());
assert!(keyout.status.success());
let pkcs12out = t!(Command::new("openssl")
.arg("pkcs12")
.arg("-export")
.arg("-nodes")
.arg("-inkey")
.arg(&keyfile)
.arg("-in")
.arg(&certfile)
.arg("-password")
.arg("pass:foobar")
.output());
assert!(pkcs12out.status.success());
let keys = Box::new(Keys {
cert_der: crtout.stdout,
pkey_der: keyout.stdout,
pkcs12_der: pkcs12out.stdout,
});
unsafe {
KEYS = Box::into_raw(keys);
} }
}); amt
unsafe { &*KEYS } };
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
let mut data = Vec::new();
loop {
let mut buf = [0; 1];
match socket.read_exact(&mut buf).await {
Ok(_) => data.extend_from_slice(&buf),
Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(err) => panic!(err),
}
}
data
};
let (amt, data) = join!(server, client);
assert_eq!(amt, AMT);
assert!(data == vec![9; AMT as usize]);
} }
cfg_if! { fn context() -> (TlsAcceptor, TlsConnector) {
if #[cfg(feature = "rustls")] { // Certs borrowed from `rust-native-tls/tests`
use webpki; let pkcs12 = include_bytes!("identity.p12");
use untrusted; let der = include_bytes!("root-ca.der");
use std::env;
use std::fs::File;
use std::process::Command;
use std::sync::Once;
use untrusted::Input; let identity = Identity::from_pkcs12(pkcs12, "mypass").unwrap();
use webpki::trust_anchor_util; let acceptor = native_tls::TlsAcceptor::builder(identity).build().unwrap();
fn server_cx() -> io::Result<ServerContext> { let cert = Certificate::from_der(der).unwrap();
let mut cx = ServerContext::new(); let connector = native_tls::TlsConnector::builder()
.add_root_certificate(cert)
.build()
.unwrap();
let (cert, key) = keys(); (acceptor.into(), connector.into())
cx.config_mut()
.set_single_cert(vec![cert.to_vec()], key.to_vec());
Ok(cx)
}
fn configure_client(cx: &mut ClientContext) {
let (cert, _key) = keys();
let cert = Input::from(cert);
let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap();
cx.config_mut().root_store.add_trust_anchors(&[anchor]);
}
// Like OpenSSL we generate certificates on the fly, but for OSX we
// also have to put them into a specific keychain. We put both the
// certificates and the keychain next to our binary.
//
// Right now I don't know of a way to programmatically create a
// self-signed certificate, so we just fork out to the `openssl` binary.
fn keys() -> (&'static [u8], &'static [u8]) {
static INIT: Once = Once::new();
static mut KEYS: *mut (Vec<u8>, Vec<u8>) = ptr::null_mut();
INIT.call_once(|| {
let (key, cert) = openssl_keys();
let path = t!(env::current_exe());
let path = path.parent().unwrap();
let keyfile = path.join("test.key");
let certfile = path.join("test.crt");
let config = path.join("openssl.config");
File::create(&config).unwrap().write_all(b"\
[req]\n\
distinguished_name=dn\n\
[ dn ]\n\
CN=localhost\n\
[ ext ]\n\
basicConstraints=CA:FALSE,pathlen:0\n\
subjectAltName = @alt_names
[alt_names]
DNS.1 = localhost
").unwrap();
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
let output = t!(Command::new("openssl")
.arg("req")
.arg("-nodes")
.arg("-x509")
.arg("-newkey").arg("rsa:2048")
.arg("-config").arg(&config)
.arg("-extensions").arg("ext")
.arg("-subj").arg(subj)
.arg("-keyout").arg(&keyfile)
.arg("-out").arg(&certfile)
.arg("-days").arg("1")
.output());
assert!(output.status.success());
let crtout = t!(Command::new("openssl")
.arg("x509")
.arg("-outform").arg("der")
.arg("-in").arg(&certfile)
.output());
assert!(crtout.status.success());
let keyout = t!(Command::new("openssl")
.arg("rsa")
.arg("-outform").arg("der")
.arg("-in").arg(&keyfile)
.output());
assert!(keyout.status.success());
let cert = crtout.stdout;
let key = keyout.stdout;
unsafe {
KEYS = Box::into_raw(Box::new((cert, key)));
}
});
unsafe {
(&(*KEYS).0, &(*KEYS).1)
}
}
} else if #[cfg(any(feature = "force-openssl",
all(not(target_os = "macos"),
not(target_os = "windows"),
not(target_os = "ios"))))] {
use std::fs::File;
use std::env;
use std::sync::Once;
fn contexts() -> (tokio_native_tls::TlsAcceptor, tokio_native_tls::TlsConnector) {
let keys = openssl_keys();
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let cert = t!(native_tls::Certificate::from_der(&keys.cert_der));
let mut client = TlsConnector::builder();
t!(client.add_root_certificate(cert).build());
(t!(srv.build()).into(), t!(client.build()).into())
}
} else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
use std::env;
use std::fs::File;
use std::sync::Once;
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
let keys = openssl_keys();
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap();
let mut client = TlsConnector::builder();
client.add_root_certificate(cert);
(t!(srv.build()).into(), t!(client.build()).into())
}
} else {
use schannel;
use winapi;
use std::env;
use std::fs::File;
use std::io;
use std::mem;
use std::sync::Once;
use schannel::cert_context::CertContext;
use schannel::cert_store::{CertStore, CertAdd, Memory};
use winapi::shared::basetsd::*;
use winapi::shared::lmcons::*;
use winapi::shared::minwindef::*;
use winapi::shared::ntdef::WCHAR;
use winapi::um::minwinbase::*;
use winapi::um::sysinfoapi::*;
use winapi::um::timezoneapi::*;
use winapi::um::wincrypt::*;
const FRIENDLY_NAME: &'static str = "tokio-tls localhost testing cert";
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
let cert = localhost_cert();
let mut store = t!(Memory::new()).into_store();
t!(store.add_cert(&cert, CertAdd::Always));
let pkcs12_der = t!(store.export_pkcs12("foobar"));
let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let client = TlsConnector::builder();
(t!(srv.build()).into(), t!(client.build()).into())
}
// ====================================================================
// Magic!
//
// Lots of magic is happening here to wrangle certificates for running
// these tests on Windows. For more information see the test suite
// in the schannel-rs crate as this is just coyping that.
//
// The general gist of this though is that the only way to add custom
// trusted certificates is to add it to the system store of trust. To
// do that we go through the whole rigamarole here to generate a new
// self-signed certificate and then insert that into the system store.
//
// This generates some dialogs, so we print what we're doing sometimes,
// and otherwise we just manage the ephemeral certificates. Because
// they're in the system store we always ensure that they're only valid
// for a small period of time (e.g. 1 day).
fn localhost_cert() -> CertContext {
static INIT: Once = Once::new();
INIT.call_once(|| {
for cert in local_root_store().certs() {
let name = match cert.friendly_name() {
Ok(name) => name,
Err(_) => continue,
};
if name != FRIENDLY_NAME {
continue
}
if !cert.is_time_valid().unwrap() {
io::stdout().write_all(br#"
The tokio-tls test suite is about to delete an old copy of one of its
certificates from your root trust store. This certificate was only valid for one
day and it is no longer needed. The host should be "localhost" and the
description should mention "tokio-tls".
"#).unwrap();
cert.delete().unwrap();
} else {
return
}
}
install_certificate().unwrap();
});
for cert in local_root_store().certs() {
let name = match cert.friendly_name() {
Ok(name) => name,
Err(_) => continue,
};
if name == FRIENDLY_NAME {
return cert
}
}
panic!("couldn't find a cert");
}
fn local_root_store() -> CertStore {
if env::var("CI").is_ok() {
CertStore::open_local_machine("Root").unwrap()
} else {
CertStore::open_current_user("Root").unwrap()
}
}
fn install_certificate() -> io::Result<CertContext> {
unsafe {
let mut provider = 0;
let mut hkey = 0;
let mut buffer = "tokio-tls test suite".encode_utf16()
.chain(Some(0))
.collect::<Vec<_>>();
let res = CryptAcquireContextW(&mut provider,
buffer.as_ptr(),
ptr::null_mut(),
PROV_RSA_FULL,
CRYPT_MACHINE_KEYSET);
if res != TRUE {
// create a new key container (since it does not exist)
let res = CryptAcquireContextW(&mut provider,
buffer.as_ptr(),
ptr::null_mut(),
PROV_RSA_FULL,
CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET);
if res != TRUE {
return Err(Error::last_os_error())
}
}
// create a new keypair (RSA-2048)
let res = CryptGenKey(provider,
AT_SIGNATURE,
0x0800<<16 | CRYPT_EXPORTABLE,
&mut hkey);
if res != TRUE {
return Err(Error::last_os_error());
}
// start creating the certificate
let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\
G=tokio_tls".encode_utf16()
.chain(Some(0))
.collect::<Vec<_>>();
let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed();
let mut cname_len = cname_buffer.len() as DWORD;
let res = CertStrToNameW(X509_ASN_ENCODING,
name.as_ptr(),
CERT_X500_NAME_STR,
ptr::null_mut(),
cname_buffer.as_mut_ptr() as *mut u8,
&mut cname_len,
ptr::null_mut());
if res != TRUE {
return Err(Error::last_os_error());
}
let mut subject_issuer = CERT_NAME_BLOB {
cbData: cname_len,
pbData: cname_buffer.as_ptr() as *mut u8,
};
let mut key_provider = CRYPT_KEY_PROV_INFO {
pwszContainerName: buffer.as_mut_ptr(),
pwszProvName: ptr::null_mut(),
dwProvType: PROV_RSA_FULL,
dwFlags: CRYPT_MACHINE_KEYSET,
cProvParam: 0,
rgProvParam: ptr::null_mut(),
dwKeySpec: AT_SIGNATURE,
};
let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER {
pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _,
Parameters: mem::zeroed(),
};
let mut expiration_date: SYSTEMTIME = mem::zeroed();
GetSystemTime(&mut expiration_date);
let mut file_time: FILETIME = mem::zeroed();
let res = SystemTimeToFileTime(&mut expiration_date,
&mut file_time);
if res != TRUE {
return Err(Error::last_os_error());
}
let mut timestamp: u64 = file_time.dwLowDateTime as u64 |
(file_time.dwHighDateTime as u64) << 32;
// one day, timestamp unit is in 100 nanosecond intervals
timestamp += (1E9 as u64) / 100 * (60 * 60 * 24);
file_time.dwLowDateTime = timestamp as u32;
file_time.dwHighDateTime = (timestamp >> 32) as u32;
let res = FileTimeToSystemTime(&file_time,
&mut expiration_date);
if res != TRUE {
return Err(Error::last_os_error());
}
// create a self signed certificate
let cert_context = CertCreateSelfSignCertificate(
0 as ULONG_PTR,
&mut subject_issuer,
0,
&mut key_provider,
&mut sig_algorithm,
ptr::null_mut(),
&mut expiration_date,
ptr::null_mut());
if cert_context.is_null() {
return Err(Error::last_os_error());
}
// TODO: this is.. a terrible hack. Right now `schannel`
// doesn't provide a public method to go from a raw
// cert context pointer to the `CertContext` structure it
// has, so we just fake it here with a transmute. This'll
// probably break at some point, but hopefully by then
// it'll have a method to do this!
struct MyCertContext<T>(T);
impl<T> Drop for MyCertContext<T> {
fn drop(&mut self) {}
}
let cert_context = MyCertContext(cert_context);
let cert_context: CertContext = mem::transmute(cert_context);
cert_context.set_friendly_name(FRIENDLY_NAME)?;
// install the certificate to the machine's local store
io::stdout().write_all(br#"
The tokio-tls test suite is about to add a certificate to your set of root
and trusted certificates. This certificate should be for the domain "localhost"
with the description related to "tokio-tls". This certificate is only valid
for one day and will be automatically deleted if you re-run the tokio-tls
test suite later.
"#).unwrap();
local_root_store().add_cert(&cert_context,
CertAdd::ReplaceExisting)?;
Ok(cert_context)
}
}
}
} }
const AMT: usize = 128 * 1024; const AMT: usize = 128 * 1024;
@ -517,112 +143,3 @@ async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
} }
Ok(amt) Ok(amt)
} }
#[tokio::test]
async fn client_to_server() {
drop(env_logger::try_init());
// Create a server listening on a port, then figure out what that port is
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
// Create a future to accept one socket, connect the ssl stream, and then
// read all the data from it.
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let mut socket = t!(server_cx.accept(socket).await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
data
};
// Create a future to connect to our server, connect the ssl stream, and
// then write a bunch of data to it.
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let socket = t!(client_cx.connect("localhost", socket).await);
copy_data(socket).await
};
// Finally, run everything!
let (data, _) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn server_to_client() {
drop(env_logger::try_init());
// Create a server listening on a port, then figure out what that port is
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let socket = t!(server_cx.accept(socket).await);
copy_data(socket).await
};
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let mut socket = t!(client_cx.connect("localhost", socket).await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
data
};
// Finally, run everything!
let (_, data) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn one_byte_at_a_time() {
const AMT: usize = 1024;
drop(env_logger::try_init());
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let mut socket = t!(server_cx.accept(socket).await);
let mut amt = 0;
for b in std::iter::repeat(9).take(AMT) {
let data = [b as u8];
t!(socket.write_all(&data).await);
amt += 1;
}
amt
};
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let mut socket = t!(client_cx.connect("localhost", socket).await);
let mut data = Vec::new();
loop {
let mut buf = [0; 1];
match socket.read_exact(&mut buf).await {
Ok(_) => data.extend_from_slice(&buf),
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
Err(err) => panic!(err),
}
}
data
};
let (amt, data) = join!(server, client);
assert_eq!(amt, AMT);
assert!(data == vec![9; AMT as usize]);
}

View File

@ -1,7 +1,6 @@
use super::*; use super::*;
use rustls::Session;
use crate::common::IoSession; use crate::common::IoSession;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
@ -58,20 +57,24 @@ where
false false
} }
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.state { match self.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData(..) => Poll::Pending, TlsState::EarlyData(..) => Poll::Pending,
TlsState::Stream | TlsState::WriteShutdown => { TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
match stream.as_mut_pin().poll_read(cx, buf) { match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
this.state.shutdown_read(); this.state.shutdown_read();
Poll::Ready(Ok(0)) Poll::Ready(Ok(0))
}, }
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read(); this.state.shutdown_read();
@ -80,8 +83,8 @@ where
this.state.shutdown_write(); this.state.shutdown_write();
} }
Poll::Ready(Ok(0)) Poll::Ready(Ok(0))
}, }
output => output output => output,
} }
} }
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
@ -95,10 +98,14 @@ where
{ {
/// Note: that it does not guarantee the final data to be sent. /// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`. /// To be cautious, you must manually call `flush`.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
match this.state { match this.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
@ -110,9 +117,10 @@ where
if let Some(mut early_data) = stream.session.early_data() { if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) { let len = match early_data.write(buf) {
Ok(n) => n, Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Poll::Pending, return Poll::Pending
Err(err) => return Poll::Ready(Err(err)) }
Err(err) => return Poll::Ready(Err(err)),
}; };
if len != 0 { if len != 0 {
data.extend_from_slice(&buf[..len]); data.extend_from_slice(&buf[..len]);
@ -143,10 +151,11 @@ where
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")]
{
use futures_core::ready; use futures_core::ready;
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
@ -176,7 +185,8 @@ where
self.state.shutdown_write(); self.state.shutdown_write();
} }
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")]
{
// we skip the handshake // we skip the handshake
if let TlsState::EarlyData(..) = self.state { if let TlsState::EarlyData(..) = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx); return Pin::new(&mut self.io).poll_shutdown(cx);
@ -184,8 +194,8 @@ where
} }
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_shutdown(cx) stream.as_mut_pin().poll_shutdown(cx)
} }
} }

View File

@ -1,12 +1,11 @@
use std::{ io, mem }; use crate::common::{Stream, TlsState};
use std::pin::Pin;
use std::future::Future;
use std::task::{ Context, Poll };
use futures_core::future::FusedFuture; use futures_core::future::FusedFuture;
use tokio::io::{ AsyncRead, AsyncWrite };
use rustls::Session; use rustls::Session;
use crate::common::{ TlsState, Stream }; use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};
use tokio::io::{AsyncRead, AsyncWrite};
pub(crate) trait IoSession { pub(crate) trait IoSession {
type Io; type Io;
@ -26,7 +25,7 @@ impl<IS> FusedFuture for MidHandshake<IS>
where where
IS: IoSession + Unpin, IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin, IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin IS::Session: Session + Unpin,
{ {
fn is_terminated(&self) -> bool { fn is_terminated(&self) -> bool {
if let MidHandshake::End = self { if let MidHandshake::End = self {
@ -41,7 +40,7 @@ impl<IS> Future for MidHandshake<IS>
where where
IS: IoSession + Unpin, IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin, IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin IS::Session: Session + Unpin,
{ {
type Output = Result<IS, (io::Error, IS::Io)>; type Output = Result<IS, (io::Error, IS::Io)>;
@ -51,20 +50,21 @@ where
if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) { if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) {
if !stream.skip_handshake() { if !stream.skip_handshake() {
let (state, io, session) = stream.get_mut(); let (state, io, session) = stream.get_mut();
let mut tls_stream = Stream::new(io, session) let mut tls_stream = Stream::new(io, session).set_eof(!state.readable());
.set_eof(!state.readable());
macro_rules! try_poll { macro_rules! try_poll {
( $e:expr ) => { ( $e:expr ) => {
match $e { match $e {
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), Poll::Ready(Err(err)) => {
return Poll::Ready(Err((err, stream.into_io())))
}
Poll::Pending => { Poll::Pending => {
*this = MidHandshake::Handshaking(stream); *this = MidHandshake::Handshaking(stream);
return Poll::Pending; return Poll::Pending;
} }
} }
} };
} }
while tls_stream.session.is_handshaking() { while tls_stream.session.is_handshaking() {

View File

@ -3,14 +3,13 @@ mod handshake;
#[cfg(feature = "unstable")] #[cfg(feature = "unstable")]
mod vecbuf; mod vecbuf;
use std::pin::Pin;
use std::task::{ Poll, Context };
use std::io::{ self, Read };
use rustls::Session;
use tokio::io::{ AsyncRead, AsyncWrite };
use futures_core as futures; use futures_core as futures;
pub(crate) use handshake::{ IoSession, MidHandshake }; pub(crate) use handshake::{IoSession, MidHandshake};
use rustls::Session;
use std::io::{self, Read};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
#[derive(Debug)] #[derive(Debug)]
pub enum TlsState { pub enum TlsState {
@ -26,8 +25,7 @@ impl TlsState {
#[inline] #[inline]
pub fn shutdown_read(&mut self) { pub fn shutdown_read(&mut self) {
match *self { match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
*self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown, _ => *self = TlsState::ReadShutdown,
} }
} }
@ -35,8 +33,7 @@ impl TlsState {
#[inline] #[inline]
pub fn shutdown_write(&mut self) { pub fn shutdown_write(&mut self) {
match *self { match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
*self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown, _ => *self = TlsState::WriteShutdown,
} }
} }
@ -62,7 +59,7 @@ impl TlsState {
pub fn is_early_data(&self) -> bool { pub fn is_early_data(&self) -> bool {
match self { match self {
TlsState::EarlyData(..) => true, TlsState::EarlyData(..) => true,
_ => false _ => false,
} }
} }
@ -76,7 +73,7 @@ impl TlsState {
pub struct Stream<'a, IO, S> { pub struct Stream<'a, IO, S> {
pub io: &'a mut IO, pub io: &'a mut IO,
pub session: &'a mut S, pub session: &'a mut S,
pub eof: bool pub eof: bool,
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
@ -100,28 +97,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> {
self.session.process_new_packets() self.session.process_new_packets().map_err(|err| {
.map_err(|err| { // In case we have an alert to send describing this error,
// In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary
// try a last-gasp write -- but don't predate the primary // error.
// error. let _ = self.write_io(cx);
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err) io::Error::new(io::ErrorKind::InvalidData, err)
}) })
} }
pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> { struct Reader<'a, 'b, T> {
io: &'a mut T, io: &'a mut T,
cx: &'a mut Context<'b> cx: &'a mut Context<'b>,
} }
impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) { match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(result) => result, Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
} }
} }
} }
@ -131,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
let n = match self.session.read_tls(&mut reader) { let n = match self.session.read_tls(&mut reader) {
Ok(n) => n, Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
Err(err) => return Poll::Ready(Err(err)) Err(err) => return Poll::Ready(Err(err)),
}; };
Poll::Ready(Ok(n)) Poll::Ready(Ok(n))
@ -143,21 +139,21 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
struct Writer<'a, 'b, T> { struct Writer<'a, 'b, T> {
io: &'a mut T, io: &'a mut T,
cx: &'a mut Context<'b> cx: &'a mut Context<'b>,
} }
impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_write(self.cx, buf) { match Pin::new(&mut self.io).poll_write(self.cx, buf) {
Poll::Ready(result) => result, Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
} }
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
match Pin::new(&mut self.io).poll_flush(self.cx) { match Pin::new(&mut self.io).poll_flush(self.cx) {
Poll::Ready(result) => result, Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
} }
} }
} }
@ -166,7 +162,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
match self.session.write_tls(&mut writer) { match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result) result => Poll::Ready(result),
} }
} }
@ -176,7 +172,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
struct Writer<'a, 'b, T> { struct Writer<'a, 'b, T> {
io: &'a mut T, io: &'a mut T,
cx: &'a mut Context<'b> cx: &'a mut Context<'b>,
} }
impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> { impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> {
@ -187,7 +183,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) { match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) {
Poll::Ready(result) => result, Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
} }
} }
} }
@ -196,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
match self.session.writev_tls(&mut writer) { match self.session.writev_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result) result => Poll::Ready(result),
} }
} }
@ -213,9 +209,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(n)) => wrlen += n, Poll::Ready(Ok(n)) => wrlen += n,
Poll::Pending => { Poll::Pending => {
write_would_block = true; write_would_block = true;
break break;
}, }
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
} }
@ -225,9 +221,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(n)) => rdlen += n, Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => { Poll::Pending => {
read_would_block = true; read_would_block = true;
break break;
}, }
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
} }
@ -237,21 +233,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
(true, true) => { (true, true) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
Poll::Ready(Err(err)) Poll::Ready(Err(err))
}, }
(_, false) => Poll::Ready(Ok((rdlen, wrlen))), (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 { (_, true) if write_would_block || read_would_block => {
Poll::Ready(Ok((rdlen, wrlen))) if rdlen != 0 || wrlen != 0 {
} else { Poll::Ready(Ok((rdlen, wrlen)))
Poll::Pending } else {
}, Poll::Pending
(..) => continue }
} }
(..) => continue,
};
} }
} }
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut pos = 0; let mut pos = 0;
while pos != buf.len() { while pos != buf.len() {
@ -262,14 +264,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
match self.read_io(cx) { match self.read_io(cx) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
self.eof = true; self.eof = true;
break break;
}, }
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Pending => { Poll::Pending => {
would_block = true; would_block = true;
break break;
}, }
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
} }
@ -280,13 +282,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)), Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)),
Ok(n) => { Ok(n) => {
pos += n; pos += n;
continue continue;
}, }
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => {
Poll::Ready(Ok(pos)), Poll::Ready(Ok(pos))
Err(err) => Poll::Ready(Err(err)) }
} Err(err) => Poll::Ready(Err(err)),
};
} }
Poll::Ready(Ok(pos)) Poll::Ready(Ok(pos))
@ -294,7 +297,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut pos = 0; let mut pos = 0;
while pos != buf.len() { while pos != buf.len() {
@ -303,25 +310,25 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
match self.session.write(&buf[pos..]) { match self.session.write(&buf[pos..]) {
Ok(n) => pos += n, Ok(n) => pos += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err)) Err(err) => return Poll::Ready(Err(err)),
}; };
while self.session.wants_write() { while self.session.wants_write() {
match self.write_io(cx) { match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => { Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true; would_block = true;
break break;
}, }
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
} }
return match (pos, would_block) { return match (pos, would_block) {
(0, true) => Poll::Pending, (0, true) => Poll::Pending,
(n, true) => Poll::Ready(Ok(n)), (n, true) => Poll::Ready(Ok(n)),
(_, false) => continue (_, false) => continue,
} };
} }
Poll::Ready(Ok(pos)) Poll::Ready(Ok(pos))

View File

@ -1,39 +1,44 @@
use std::pin::Pin; use super::Stream;
use std::sync::Arc;
use std::task::{ Poll, Context };
use futures_core::ready; use futures_core::ready;
use futures_util::future::poll_fn; use futures_util::future::poll_fn;
use futures_util::task::noop_waker_ref; use futures_util::task::noop_waker_ref;
use tokio::io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; use rustls::internal::pemfile::{certs, rsa_private_keys};
use std::io::{ self, Read, Write, BufReader, Cursor }; use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use webpki::DNSNameRef; use webpki::DNSNameRef;
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use rustls::{
ServerConfig, ClientConfig,
ServerSession, ClientSession,
Session, NoClientAuth
};
use super::Stream;
struct Good<'a>(&'a mut dyn Session); struct Good<'a>(&'a mut dyn Session);
impl<'a> AsyncRead for Good<'a> { impl<'a> AsyncRead for Good<'a> {
fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.write_tls(buf.by_ref())) Poll::Ready(self.0.write_tls(buf.by_ref()))
} }
} }
impl<'a> AsyncWrite for Good<'a> { impl<'a> AsyncWrite for Good<'a> {
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &[u8],
) -> Poll<io::Result<usize>> {
let len = self.0.read_tls(buf.by_ref())?; let len = self.0.read_tls(buf.by_ref())?;
self.0.process_new_packets() self.0
.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(len)) Poll::Ready(Ok(len))
} }
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.process_new_packets() self.0
.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
@ -47,13 +52,21 @@ impl<'a> AsyncWrite for Good<'a> {
struct Pending; struct Pending;
impl AsyncRead for Pending { impl AsyncRead for Pending {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Pending Poll::Pending
} }
} }
impl AsyncWrite for Pending { impl AsyncWrite for Pending {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Pending Poll::Pending
} }
@ -69,13 +82,21 @@ impl AsyncWrite for Pending {
struct Eof; struct Eof;
impl AsyncRead for Eof { impl AsyncRead for Eof {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0)) Poll::Ready(Ok(0))
} }
} }
impl AsyncWrite for Eof { impl AsyncWrite for Eof {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(buf.len())) Poll::Ready(Ok(buf.len()))
} }
@ -122,8 +143,14 @@ async fn stream_bad() -> io::Result<()> {
let mut bad = Pending; let mut bad = Pending;
let mut stream = Stream::new(&mut bad, &mut client); let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); assert_eq!(
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
8
);
assert_eq!(
poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
8
);
let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer
assert!(r < 1024); assert!(r < 1024);
@ -164,7 +191,10 @@ async fn stream_handshake_eof() -> io::Result<()> {
let mut cx = Context::from_waker(noop_waker_ref()); let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.handshake(&mut cx); let r = stream.handshake(&mut cx);
assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); assert_eq!(
r.map_err(|err| err.kind()),
Poll::Ready(Err(io::ErrorKind::UnexpectedEof))
);
Ok(()) as io::Result<()> Ok(()) as io::Result<()>
} }
@ -204,7 +234,11 @@ fn make_pair() -> (ServerSession, ClientSession) {
(server, client) (server, client)
} }
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn do_handshake(
client: &mut ClientSession,
server: &mut ServerSession,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let mut good = Good(server); let mut good = Good(server);
let mut stream = Stream::new(&mut good, client); let mut stream = Stream::new(&mut good, client);

View File

@ -1,23 +1,27 @@
use std::io::IoSlice;
use std::cmp::{ self, Ordering };
use bytes::Buf; use bytes::Buf;
use std::cmp::{self, Ordering};
use std::io::IoSlice;
pub struct VecBuf<'a, 'b: 'a> { pub struct VecBuf<'a, 'b: 'a> {
pos: usize, pos: usize,
cur: usize, cur: usize,
inner: &'a [&'b [u8]] inner: &'a [&'b [u8]],
} }
impl<'a, 'b> VecBuf<'a, 'b> { impl<'a, 'b> VecBuf<'a, 'b> {
pub fn new(vbytes: &'a [&'b [u8]]) -> Self { pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
VecBuf { pos: 0, cur: 0, inner: vbytes } VecBuf {
pos: 0,
cur: 0,
inner: vbytes,
}
} }
} }
impl<'a, 'b> Buf for VecBuf<'a, 'b> { impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn remaining(&self) -> usize { fn remaining(&self) -> usize {
let sum = self.inner let sum = self
.inner
.iter() .iter()
.skip(self.pos) .skip(self.pos)
.map(|bytes| bytes.len()) .map(|bytes| bytes.len())
@ -32,19 +36,21 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn advance(&mut self, cnt: usize) { fn advance(&mut self, cnt: usize) {
let current = self.inner[self.pos].len(); let current = self.inner[self.pos].len();
match (self.cur + cnt).cmp(&current) { match (self.cur + cnt).cmp(&current) {
Ordering::Equal => if self.pos + 1 < self.inner.len() { Ordering::Equal => {
self.pos += 1; if self.pos + 1 < self.inner.len() {
self.cur = 0; self.pos += 1;
} else { self.cur = 0;
self.cur += cnt; } else {
}, self.cur += cnt;
}
}
Ordering::Greater => { Ordering::Greater => {
if self.pos + 1 < self.inner.len() { if self.pos + 1 < self.inner.len() {
self.pos += 1; self.pos += 1;
} }
let remaining = self.cur + cnt - current; let remaining = self.cur + cnt - current;
self.advance(remaining); self.advance(remaining);
}, }
Ordering::Less => self.cur += cnt, Ordering::Less => self.cur += cnt,
} }
} }
@ -120,8 +126,7 @@ mod test_vecbuf {
let b1: &[u8] = &mut [0]; let b1: &[u8] = &mut [0];
let b2: &[u8] = &mut [0]; let b2: &[u8] = &mut [0];
let mut dst: [IoSlice; 2] = let mut dst: [IoSlice; 2] = [IoSlice::new(b1), IoSlice::new(b2)];
[IoSlice::new(b1), IoSlice::new(b2)];
assert_eq!(2, buf.bytes_vectored(&mut dst[..])); assert_eq!(2, buf.bytes_vectored(&mut dst[..]));
} }

View File

@ -1,19 +1,19 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
mod common;
pub mod client; pub mod client;
mod common;
pub mod server; pub mod server;
use common::{MidHandshake, Stream, TlsState};
use futures_core::future::FusedFuture;
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
use std::future::Future;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::future::Future; use std::task::{Context, Poll};
use std::task::{ Context, Poll }; use tokio::io::{AsyncRead, AsyncWrite};
use futures_core::future::FusedFuture;
use tokio::io::{ AsyncRead, AsyncWrite };
use webpki::DNSNameRef; use webpki::DNSNameRef;
use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session };
use common::{ Stream, TlsState, MidHandshake };
pub use rustls; pub use rustls;
pub use webpki; pub use webpki;
@ -88,7 +88,7 @@ impl TlsConnector {
TlsState::Stream TlsState::Stream
}, },
session session,
})) }))
} }
} }
@ -151,9 +151,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0) Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
.poll(cx)
.map_err(|(err, _)| err)
} }
} }
@ -169,9 +167,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0) Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
.poll(cx)
.map_err(|(err, _)| err)
} }
} }

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use rustls::Session;
use crate::common::IoSession; use crate::common::IoSession;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
@ -60,29 +60,35 @@ where
false false
} }
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
match &this.state { match &this.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { TlsState::Stream | TlsState::WriteShutdown => {
Poll::Ready(Ok(0)) => { match stream.as_mut_pin().poll_read(cx, buf) {
this.state.shutdown_read(); Poll::Ready(Ok(0)) => {
Poll::Ready(Ok(0)) this.state.shutdown_read();
} Poll::Ready(Ok(0))
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
this.state.shutdown_write();
} }
Poll::Ready(Ok(0)) Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
} }
Poll::Ready(Err(e)) => Poll::Ready(Err(e)), }
Poll::Pending => Poll::Pending
},
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
s => unreachable!("server TLS can not hit this state: {:?}", s), s => unreachable!("server TLS can not hit this state: {:?}", s),
@ -96,17 +102,21 @@ where
{ {
/// Note: that it does not guarantee the final data to be sent. /// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`. /// To be cautious, you must manually call `flush`.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_write(cx, buf) stream.as_mut_pin().poll_write(cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_flush(cx) stream.as_mut_pin().poll_flush(cx)
} }
@ -117,8 +127,8 @@ where
} }
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream =
.set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_shutdown(cx) stream.as_mut_pin().poll_shutdown(cx)
} }
} }

View File

@ -1,21 +1,20 @@
use std::io;
use std::sync::Arc;
use std::net::ToSocketAddrs;
use tokio::prelude::*;
use tokio::net::TcpStream;
use rustls::ClientConfig; use rustls::ClientConfig;
use tokio_rustls::{ TlsConnector, client::TlsStream }; use std::io;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio_rustls::{client::TlsStream, TlsConnector};
async fn get(
async fn get(config: Arc<ClientConfig>, domain: &str, port: u16) config: Arc<ClientConfig>,
-> io::Result<(TlsStream<TcpStream>, String)> domain: &str,
{ port: u16,
) -> io::Result<(TlsStream<TcpStream>, String)> {
let connector = TlsConnector::from(config); let connector = TlsConnector::from(config);
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
let addr = (domain, port) let addr = (domain, port).to_socket_addrs()?.next().unwrap();
.to_socket_addrs()?
.next().unwrap();
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
@ -31,7 +30,9 @@ async fn get(config: Arc<ClientConfig>, domain: &str, port: u16)
#[tokio::test] #[tokio::test]
async fn test_tls12() -> io::Result<()> { async fn test_tls12() -> io::Result<()> {
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config.versions = vec![rustls::ProtocolVersion::TLSv1_2]; config.versions = vec![rustls::ProtocolVersion::TLSv1_2];
let config = Arc::new(config); let config = Arc::new(config);
let domain = "tls-v1-2.badssl.com"; let domain = "tls-v1-2.badssl.com";
@ -52,7 +53,9 @@ fn test_tls13() {
#[tokio::test] #[tokio::test]
async fn test_modern() -> io::Result<()> { async fn test_modern() -> io::Result<()> {
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let config = Arc::new(config); let config = Arc::new(config);
let domain = "mozilla-modern.badssl.com"; let domain = "mozilla-modern.badssl.com";

View File

@ -1,21 +1,17 @@
#![cfg(feature = "early-data")] #![cfg(feature = "early-data")]
use std::io::{ self, BufRead, BufReader, Cursor }; use std::io::{self, BufRead, BufReader, Cursor};
use std::process::{ Command, Child, Stdio };
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::process::{Child, Command, Stdio};
use std::process::{Child, Command, Stdio};
use std::sync::Arc; use std::sync::Arc;
use std::marker::Unpin; use std::task::{Context, Poll};
use std::pin::{ Pin };
use std::task::{ Context, Poll };
use std::time::Duration; use std::time::Duration;
use tokio::prelude::*;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio::time::delay_for; use tokio::time::delay_for;
use futures_util::{ future, ready }; use tokio_rustls::{client::TlsStream, TlsConnector};
use rustls::ClientConfig;
use tokio_rustls::{ TlsConnector, client::TlsStream };
use std::future::Future;
struct Read1<T>(T); struct Read1<T>(T);
@ -29,11 +25,12 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
} }
} }
async fn send(config: Arc<ClientConfig>, addr: SocketAddr, data: &[u8]) async fn send(
-> io::Result<TlsStream<TcpStream>> config: Arc<ClientConfig>,
{ addr: SocketAddr,
let connector = TlsConnector::from(config) data: &[u8],
.early_data(true); ) -> io::Result<TlsStream<TcpStream>> {
let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?; let stream = TcpStream::connect(&addr).await?;
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();
@ -98,10 +95,8 @@ async fn test_0rtt() -> io::Result<()> {
let stdout = handle.0.stdout.as_mut().unwrap(); let stdout = handle.0.stdout.as_mut().unwrap();
let mut lines = BufReader::new(stdout).lines(); let mut lines = BufReader::new(stdout).lines();
let has_msg1 = lines.by_ref() let has_msg1 = lines.by_ref().any(|line| line.unwrap().contains("hello"));
.any(|line| line.unwrap().contains("hello")); let has_msg2 = lines.by_ref().any(|line| line.unwrap().contains("world!"));
let has_msg2 = lines.by_ref()
.any(|line| line.unwrap().contains("world!"));
assert!(has_msg1 && has_msg2); assert!(has_msg1 && has_msg2);

View File

@ -1,29 +1,30 @@
use std::{ io, thread };
use std::io::{ BufReader, Cursor };
use std::sync::Arc;
use std::sync::mpsc::channel;
use std::net::SocketAddr;
use futures_util::future::TryFutureExt; use futures_util::future::TryFutureExt;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls::{ClientConfig, ServerConfig};
use std::io::{BufReader, Cursor};
use std::net::SocketAddr;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::{io, thread};
use tokio::io::{copy, split};
use tokio::net::{TcpListener, TcpStream};
use tokio::prelude::*; use tokio::prelude::*;
use tokio::runtime; use tokio::runtime;
use tokio::io::{ copy, split }; use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio::net::{ TcpListener, TcpStream };
use rustls::{ ServerConfig, ClientConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::{ TlsConnector, TlsAcceptor };
const CERT: &str = include_str!("end.cert"); const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain"); const CHAIN: &str = include_str!("end.chain");
const RSA: &str = include_str!("end.rsa"); const RSA: &str = include_str!("end.rsa");
lazy_static!{ lazy_static! {
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut config = ServerConfig::new(rustls::NoClientAuth::new()); let mut config = ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(cert, keys.pop().unwrap()) config
.set_single_cert(cert, keys.pop().unwrap())
.expect("invalid key or certificate"); .expect("invalid key or certificate");
let acceptor = TlsAcceptor::from(Arc::new(config)); let acceptor = TlsAcceptor::from(Arc::new(config));
@ -55,11 +56,13 @@ lazy_static!{
copy(&mut reader, &mut writer).await?; copy(&mut reader, &mut writer).await?;
Ok(()) as io::Result<()> Ok(()) as io::Result<()>
}.unwrap_or_else(|err| eprintln!("server: {:?}", err)); }
.unwrap_or_else(|err| eprintln!("server: {:?}", err));
handle.spawn(fut); handle.spawn(fut);
} }
}.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err)); }
.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err));
runtime.block_on(done); runtime.block_on(done);
}); });