Rename more tests (#1)

* Rename more tests

* Clean up smoke test

* fmt

* Clean up ci and remove all-features test
This commit is contained in:
Lucio Franco 2020-02-27 18:32:52 -05:00 committed by GitHub
parent 01fdb7ccf4
commit 7e41beaff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 402 additions and 818 deletions

View File

@ -1,6 +1,10 @@
name: CI
on: [push, pull_request]
on:
push:
branches:
- master
pull_request: {}
jobs:
check:
@ -40,7 +44,7 @@ jobs:
profile: minimal
- uses: actions/checkout@master
- name: Test
run: cargo test --all --all-features
run: cargo test --all
lints:
name: Lints

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,500 +1,126 @@
#![warn(rust_2018_idioms)]
use cfg_if::cfg_if;
use env_logger;
use futures::join;
use native_tls;
use native_tls::{Identity, TlsAcceptor, TlsConnector};
use std::io::Write;
use std::marker::Unpin;
use std::process::Command;
use std::ptr;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind};
use tokio::net::{TcpListener, TcpStream};
use tokio::stream::StreamExt;
use native_tls::{Certificate, Identity};
use std::io::Error;
use tokio::{
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_native_tls::{TlsAcceptor, TlsConnector};
macro_rules! t {
($e:expr) => {
match $e {
Ok(e) => e,
Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
}
#[tokio::test]
async fn client_to_server() {
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = srv.local_addr().unwrap();
let (server_tls, client_tls) = context();
// Create a future to accept one socket, connect the ssl stream, and then
// read all the data from it.
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let mut socket = server_tls.accept(socket).await.unwrap();
let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap();
data
};
// Create a future to connect to our server, connect the ssl stream, and
// then write a bunch of data to it.
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let socket = client_tls.connect("foobar.com", socket).await.unwrap();
copy_data(socket).await
};
// Finally, run everything!
let (data, _) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[allow(dead_code)]
struct Keys {
cert_der: Vec<u8>,
pkey_der: Vec<u8>,
pkcs12_der: Vec<u8>,
#[tokio::test]
async fn server_to_client() {
// Create a server listening on a port, then figure out what that port is
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = srv.local_addr().unwrap();
let (server_tls, client_tls) = context();
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let socket = server_tls.accept(socket).await.unwrap();
copy_data(socket).await
};
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap();
data
};
// Finally, run everything!
let (_, data) = join!(server, client);
assert!(data == vec![9; AMT]);
}
#[allow(dead_code)]
fn openssl_keys() -> &'static Keys {
static INIT: Once = Once::new();
static mut KEYS: *mut Keys = ptr::null_mut();
#[tokio::test]
async fn one_byte_at_a_time() {
const AMT: usize = 1024;
INIT.call_once(|| {
let path = t!(env::current_exe());
let path = path.parent().unwrap();
let keyfile = path.join("test.key");
let certfile = path.join("test.crt");
let config = path.join("openssl.config");
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = srv.local_addr().unwrap();
File::create(&config)
.unwrap()
.write_all(
b"\
[req]\n\
distinguished_name=dn\n\
[ dn ]\n\
CN=localhost\n\
[ ext ]\n\
basicConstraints=CA:FALSE,pathlen:0\n\
subjectAltName = @alt_names
extendedKeyUsage=serverAuth,clientAuth
[alt_names]
DNS.1 = localhost
",
)
.unwrap();
let (server_tls, client_tls) = context();
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
let output = t!(Command::new("openssl")
.arg("req")
.arg("-nodes")
.arg("-x509")
.arg("-newkey")
.arg("rsa:2048")
.arg("-config")
.arg(&config)
.arg("-extensions")
.arg("ext")
.arg("-subj")
.arg(subj)
.arg("-keyout")
.arg(&keyfile)
.arg("-out")
.arg(&certfile)
.arg("-days")
.arg("1")
.output());
assert!(output.status.success());
let crtout = t!(Command::new("openssl")
.arg("x509")
.arg("-outform")
.arg("der")
.arg("-in")
.arg(&certfile)
.output());
assert!(crtout.status.success());
let keyout = t!(Command::new("openssl")
.arg("rsa")
.arg("-outform")
.arg("der")
.arg("-in")
.arg(&keyfile)
.output());
assert!(keyout.status.success());
let pkcs12out = t!(Command::new("openssl")
.arg("pkcs12")
.arg("-export")
.arg("-nodes")
.arg("-inkey")
.arg(&keyfile)
.arg("-in")
.arg(&certfile)
.arg("-password")
.arg("pass:foobar")
.output());
assert!(pkcs12out.status.success());
let keys = Box::new(Keys {
cert_der: crtout.stdout,
pkey_der: keyout.stdout,
pkcs12_der: pkcs12out.stdout,
});
unsafe {
KEYS = Box::into_raw(keys);
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let mut socket = server_tls.accept(socket).await.unwrap();
let mut amt = 0;
for b in std::iter::repeat(9).take(AMT) {
let data = [b as u8];
socket.write_all(&data).await.unwrap();
amt += 1;
}
});
unsafe { &*KEYS }
amt
};
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
let mut data = Vec::new();
loop {
let mut buf = [0; 1];
match socket.read_exact(&mut buf).await {
Ok(_) => data.extend_from_slice(&buf),
Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(err) => panic!(err),
}
}
data
};
let (amt, data) = join!(server, client);
assert_eq!(amt, AMT);
assert!(data == vec![9; AMT as usize]);
}
cfg_if! {
if #[cfg(feature = "rustls")] {
use webpki;
use untrusted;
use std::env;
use std::fs::File;
use std::process::Command;
use std::sync::Once;
fn context() -> (TlsAcceptor, TlsConnector) {
// Certs borrowed from `rust-native-tls/tests`
let pkcs12 = include_bytes!("identity.p12");
let der = include_bytes!("root-ca.der");
use untrusted::Input;
use webpki::trust_anchor_util;
let identity = Identity::from_pkcs12(pkcs12, "mypass").unwrap();
let acceptor = native_tls::TlsAcceptor::builder(identity).build().unwrap();
fn server_cx() -> io::Result<ServerContext> {
let mut cx = ServerContext::new();
let cert = Certificate::from_der(der).unwrap();
let connector = native_tls::TlsConnector::builder()
.add_root_certificate(cert)
.build()
.unwrap();
let (cert, key) = keys();
cx.config_mut()
.set_single_cert(vec![cert.to_vec()], key.to_vec());
Ok(cx)
}
fn configure_client(cx: &mut ClientContext) {
let (cert, _key) = keys();
let cert = Input::from(cert);
let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap();
cx.config_mut().root_store.add_trust_anchors(&[anchor]);
}
// Like OpenSSL we generate certificates on the fly, but for OSX we
// also have to put them into a specific keychain. We put both the
// certificates and the keychain next to our binary.
//
// Right now I don't know of a way to programmatically create a
// self-signed certificate, so we just fork out to the `openssl` binary.
fn keys() -> (&'static [u8], &'static [u8]) {
static INIT: Once = Once::new();
static mut KEYS: *mut (Vec<u8>, Vec<u8>) = ptr::null_mut();
INIT.call_once(|| {
let (key, cert) = openssl_keys();
let path = t!(env::current_exe());
let path = path.parent().unwrap();
let keyfile = path.join("test.key");
let certfile = path.join("test.crt");
let config = path.join("openssl.config");
File::create(&config).unwrap().write_all(b"\
[req]\n\
distinguished_name=dn\n\
[ dn ]\n\
CN=localhost\n\
[ ext ]\n\
basicConstraints=CA:FALSE,pathlen:0\n\
subjectAltName = @alt_names
[alt_names]
DNS.1 = localhost
").unwrap();
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
let output = t!(Command::new("openssl")
.arg("req")
.arg("-nodes")
.arg("-x509")
.arg("-newkey").arg("rsa:2048")
.arg("-config").arg(&config)
.arg("-extensions").arg("ext")
.arg("-subj").arg(subj)
.arg("-keyout").arg(&keyfile)
.arg("-out").arg(&certfile)
.arg("-days").arg("1")
.output());
assert!(output.status.success());
let crtout = t!(Command::new("openssl")
.arg("x509")
.arg("-outform").arg("der")
.arg("-in").arg(&certfile)
.output());
assert!(crtout.status.success());
let keyout = t!(Command::new("openssl")
.arg("rsa")
.arg("-outform").arg("der")
.arg("-in").arg(&keyfile)
.output());
assert!(keyout.status.success());
let cert = crtout.stdout;
let key = keyout.stdout;
unsafe {
KEYS = Box::into_raw(Box::new((cert, key)));
}
});
unsafe {
(&(*KEYS).0, &(*KEYS).1)
}
}
} else if #[cfg(any(feature = "force-openssl",
all(not(target_os = "macos"),
not(target_os = "windows"),
not(target_os = "ios"))))] {
use std::fs::File;
use std::env;
use std::sync::Once;
fn contexts() -> (tokio_native_tls::TlsAcceptor, tokio_native_tls::TlsConnector) {
let keys = openssl_keys();
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let cert = t!(native_tls::Certificate::from_der(&keys.cert_der));
let mut client = TlsConnector::builder();
t!(client.add_root_certificate(cert).build());
(t!(srv.build()).into(), t!(client.build()).into())
}
} else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
use std::env;
use std::fs::File;
use std::sync::Once;
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
let keys = openssl_keys();
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap();
let mut client = TlsConnector::builder();
client.add_root_certificate(cert);
(t!(srv.build()).into(), t!(client.build()).into())
}
} else {
use schannel;
use winapi;
use std::env;
use std::fs::File;
use std::io;
use std::mem;
use std::sync::Once;
use schannel::cert_context::CertContext;
use schannel::cert_store::{CertStore, CertAdd, Memory};
use winapi::shared::basetsd::*;
use winapi::shared::lmcons::*;
use winapi::shared::minwindef::*;
use winapi::shared::ntdef::WCHAR;
use winapi::um::minwinbase::*;
use winapi::um::sysinfoapi::*;
use winapi::um::timezoneapi::*;
use winapi::um::wincrypt::*;
const FRIENDLY_NAME: &'static str = "tokio-tls localhost testing cert";
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
let cert = localhost_cert();
let mut store = t!(Memory::new()).into_store();
t!(store.add_cert(&cert, CertAdd::Always));
let pkcs12_der = t!(store.export_pkcs12("foobar"));
let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let client = TlsConnector::builder();
(t!(srv.build()).into(), t!(client.build()).into())
}
// ====================================================================
// Magic!
//
// Lots of magic is happening here to wrangle certificates for running
// these tests on Windows. For more information see the test suite
// in the schannel-rs crate as this is just coyping that.
//
// The general gist of this though is that the only way to add custom
// trusted certificates is to add it to the system store of trust. To
// do that we go through the whole rigamarole here to generate a new
// self-signed certificate and then insert that into the system store.
//
// This generates some dialogs, so we print what we're doing sometimes,
// and otherwise we just manage the ephemeral certificates. Because
// they're in the system store we always ensure that they're only valid
// for a small period of time (e.g. 1 day).
fn localhost_cert() -> CertContext {
static INIT: Once = Once::new();
INIT.call_once(|| {
for cert in local_root_store().certs() {
let name = match cert.friendly_name() {
Ok(name) => name,
Err(_) => continue,
};
if name != FRIENDLY_NAME {
continue
}
if !cert.is_time_valid().unwrap() {
io::stdout().write_all(br#"
The tokio-tls test suite is about to delete an old copy of one of its
certificates from your root trust store. This certificate was only valid for one
day and it is no longer needed. The host should be "localhost" and the
description should mention "tokio-tls".
"#).unwrap();
cert.delete().unwrap();
} else {
return
}
}
install_certificate().unwrap();
});
for cert in local_root_store().certs() {
let name = match cert.friendly_name() {
Ok(name) => name,
Err(_) => continue,
};
if name == FRIENDLY_NAME {
return cert
}
}
panic!("couldn't find a cert");
}
fn local_root_store() -> CertStore {
if env::var("CI").is_ok() {
CertStore::open_local_machine("Root").unwrap()
} else {
CertStore::open_current_user("Root").unwrap()
}
}
fn install_certificate() -> io::Result<CertContext> {
unsafe {
let mut provider = 0;
let mut hkey = 0;
let mut buffer = "tokio-tls test suite".encode_utf16()
.chain(Some(0))
.collect::<Vec<_>>();
let res = CryptAcquireContextW(&mut provider,
buffer.as_ptr(),
ptr::null_mut(),
PROV_RSA_FULL,
CRYPT_MACHINE_KEYSET);
if res != TRUE {
// create a new key container (since it does not exist)
let res = CryptAcquireContextW(&mut provider,
buffer.as_ptr(),
ptr::null_mut(),
PROV_RSA_FULL,
CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET);
if res != TRUE {
return Err(Error::last_os_error())
}
}
// create a new keypair (RSA-2048)
let res = CryptGenKey(provider,
AT_SIGNATURE,
0x0800<<16 | CRYPT_EXPORTABLE,
&mut hkey);
if res != TRUE {
return Err(Error::last_os_error());
}
// start creating the certificate
let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\
G=tokio_tls".encode_utf16()
.chain(Some(0))
.collect::<Vec<_>>();
let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed();
let mut cname_len = cname_buffer.len() as DWORD;
let res = CertStrToNameW(X509_ASN_ENCODING,
name.as_ptr(),
CERT_X500_NAME_STR,
ptr::null_mut(),
cname_buffer.as_mut_ptr() as *mut u8,
&mut cname_len,
ptr::null_mut());
if res != TRUE {
return Err(Error::last_os_error());
}
let mut subject_issuer = CERT_NAME_BLOB {
cbData: cname_len,
pbData: cname_buffer.as_ptr() as *mut u8,
};
let mut key_provider = CRYPT_KEY_PROV_INFO {
pwszContainerName: buffer.as_mut_ptr(),
pwszProvName: ptr::null_mut(),
dwProvType: PROV_RSA_FULL,
dwFlags: CRYPT_MACHINE_KEYSET,
cProvParam: 0,
rgProvParam: ptr::null_mut(),
dwKeySpec: AT_SIGNATURE,
};
let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER {
pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _,
Parameters: mem::zeroed(),
};
let mut expiration_date: SYSTEMTIME = mem::zeroed();
GetSystemTime(&mut expiration_date);
let mut file_time: FILETIME = mem::zeroed();
let res = SystemTimeToFileTime(&mut expiration_date,
&mut file_time);
if res != TRUE {
return Err(Error::last_os_error());
}
let mut timestamp: u64 = file_time.dwLowDateTime as u64 |
(file_time.dwHighDateTime as u64) << 32;
// one day, timestamp unit is in 100 nanosecond intervals
timestamp += (1E9 as u64) / 100 * (60 * 60 * 24);
file_time.dwLowDateTime = timestamp as u32;
file_time.dwHighDateTime = (timestamp >> 32) as u32;
let res = FileTimeToSystemTime(&file_time,
&mut expiration_date);
if res != TRUE {
return Err(Error::last_os_error());
}
// create a self signed certificate
let cert_context = CertCreateSelfSignCertificate(
0 as ULONG_PTR,
&mut subject_issuer,
0,
&mut key_provider,
&mut sig_algorithm,
ptr::null_mut(),
&mut expiration_date,
ptr::null_mut());
if cert_context.is_null() {
return Err(Error::last_os_error());
}
// TODO: this is.. a terrible hack. Right now `schannel`
// doesn't provide a public method to go from a raw
// cert context pointer to the `CertContext` structure it
// has, so we just fake it here with a transmute. This'll
// probably break at some point, but hopefully by then
// it'll have a method to do this!
struct MyCertContext<T>(T);
impl<T> Drop for MyCertContext<T> {
fn drop(&mut self) {}
}
let cert_context = MyCertContext(cert_context);
let cert_context: CertContext = mem::transmute(cert_context);
cert_context.set_friendly_name(FRIENDLY_NAME)?;
// install the certificate to the machine's local store
io::stdout().write_all(br#"
The tokio-tls test suite is about to add a certificate to your set of root
and trusted certificates. This certificate should be for the domain "localhost"
with the description related to "tokio-tls". This certificate is only valid
for one day and will be automatically deleted if you re-run the tokio-tls
test suite later.
"#).unwrap();
local_root_store().add_cert(&cert_context,
CertAdd::ReplaceExisting)?;
Ok(cert_context)
}
}
}
(acceptor.into(), connector.into())
}
const AMT: usize = 128 * 1024;
@ -517,112 +143,3 @@ async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
}
Ok(amt)
}
#[tokio::test]
async fn client_to_server() {
drop(env_logger::try_init());
// Create a server listening on a port, then figure out what that port is
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
// Create a future to accept one socket, connect the ssl stream, and then
// read all the data from it.
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let mut socket = t!(server_cx.accept(socket).await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
data
};
// Create a future to connect to our server, connect the ssl stream, and
// then write a bunch of data to it.
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let socket = t!(client_cx.connect("localhost", socket).await);
copy_data(socket).await
};
// Finally, run everything!
let (data, _) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn server_to_client() {
drop(env_logger::try_init());
// Create a server listening on a port, then figure out what that port is
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let socket = t!(server_cx.accept(socket).await);
copy_data(socket).await
};
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let mut socket = t!(client_cx.connect("localhost", socket).await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
data
};
// Finally, run everything!
let (_, data) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn one_byte_at_a_time() {
const AMT: usize = 1024;
drop(env_logger::try_init());
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let mut socket = t!(server_cx.accept(socket).await);
let mut amt = 0;
for b in std::iter::repeat(9).take(AMT) {
let data = [b as u8];
t!(socket.write_all(&data).await);
amt += 1;
}
amt
};
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let mut socket = t!(client_cx.connect("localhost", socket).await);
let mut data = Vec::new();
loop {
let mut buf = [0; 1];
match socket.read_exact(&mut buf).await {
Ok(_) => data.extend_from_slice(&buf),
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
Err(err) => panic!(err),
}
}
data
};
let (amt, data) = join!(server, client);
assert_eq!(amt, AMT);
assert!(data == vec![9; AMT as usize]);
}

View File

@ -1,7 +1,6 @@
use super::*;
use rustls::Session;
use crate::common::IoSession;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
@ -58,20 +57,24 @@ where
false
}
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(..) => Poll::Pending,
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
},
}
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
@ -80,8 +83,8 @@ where
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
},
output => output
}
output => output,
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
@ -95,10 +98,14 @@ where
{
/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
match this.state {
#[cfg(feature = "early-data")]
@ -110,9 +117,10 @@ where
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending,
Err(err) => return Poll::Ready(Err(err))
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Poll::Pending
}
Err(err) => return Poll::Ready(Err(err)),
};
if len != 0 {
data.extend_from_slice(&buf[..len]);
@ -143,10 +151,11 @@ where
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
#[cfg(feature = "early-data")] {
#[cfg(feature = "early-data")]
{
use futures_core::ready;
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
@ -176,7 +185,8 @@ where
self.state.shutdown_write();
}
#[cfg(feature = "early-data")] {
#[cfg(feature = "early-data")]
{
// we skip the handshake
if let TlsState::EarlyData(..) = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx);
@ -184,8 +194,8 @@ where
}
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_shutdown(cx)
}
}

View File

@ -1,12 +1,11 @@
use std::{ io, mem };
use std::pin::Pin;
use std::future::Future;
use std::task::{ Context, Poll };
use crate::common::{Stream, TlsState};
use futures_core::future::FusedFuture;
use tokio::io::{ AsyncRead, AsyncWrite };
use rustls::Session;
use crate::common::{ TlsState, Stream };
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};
use tokio::io::{AsyncRead, AsyncWrite};
pub(crate) trait IoSession {
type Io;
@ -26,7 +25,7 @@ impl<IS> FusedFuture for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin
IS::Session: Session + Unpin,
{
fn is_terminated(&self) -> bool {
if let MidHandshake::End = self {
@ -41,7 +40,7 @@ impl<IS> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin
IS::Session: Session + Unpin,
{
type Output = Result<IS, (io::Error, IS::Io)>;
@ -51,20 +50,21 @@ where
if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) {
if !stream.skip_handshake() {
let (state, io, session) = stream.get_mut();
let mut tls_stream = Stream::new(io, session)
.set_eof(!state.readable());
let mut tls_stream = Stream::new(io, session).set_eof(!state.readable());
macro_rules! try_poll {
( $e:expr ) => {
match $e {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
Poll::Ready(Err(err)) => {
return Poll::Ready(Err((err, stream.into_io())))
}
Poll::Pending => {
*this = MidHandshake::Handshaking(stream);
return Poll::Pending;
}
}
}
};
}
while tls_stream.session.is_handshaking() {

View File

@ -3,14 +3,13 @@ mod handshake;
#[cfg(feature = "unstable")]
mod vecbuf;
use std::pin::Pin;
use std::task::{ Poll, Context };
use std::io::{ self, Read };
use rustls::Session;
use tokio::io::{ AsyncRead, AsyncWrite };
use futures_core as futures;
pub(crate) use handshake::{ IoSession, MidHandshake };
pub(crate) use handshake::{IoSession, MidHandshake};
use rustls::Session;
use std::io::{self, Read};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
#[derive(Debug)]
pub enum TlsState {
@ -26,8 +25,7 @@ impl TlsState {
#[inline]
pub fn shutdown_read(&mut self) {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown =>
*self = TlsState::FullyShutdown,
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown,
}
}
@ -35,8 +33,7 @@ impl TlsState {
#[inline]
pub fn shutdown_write(&mut self) {
match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown =>
*self = TlsState::FullyShutdown,
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown,
}
}
@ -62,7 +59,7 @@ impl TlsState {
pub fn is_early_data(&self) -> bool {
match self {
TlsState::EarlyData(..) => true,
_ => false
_ => false,
}
}
@ -76,7 +73,7 @@ impl TlsState {
pub struct Stream<'a, IO, S> {
pub io: &'a mut IO,
pub session: &'a mut S,
pub eof: bool
pub eof: bool,
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
@ -100,28 +97,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}
pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> {
self.session.process_new_packets()
.map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
self.session.process_new_packets().map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})
io::Error::new(io::ErrorKind::InvalidData, err)
})
}
pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
cx: &'a mut Context<'b>,
}
impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
@ -131,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
let n = match self.session.read_tls(&mut reader) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
Err(err) => return Poll::Ready(Err(err))
Err(err) => return Poll::Ready(Err(err)),
};
Poll::Ready(Ok(n))
@ -143,21 +139,21 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
cx: &'a mut Context<'b>,
}
impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_write(self.cx, buf) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
fn flush(&mut self) -> io::Result<()> {
match Pin::new(&mut self.io).poll_flush(self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
@ -166,7 +162,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
result => Poll::Ready(result),
}
}
@ -176,7 +172,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
cx: &'a mut Context<'b>,
}
impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> {
@ -187,7 +183,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
@ -196,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
match self.session.writev_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
result => Poll::Ready(result),
}
}
@ -213,9 +209,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(n)) => wrlen += n,
Poll::Pending => {
write_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
@ -225,9 +221,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => {
read_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
@ -237,21 +233,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
(true, true) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
Poll::Ready(Err(err))
},
}
(_, false) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 {
Poll::Ready(Ok((rdlen, wrlen)))
} else {
Poll::Pending
},
(..) => continue
}
(_, true) if write_would_block || read_would_block => {
if rdlen != 0 || wrlen != 0 {
Poll::Ready(Ok((rdlen, wrlen)))
} else {
Poll::Pending
}
}
(..) => continue,
};
}
}
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut pos = 0;
while pos != buf.len() {
@ -262,14 +264,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
match self.read_io(cx) {
Poll::Ready(Ok(0)) => {
self.eof = true;
break
},
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
@ -280,13 +282,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)),
Ok(n) => {
pos += n;
continue
},
continue;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 =>
Poll::Ready(Ok(pos)),
Err(err) => Poll::Ready(Err(err))
}
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => {
Poll::Ready(Ok(pos))
}
Err(err) => Poll::Ready(Err(err)),
};
}
Poll::Ready(Ok(pos))
@ -294,7 +297,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut pos = 0;
while pos != buf.len() {
@ -303,25 +310,25 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
match self.session.write(&buf[pos..]) {
Ok(n) => pos += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err))
Err(err) => return Poll::Ready(Err(err)),
};
while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break
},
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
return match (pos, would_block) {
(0, true) => Poll::Pending,
(n, true) => Poll::Ready(Ok(n)),
(_, false) => continue
}
(_, false) => continue,
};
}
Poll::Ready(Ok(pos))

View File

@ -1,39 +1,44 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ Poll, Context };
use super::Stream;
use futures_core::ready;
use futures_util::future::poll_fn;
use futures_util::task::noop_waker_ref;
use tokio::io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt };
use std::io::{ self, Read, Write, BufReader, Cursor };
use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use webpki::DNSNameRef;
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use rustls::{
ServerConfig, ClientConfig,
ServerSession, ClientSession,
Session, NoClientAuth
};
use super::Stream;
struct Good<'a>(&'a mut dyn Session);
impl<'a> AsyncRead for Good<'a> {
fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.write_tls(buf.by_ref()))
}
}
impl<'a> AsyncWrite for Good<'a> {
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &[u8],
) -> Poll<io::Result<usize>> {
let len = self.0.read_tls(buf.by_ref())?;
self.0.process_new_packets()
self.0
.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(len))
}
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.process_new_packets()
self.0
.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(()))
}
@ -47,13 +52,21 @@ impl<'a> AsyncWrite for Good<'a> {
struct Pending;
impl AsyncRead for Pending {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Pending
}
}
impl AsyncWrite for Pending {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Pending
}
@ -69,13 +82,21 @@ impl AsyncWrite for Pending {
struct Eof;
impl AsyncRead for Eof {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0))
}
}
impl AsyncWrite for Eof {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}
@ -122,8 +143,14 @@ async fn stream_bad() -> io::Result<()> {
let mut bad = Pending;
let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(
poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
8
);
assert_eq!(
poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
8
);
let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer
assert!(r < 1024);
@ -164,7 +191,10 @@ async fn stream_handshake_eof() -> io::Result<()> {
let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.handshake(&mut cx);
assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof)));
assert_eq!(
r.map_err(|err| err.kind()),
Poll::Ready(Err(io::ErrorKind::UnexpectedEof))
);
Ok(()) as io::Result<()>
}
@ -204,7 +234,11 @@ fn make_pair() -> (ServerSession, ClientSession) {
(server, client)
}
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn do_handshake(
client: &mut ClientSession,
server: &mut ServerSession,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let mut good = Good(server);
let mut stream = Stream::new(&mut good, client);

View File

@ -1,23 +1,27 @@
use std::io::IoSlice;
use std::cmp::{ self, Ordering };
use bytes::Buf;
use std::cmp::{self, Ordering};
use std::io::IoSlice;
pub struct VecBuf<'a, 'b: 'a> {
pos: usize,
cur: usize,
inner: &'a [&'b [u8]]
inner: &'a [&'b [u8]],
}
impl<'a, 'b> VecBuf<'a, 'b> {
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
VecBuf { pos: 0, cur: 0, inner: vbytes }
VecBuf {
pos: 0,
cur: 0,
inner: vbytes,
}
}
}
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn remaining(&self) -> usize {
let sum = self.inner
let sum = self
.inner
.iter()
.skip(self.pos)
.map(|bytes| bytes.len())
@ -32,19 +36,21 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn advance(&mut self, cnt: usize) {
let current = self.inner[self.pos].len();
match (self.cur + cnt).cmp(&current) {
Ordering::Equal => if self.pos + 1 < self.inner.len() {
self.pos += 1;
self.cur = 0;
} else {
self.cur += cnt;
},
Ordering::Equal => {
if self.pos + 1 < self.inner.len() {
self.pos += 1;
self.cur = 0;
} else {
self.cur += cnt;
}
}
Ordering::Greater => {
if self.pos + 1 < self.inner.len() {
self.pos += 1;
}
let remaining = self.cur + cnt - current;
self.advance(remaining);
},
}
Ordering::Less => self.cur += cnt,
}
}
@ -120,8 +126,7 @@ mod test_vecbuf {
let b1: &[u8] = &mut [0];
let b2: &[u8] = &mut [0];
let mut dst: [IoSlice; 2] =
[IoSlice::new(b1), IoSlice::new(b2)];
let mut dst: [IoSlice; 2] = [IoSlice::new(b1), IoSlice::new(b2)];
assert_eq!(2, buf.bytes_vectored(&mut dst[..]));
}

View File

@ -1,19 +1,19 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
mod common;
pub mod client;
mod common;
pub mod server;
use common::{MidHandshake, Stream, TlsState};
use futures_core::future::FusedFuture;
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::future::Future;
use std::task::{ Context, Poll };
use futures_core::future::FusedFuture;
use tokio::io::{ AsyncRead, AsyncWrite };
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use webpki::DNSNameRef;
use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session };
use common::{ Stream, TlsState, MidHandshake };
pub use rustls;
pub use webpki;
@ -88,7 +88,7 @@ impl TlsConnector {
TlsState::Stream
},
session
session,
}))
}
}
@ -151,9 +151,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0)
.poll(cx)
.map_err(|(err, _)| err)
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
}
}
@ -169,9 +167,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0)
.poll(cx)
.map_err(|(err, _)| err)
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
}
}

View File

@ -1,6 +1,6 @@
use super::*;
use rustls::Session;
use crate::common::IoSession;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
@ -60,29 +60,35 @@ where
false
}
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
match &this.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
}
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
this.state.shutdown_write();
TlsState::Stream | TlsState::WriteShutdown => {
match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
}
Poll::Ready(Ok(0))
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending
},
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
#[cfg(feature = "early-data")]
s => unreachable!("server TLS can not hit this state: {:?}", s),
@ -96,17 +102,21 @@ where
{
/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_flush(cx)
}
@ -117,8 +127,8 @@ where
}
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_shutdown(cx)
}
}

View File

@ -1,21 +1,20 @@
use std::io;
use std::sync::Arc;
use std::net::ToSocketAddrs;
use tokio::prelude::*;
use tokio::net::TcpStream;
use rustls::ClientConfig;
use tokio_rustls::{ TlsConnector, client::TlsStream };
use std::io;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio_rustls::{client::TlsStream, TlsConnector};
async fn get(config: Arc<ClientConfig>, domain: &str, port: u16)
-> io::Result<(TlsStream<TcpStream>, String)>
{
async fn get(
config: Arc<ClientConfig>,
domain: &str,
port: u16,
) -> io::Result<(TlsStream<TcpStream>, String)> {
let connector = TlsConnector::from(config);
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
let addr = (domain, port)
.to_socket_addrs()?
.next().unwrap();
let addr = (domain, port).to_socket_addrs()?.next().unwrap();
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
let mut buf = Vec::new();
@ -31,7 +30,9 @@ async fn get(config: Arc<ClientConfig>, domain: &str, port: u16)
#[tokio::test]
async fn test_tls12() -> io::Result<()> {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config.versions = vec![rustls::ProtocolVersion::TLSv1_2];
let config = Arc::new(config);
let domain = "tls-v1-2.badssl.com";
@ -52,7 +53,9 @@ fn test_tls13() {
#[tokio::test]
async fn test_modern() -> io::Result<()> {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let config = Arc::new(config);
let domain = "mozilla-modern.badssl.com";

View File

@ -1,21 +1,17 @@
#![cfg(feature = "early-data")]
use std::io::{ self, BufRead, BufReader, Cursor };
use std::process::{ Command, Child, Stdio };
use std::io::{self, BufRead, BufReader, Cursor};
use std::net::SocketAddr;
use std::pin::Pin;
use std::process::{Child, Command, Stdio};
use std::process::{Child, Command, Stdio};
use std::sync::Arc;
use std::marker::Unpin;
use std::pin::{ Pin };
use std::task::{ Context, Poll };
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::prelude::*;
use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio::time::delay_for;
use futures_util::{ future, ready };
use rustls::ClientConfig;
use tokio_rustls::{ TlsConnector, client::TlsStream };
use std::future::Future;
use tokio_rustls::{client::TlsStream, TlsConnector};
struct Read1<T>(T);
@ -29,11 +25,12 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
}
}
async fn send(config: Arc<ClientConfig>, addr: SocketAddr, data: &[u8])
-> io::Result<TlsStream<TcpStream>>
{
let connector = TlsConnector::from(config)
.early_data(true);
async fn send(
config: Arc<ClientConfig>,
addr: SocketAddr,
data: &[u8],
) -> io::Result<TlsStream<TcpStream>> {
let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?;
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();
@ -98,10 +95,8 @@ async fn test_0rtt() -> io::Result<()> {
let stdout = handle.0.stdout.as_mut().unwrap();
let mut lines = BufReader::new(stdout).lines();
let has_msg1 = lines.by_ref()
.any(|line| line.unwrap().contains("hello"));
let has_msg2 = lines.by_ref()
.any(|line| line.unwrap().contains("world!"));
let has_msg1 = lines.by_ref().any(|line| line.unwrap().contains("hello"));
let has_msg2 = lines.by_ref().any(|line| line.unwrap().contains("world!"));
assert!(has_msg1 && has_msg2);

View File

@ -1,29 +1,30 @@
use std::{ io, thread };
use std::io::{ BufReader, Cursor };
use std::sync::Arc;
use std::sync::mpsc::channel;
use std::net::SocketAddr;
use futures_util::future::TryFutureExt;
use lazy_static::lazy_static;
use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls::{ClientConfig, ServerConfig};
use std::io::{BufReader, Cursor};
use std::net::SocketAddr;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::{io, thread};
use tokio::io::{copy, split};
use tokio::net::{TcpListener, TcpStream};
use tokio::prelude::*;
use tokio::runtime;
use tokio::io::{ copy, split };
use tokio::net::{ TcpListener, TcpStream };
use rustls::{ ServerConfig, ClientConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::{ TlsConnector, TlsAcceptor };
use tokio_rustls::{TlsAcceptor, TlsConnector};
const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain");
const RSA: &str = include_str!("end.rsa");
lazy_static!{
lazy_static! {
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(cert, keys.pop().unwrap())
config
.set_single_cert(cert, keys.pop().unwrap())
.expect("invalid key or certificate");
let acceptor = TlsAcceptor::from(Arc::new(config));
@ -55,11 +56,13 @@ lazy_static!{
copy(&mut reader, &mut writer).await?;
Ok(()) as io::Result<()>
}.unwrap_or_else(|err| eprintln!("server: {:?}", err));
}
.unwrap_or_else(|err| eprintln!("server: {:?}", err));
handle.spawn(fut);
}
}.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err));
}
.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err));
runtime.block_on(done);
});