Rename more tests (#1)
* Rename more tests * Clean up smoke test * fmt * Clean up ci and remove all-features test
This commit is contained in:
parent
01fdb7ccf4
commit
7e41beaff4
8
.github/workflows/CI.yml
vendored
8
.github/workflows/CI.yml
vendored
@ -1,6 +1,10 @@
|
|||||||
name: CI
|
name: CI
|
||||||
|
|
||||||
on: [push, pull_request]
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request: {}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check:
|
check:
|
||||||
@ -40,7 +44,7 @@ jobs:
|
|||||||
profile: minimal
|
profile: minimal
|
||||||
- uses: actions/checkout@master
|
- uses: actions/checkout@master
|
||||||
- name: Test
|
- name: Test
|
||||||
run: cargo test --all --all-features
|
run: cargo test --all
|
||||||
|
|
||||||
lints:
|
lints:
|
||||||
name: Lints
|
name: Lints
|
||||||
|
BIN
tokio-native-tls/tests/cert.der
Normal file
BIN
tokio-native-tls/tests/cert.der
Normal file
Binary file not shown.
BIN
tokio-native-tls/tests/identity.p12
Normal file
BIN
tokio-native-tls/tests/identity.p12
Normal file
Binary file not shown.
BIN
tokio-native-tls/tests/root-ca.der
Normal file
BIN
tokio-native-tls/tests/root-ca.der
Normal file
Binary file not shown.
@ -1,500 +1,126 @@
|
|||||||
#![warn(rust_2018_idioms)]
|
|
||||||
|
|
||||||
use cfg_if::cfg_if;
|
|
||||||
use env_logger;
|
|
||||||
use futures::join;
|
use futures::join;
|
||||||
use native_tls;
|
use native_tls::{Certificate, Identity};
|
||||||
use native_tls::{Identity, TlsAcceptor, TlsConnector};
|
use std::io::Error;
|
||||||
use std::io::Write;
|
use tokio::{
|
||||||
use std::marker::Unpin;
|
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||||
use std::process::Command;
|
net::{TcpListener, TcpStream},
|
||||||
use std::ptr;
|
};
|
||||||
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind};
|
use tokio_native_tls::{TlsAcceptor, TlsConnector};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
|
||||||
use tokio::stream::StreamExt;
|
|
||||||
|
|
||||||
macro_rules! t {
|
#[tokio::test]
|
||||||
($e:expr) => {
|
async fn client_to_server() {
|
||||||
match $e {
|
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
Ok(e) => e,
|
let addr = srv.local_addr().unwrap();
|
||||||
Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
|
|
||||||
}
|
let (server_tls, client_tls) = context();
|
||||||
|
|
||||||
|
// Create a future to accept one socket, connect the ssl stream, and then
|
||||||
|
// read all the data from it.
|
||||||
|
let server = async move {
|
||||||
|
let (socket, _) = srv.accept().await.unwrap();
|
||||||
|
let mut socket = server_tls.accept(socket).await.unwrap();
|
||||||
|
let mut data = Vec::new();
|
||||||
|
socket.read_to_end(&mut data).await.unwrap();
|
||||||
|
data
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create a future to connect to our server, connect the ssl stream, and
|
||||||
|
// then write a bunch of data to it.
|
||||||
|
let client = async move {
|
||||||
|
let socket = TcpStream::connect(&addr).await.unwrap();
|
||||||
|
let socket = client_tls.connect("foobar.com", socket).await.unwrap();
|
||||||
|
copy_data(socket).await
|
||||||
|
};
|
||||||
|
|
||||||
|
// Finally, run everything!
|
||||||
|
let (data, _) = join!(server, client);
|
||||||
|
// assert_eq!(amt, AMT);
|
||||||
|
assert!(data == vec![9; AMT]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[tokio::test]
|
||||||
struct Keys {
|
async fn server_to_client() {
|
||||||
cert_der: Vec<u8>,
|
// Create a server listening on a port, then figure out what that port is
|
||||||
pkey_der: Vec<u8>,
|
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
pkcs12_der: Vec<u8>,
|
let addr = srv.local_addr().unwrap();
|
||||||
|
|
||||||
|
let (server_tls, client_tls) = context();
|
||||||
|
|
||||||
|
let server = async move {
|
||||||
|
let (socket, _) = srv.accept().await.unwrap();
|
||||||
|
let socket = server_tls.accept(socket).await.unwrap();
|
||||||
|
copy_data(socket).await
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = async move {
|
||||||
|
let socket = TcpStream::connect(&addr).await.unwrap();
|
||||||
|
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
|
||||||
|
let mut data = Vec::new();
|
||||||
|
socket.read_to_end(&mut data).await.unwrap();
|
||||||
|
data
|
||||||
|
};
|
||||||
|
|
||||||
|
// Finally, run everything!
|
||||||
|
let (_, data) = join!(server, client);
|
||||||
|
assert!(data == vec![9; AMT]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[tokio::test]
|
||||||
fn openssl_keys() -> &'static Keys {
|
async fn one_byte_at_a_time() {
|
||||||
static INIT: Once = Once::new();
|
const AMT: usize = 1024;
|
||||||
static mut KEYS: *mut Keys = ptr::null_mut();
|
|
||||||
|
|
||||||
INIT.call_once(|| {
|
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let path = t!(env::current_exe());
|
let addr = srv.local_addr().unwrap();
|
||||||
let path = path.parent().unwrap();
|
|
||||||
let keyfile = path.join("test.key");
|
|
||||||
let certfile = path.join("test.crt");
|
|
||||||
let config = path.join("openssl.config");
|
|
||||||
|
|
||||||
File::create(&config)
|
let (server_tls, client_tls) = context();
|
||||||
.unwrap()
|
|
||||||
.write_all(
|
|
||||||
b"\
|
|
||||||
[req]\n\
|
|
||||||
distinguished_name=dn\n\
|
|
||||||
[ dn ]\n\
|
|
||||||
CN=localhost\n\
|
|
||||||
[ ext ]\n\
|
|
||||||
basicConstraints=CA:FALSE,pathlen:0\n\
|
|
||||||
subjectAltName = @alt_names
|
|
||||||
extendedKeyUsage=serverAuth,clientAuth
|
|
||||||
[alt_names]
|
|
||||||
DNS.1 = localhost
|
|
||||||
",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
|
let server = async move {
|
||||||
let output = t!(Command::new("openssl")
|
let (socket, _) = srv.accept().await.unwrap();
|
||||||
.arg("req")
|
let mut socket = server_tls.accept(socket).await.unwrap();
|
||||||
.arg("-nodes")
|
let mut amt = 0;
|
||||||
.arg("-x509")
|
for b in std::iter::repeat(9).take(AMT) {
|
||||||
.arg("-newkey")
|
let data = [b as u8];
|
||||||
.arg("rsa:2048")
|
socket.write_all(&data).await.unwrap();
|
||||||
.arg("-config")
|
amt += 1;
|
||||||
.arg(&config)
|
|
||||||
.arg("-extensions")
|
|
||||||
.arg("ext")
|
|
||||||
.arg("-subj")
|
|
||||||
.arg(subj)
|
|
||||||
.arg("-keyout")
|
|
||||||
.arg(&keyfile)
|
|
||||||
.arg("-out")
|
|
||||||
.arg(&certfile)
|
|
||||||
.arg("-days")
|
|
||||||
.arg("1")
|
|
||||||
.output());
|
|
||||||
assert!(output.status.success());
|
|
||||||
|
|
||||||
let crtout = t!(Command::new("openssl")
|
|
||||||
.arg("x509")
|
|
||||||
.arg("-outform")
|
|
||||||
.arg("der")
|
|
||||||
.arg("-in")
|
|
||||||
.arg(&certfile)
|
|
||||||
.output());
|
|
||||||
assert!(crtout.status.success());
|
|
||||||
let keyout = t!(Command::new("openssl")
|
|
||||||
.arg("rsa")
|
|
||||||
.arg("-outform")
|
|
||||||
.arg("der")
|
|
||||||
.arg("-in")
|
|
||||||
.arg(&keyfile)
|
|
||||||
.output());
|
|
||||||
assert!(keyout.status.success());
|
|
||||||
|
|
||||||
let pkcs12out = t!(Command::new("openssl")
|
|
||||||
.arg("pkcs12")
|
|
||||||
.arg("-export")
|
|
||||||
.arg("-nodes")
|
|
||||||
.arg("-inkey")
|
|
||||||
.arg(&keyfile)
|
|
||||||
.arg("-in")
|
|
||||||
.arg(&certfile)
|
|
||||||
.arg("-password")
|
|
||||||
.arg("pass:foobar")
|
|
||||||
.output());
|
|
||||||
assert!(pkcs12out.status.success());
|
|
||||||
|
|
||||||
let keys = Box::new(Keys {
|
|
||||||
cert_der: crtout.stdout,
|
|
||||||
pkey_der: keyout.stdout,
|
|
||||||
pkcs12_der: pkcs12out.stdout,
|
|
||||||
});
|
|
||||||
unsafe {
|
|
||||||
KEYS = Box::into_raw(keys);
|
|
||||||
}
|
}
|
||||||
});
|
amt
|
||||||
unsafe { &*KEYS }
|
};
|
||||||
|
|
||||||
|
let client = async move {
|
||||||
|
let socket = TcpStream::connect(&addr).await.unwrap();
|
||||||
|
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
|
||||||
|
let mut data = Vec::new();
|
||||||
|
loop {
|
||||||
|
let mut buf = [0; 1];
|
||||||
|
match socket.read_exact(&mut buf).await {
|
||||||
|
Ok(_) => data.extend_from_slice(&buf),
|
||||||
|
Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||||
|
Err(err) => panic!(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data
|
||||||
|
};
|
||||||
|
|
||||||
|
let (amt, data) = join!(server, client);
|
||||||
|
assert_eq!(amt, AMT);
|
||||||
|
assert!(data == vec![9; AMT as usize]);
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg_if! {
|
fn context() -> (TlsAcceptor, TlsConnector) {
|
||||||
if #[cfg(feature = "rustls")] {
|
// Certs borrowed from `rust-native-tls/tests`
|
||||||
use webpki;
|
let pkcs12 = include_bytes!("identity.p12");
|
||||||
use untrusted;
|
let der = include_bytes!("root-ca.der");
|
||||||
use std::env;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::process::Command;
|
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
use untrusted::Input;
|
let identity = Identity::from_pkcs12(pkcs12, "mypass").unwrap();
|
||||||
use webpki::trust_anchor_util;
|
let acceptor = native_tls::TlsAcceptor::builder(identity).build().unwrap();
|
||||||
|
|
||||||
fn server_cx() -> io::Result<ServerContext> {
|
let cert = Certificate::from_der(der).unwrap();
|
||||||
let mut cx = ServerContext::new();
|
let connector = native_tls::TlsConnector::builder()
|
||||||
|
.add_root_certificate(cert)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let (cert, key) = keys();
|
(acceptor.into(), connector.into())
|
||||||
cx.config_mut()
|
|
||||||
.set_single_cert(vec![cert.to_vec()], key.to_vec());
|
|
||||||
|
|
||||||
Ok(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn configure_client(cx: &mut ClientContext) {
|
|
||||||
let (cert, _key) = keys();
|
|
||||||
let cert = Input::from(cert);
|
|
||||||
let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap();
|
|
||||||
cx.config_mut().root_store.add_trust_anchors(&[anchor]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Like OpenSSL we generate certificates on the fly, but for OSX we
|
|
||||||
// also have to put them into a specific keychain. We put both the
|
|
||||||
// certificates and the keychain next to our binary.
|
|
||||||
//
|
|
||||||
// Right now I don't know of a way to programmatically create a
|
|
||||||
// self-signed certificate, so we just fork out to the `openssl` binary.
|
|
||||||
fn keys() -> (&'static [u8], &'static [u8]) {
|
|
||||||
static INIT: Once = Once::new();
|
|
||||||
static mut KEYS: *mut (Vec<u8>, Vec<u8>) = ptr::null_mut();
|
|
||||||
|
|
||||||
INIT.call_once(|| {
|
|
||||||
let (key, cert) = openssl_keys();
|
|
||||||
let path = t!(env::current_exe());
|
|
||||||
let path = path.parent().unwrap();
|
|
||||||
let keyfile = path.join("test.key");
|
|
||||||
let certfile = path.join("test.crt");
|
|
||||||
let config = path.join("openssl.config");
|
|
||||||
|
|
||||||
File::create(&config).unwrap().write_all(b"\
|
|
||||||
[req]\n\
|
|
||||||
distinguished_name=dn\n\
|
|
||||||
[ dn ]\n\
|
|
||||||
CN=localhost\n\
|
|
||||||
[ ext ]\n\
|
|
||||||
basicConstraints=CA:FALSE,pathlen:0\n\
|
|
||||||
subjectAltName = @alt_names
|
|
||||||
[alt_names]
|
|
||||||
DNS.1 = localhost
|
|
||||||
").unwrap();
|
|
||||||
|
|
||||||
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
|
|
||||||
let output = t!(Command::new("openssl")
|
|
||||||
.arg("req")
|
|
||||||
.arg("-nodes")
|
|
||||||
.arg("-x509")
|
|
||||||
.arg("-newkey").arg("rsa:2048")
|
|
||||||
.arg("-config").arg(&config)
|
|
||||||
.arg("-extensions").arg("ext")
|
|
||||||
.arg("-subj").arg(subj)
|
|
||||||
.arg("-keyout").arg(&keyfile)
|
|
||||||
.arg("-out").arg(&certfile)
|
|
||||||
.arg("-days").arg("1")
|
|
||||||
.output());
|
|
||||||
assert!(output.status.success());
|
|
||||||
|
|
||||||
let crtout = t!(Command::new("openssl")
|
|
||||||
.arg("x509")
|
|
||||||
.arg("-outform").arg("der")
|
|
||||||
.arg("-in").arg(&certfile)
|
|
||||||
.output());
|
|
||||||
assert!(crtout.status.success());
|
|
||||||
let keyout = t!(Command::new("openssl")
|
|
||||||
.arg("rsa")
|
|
||||||
.arg("-outform").arg("der")
|
|
||||||
.arg("-in").arg(&keyfile)
|
|
||||||
.output());
|
|
||||||
assert!(keyout.status.success());
|
|
||||||
|
|
||||||
let cert = crtout.stdout;
|
|
||||||
let key = keyout.stdout;
|
|
||||||
unsafe {
|
|
||||||
KEYS = Box::into_raw(Box::new((cert, key)));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
unsafe {
|
|
||||||
(&(*KEYS).0, &(*KEYS).1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if #[cfg(any(feature = "force-openssl",
|
|
||||||
all(not(target_os = "macos"),
|
|
||||||
not(target_os = "windows"),
|
|
||||||
not(target_os = "ios"))))] {
|
|
||||||
use std::fs::File;
|
|
||||||
use std::env;
|
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
fn contexts() -> (tokio_native_tls::TlsAcceptor, tokio_native_tls::TlsConnector) {
|
|
||||||
let keys = openssl_keys();
|
|
||||||
|
|
||||||
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
|
|
||||||
let srv = TlsAcceptor::builder(pkcs12);
|
|
||||||
|
|
||||||
let cert = t!(native_tls::Certificate::from_der(&keys.cert_der));
|
|
||||||
|
|
||||||
let mut client = TlsConnector::builder();
|
|
||||||
t!(client.add_root_certificate(cert).build());
|
|
||||||
|
|
||||||
(t!(srv.build()).into(), t!(client.build()).into())
|
|
||||||
}
|
|
||||||
} else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
|
|
||||||
use std::env;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
|
|
||||||
let keys = openssl_keys();
|
|
||||||
|
|
||||||
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
|
|
||||||
let srv = TlsAcceptor::builder(pkcs12);
|
|
||||||
|
|
||||||
let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap();
|
|
||||||
let mut client = TlsConnector::builder();
|
|
||||||
client.add_root_certificate(cert);
|
|
||||||
|
|
||||||
(t!(srv.build()).into(), t!(client.build()).into())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
use schannel;
|
|
||||||
use winapi;
|
|
||||||
|
|
||||||
use std::env;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io;
|
|
||||||
use std::mem;
|
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
use schannel::cert_context::CertContext;
|
|
||||||
use schannel::cert_store::{CertStore, CertAdd, Memory};
|
|
||||||
use winapi::shared::basetsd::*;
|
|
||||||
use winapi::shared::lmcons::*;
|
|
||||||
use winapi::shared::minwindef::*;
|
|
||||||
use winapi::shared::ntdef::WCHAR;
|
|
||||||
use winapi::um::minwinbase::*;
|
|
||||||
use winapi::um::sysinfoapi::*;
|
|
||||||
use winapi::um::timezoneapi::*;
|
|
||||||
use winapi::um::wincrypt::*;
|
|
||||||
|
|
||||||
const FRIENDLY_NAME: &'static str = "tokio-tls localhost testing cert";
|
|
||||||
|
|
||||||
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
|
|
||||||
let cert = localhost_cert();
|
|
||||||
let mut store = t!(Memory::new()).into_store();
|
|
||||||
t!(store.add_cert(&cert, CertAdd::Always));
|
|
||||||
let pkcs12_der = t!(store.export_pkcs12("foobar"));
|
|
||||||
let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar"));
|
|
||||||
|
|
||||||
let srv = TlsAcceptor::builder(pkcs12);
|
|
||||||
let client = TlsConnector::builder();
|
|
||||||
(t!(srv.build()).into(), t!(client.build()).into())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ====================================================================
|
|
||||||
// Magic!
|
|
||||||
//
|
|
||||||
// Lots of magic is happening here to wrangle certificates for running
|
|
||||||
// these tests on Windows. For more information see the test suite
|
|
||||||
// in the schannel-rs crate as this is just coyping that.
|
|
||||||
//
|
|
||||||
// The general gist of this though is that the only way to add custom
|
|
||||||
// trusted certificates is to add it to the system store of trust. To
|
|
||||||
// do that we go through the whole rigamarole here to generate a new
|
|
||||||
// self-signed certificate and then insert that into the system store.
|
|
||||||
//
|
|
||||||
// This generates some dialogs, so we print what we're doing sometimes,
|
|
||||||
// and otherwise we just manage the ephemeral certificates. Because
|
|
||||||
// they're in the system store we always ensure that they're only valid
|
|
||||||
// for a small period of time (e.g. 1 day).
|
|
||||||
|
|
||||||
fn localhost_cert() -> CertContext {
|
|
||||||
static INIT: Once = Once::new();
|
|
||||||
INIT.call_once(|| {
|
|
||||||
for cert in local_root_store().certs() {
|
|
||||||
let name = match cert.friendly_name() {
|
|
||||||
Ok(name) => name,
|
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
if name != FRIENDLY_NAME {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !cert.is_time_valid().unwrap() {
|
|
||||||
io::stdout().write_all(br#"
|
|
||||||
|
|
||||||
The tokio-tls test suite is about to delete an old copy of one of its
|
|
||||||
certificates from your root trust store. This certificate was only valid for one
|
|
||||||
day and it is no longer needed. The host should be "localhost" and the
|
|
||||||
description should mention "tokio-tls".
|
|
||||||
|
|
||||||
"#).unwrap();
|
|
||||||
cert.delete().unwrap();
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
install_certificate().unwrap();
|
|
||||||
});
|
|
||||||
|
|
||||||
for cert in local_root_store().certs() {
|
|
||||||
let name = match cert.friendly_name() {
|
|
||||||
Ok(name) => name,
|
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
if name == FRIENDLY_NAME {
|
|
||||||
return cert
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
panic!("couldn't find a cert");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn local_root_store() -> CertStore {
|
|
||||||
if env::var("CI").is_ok() {
|
|
||||||
CertStore::open_local_machine("Root").unwrap()
|
|
||||||
} else {
|
|
||||||
CertStore::open_current_user("Root").unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn install_certificate() -> io::Result<CertContext> {
|
|
||||||
unsafe {
|
|
||||||
let mut provider = 0;
|
|
||||||
let mut hkey = 0;
|
|
||||||
|
|
||||||
let mut buffer = "tokio-tls test suite".encode_utf16()
|
|
||||||
.chain(Some(0))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let res = CryptAcquireContextW(&mut provider,
|
|
||||||
buffer.as_ptr(),
|
|
||||||
ptr::null_mut(),
|
|
||||||
PROV_RSA_FULL,
|
|
||||||
CRYPT_MACHINE_KEYSET);
|
|
||||||
if res != TRUE {
|
|
||||||
// create a new key container (since it does not exist)
|
|
||||||
let res = CryptAcquireContextW(&mut provider,
|
|
||||||
buffer.as_ptr(),
|
|
||||||
ptr::null_mut(),
|
|
||||||
PROV_RSA_FULL,
|
|
||||||
CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET);
|
|
||||||
if res != TRUE {
|
|
||||||
return Err(Error::last_os_error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create a new keypair (RSA-2048)
|
|
||||||
let res = CryptGenKey(provider,
|
|
||||||
AT_SIGNATURE,
|
|
||||||
0x0800<<16 | CRYPT_EXPORTABLE,
|
|
||||||
&mut hkey);
|
|
||||||
if res != TRUE {
|
|
||||||
return Err(Error::last_os_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// start creating the certificate
|
|
||||||
let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\
|
|
||||||
G=tokio_tls".encode_utf16()
|
|
||||||
.chain(Some(0))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed();
|
|
||||||
let mut cname_len = cname_buffer.len() as DWORD;
|
|
||||||
let res = CertStrToNameW(X509_ASN_ENCODING,
|
|
||||||
name.as_ptr(),
|
|
||||||
CERT_X500_NAME_STR,
|
|
||||||
ptr::null_mut(),
|
|
||||||
cname_buffer.as_mut_ptr() as *mut u8,
|
|
||||||
&mut cname_len,
|
|
||||||
ptr::null_mut());
|
|
||||||
if res != TRUE {
|
|
||||||
return Err(Error::last_os_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut subject_issuer = CERT_NAME_BLOB {
|
|
||||||
cbData: cname_len,
|
|
||||||
pbData: cname_buffer.as_ptr() as *mut u8,
|
|
||||||
};
|
|
||||||
let mut key_provider = CRYPT_KEY_PROV_INFO {
|
|
||||||
pwszContainerName: buffer.as_mut_ptr(),
|
|
||||||
pwszProvName: ptr::null_mut(),
|
|
||||||
dwProvType: PROV_RSA_FULL,
|
|
||||||
dwFlags: CRYPT_MACHINE_KEYSET,
|
|
||||||
cProvParam: 0,
|
|
||||||
rgProvParam: ptr::null_mut(),
|
|
||||||
dwKeySpec: AT_SIGNATURE,
|
|
||||||
};
|
|
||||||
let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER {
|
|
||||||
pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _,
|
|
||||||
Parameters: mem::zeroed(),
|
|
||||||
};
|
|
||||||
let mut expiration_date: SYSTEMTIME = mem::zeroed();
|
|
||||||
GetSystemTime(&mut expiration_date);
|
|
||||||
let mut file_time: FILETIME = mem::zeroed();
|
|
||||||
let res = SystemTimeToFileTime(&mut expiration_date,
|
|
||||||
&mut file_time);
|
|
||||||
if res != TRUE {
|
|
||||||
return Err(Error::last_os_error());
|
|
||||||
}
|
|
||||||
let mut timestamp: u64 = file_time.dwLowDateTime as u64 |
|
|
||||||
(file_time.dwHighDateTime as u64) << 32;
|
|
||||||
// one day, timestamp unit is in 100 nanosecond intervals
|
|
||||||
timestamp += (1E9 as u64) / 100 * (60 * 60 * 24);
|
|
||||||
file_time.dwLowDateTime = timestamp as u32;
|
|
||||||
file_time.dwHighDateTime = (timestamp >> 32) as u32;
|
|
||||||
let res = FileTimeToSystemTime(&file_time,
|
|
||||||
&mut expiration_date);
|
|
||||||
if res != TRUE {
|
|
||||||
return Err(Error::last_os_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// create a self signed certificate
|
|
||||||
let cert_context = CertCreateSelfSignCertificate(
|
|
||||||
0 as ULONG_PTR,
|
|
||||||
&mut subject_issuer,
|
|
||||||
0,
|
|
||||||
&mut key_provider,
|
|
||||||
&mut sig_algorithm,
|
|
||||||
ptr::null_mut(),
|
|
||||||
&mut expiration_date,
|
|
||||||
ptr::null_mut());
|
|
||||||
if cert_context.is_null() {
|
|
||||||
return Err(Error::last_os_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: this is.. a terrible hack. Right now `schannel`
|
|
||||||
// doesn't provide a public method to go from a raw
|
|
||||||
// cert context pointer to the `CertContext` structure it
|
|
||||||
// has, so we just fake it here with a transmute. This'll
|
|
||||||
// probably break at some point, but hopefully by then
|
|
||||||
// it'll have a method to do this!
|
|
||||||
struct MyCertContext<T>(T);
|
|
||||||
impl<T> Drop for MyCertContext<T> {
|
|
||||||
fn drop(&mut self) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
let cert_context = MyCertContext(cert_context);
|
|
||||||
let cert_context: CertContext = mem::transmute(cert_context);
|
|
||||||
|
|
||||||
cert_context.set_friendly_name(FRIENDLY_NAME)?;
|
|
||||||
|
|
||||||
// install the certificate to the machine's local store
|
|
||||||
io::stdout().write_all(br#"
|
|
||||||
|
|
||||||
The tokio-tls test suite is about to add a certificate to your set of root
|
|
||||||
and trusted certificates. This certificate should be for the domain "localhost"
|
|
||||||
with the description related to "tokio-tls". This certificate is only valid
|
|
||||||
for one day and will be automatically deleted if you re-run the tokio-tls
|
|
||||||
test suite later.
|
|
||||||
|
|
||||||
"#).unwrap();
|
|
||||||
local_root_store().add_cert(&cert_context,
|
|
||||||
CertAdd::ReplaceExisting)?;
|
|
||||||
Ok(cert_context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const AMT: usize = 128 * 1024;
|
const AMT: usize = 128 * 1024;
|
||||||
@ -517,112 +143,3 @@ async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
|
|||||||
}
|
}
|
||||||
Ok(amt)
|
Ok(amt)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn client_to_server() {
|
|
||||||
drop(env_logger::try_init());
|
|
||||||
|
|
||||||
// Create a server listening on a port, then figure out what that port is
|
|
||||||
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
|
|
||||||
let addr = t!(srv.local_addr());
|
|
||||||
|
|
||||||
let (server_cx, client_cx) = contexts();
|
|
||||||
|
|
||||||
// Create a future to accept one socket, connect the ssl stream, and then
|
|
||||||
// read all the data from it.
|
|
||||||
let server = async move {
|
|
||||||
let mut incoming = srv.incoming();
|
|
||||||
let socket = t!(incoming.next().await.unwrap());
|
|
||||||
let mut socket = t!(server_cx.accept(socket).await);
|
|
||||||
let mut data = Vec::new();
|
|
||||||
t!(socket.read_to_end(&mut data).await);
|
|
||||||
data
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create a future to connect to our server, connect the ssl stream, and
|
|
||||||
// then write a bunch of data to it.
|
|
||||||
let client = async move {
|
|
||||||
let socket = t!(TcpStream::connect(&addr).await);
|
|
||||||
let socket = t!(client_cx.connect("localhost", socket).await);
|
|
||||||
copy_data(socket).await
|
|
||||||
};
|
|
||||||
|
|
||||||
// Finally, run everything!
|
|
||||||
let (data, _) = join!(server, client);
|
|
||||||
// assert_eq!(amt, AMT);
|
|
||||||
assert!(data == vec![9; AMT]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn server_to_client() {
|
|
||||||
drop(env_logger::try_init());
|
|
||||||
|
|
||||||
// Create a server listening on a port, then figure out what that port is
|
|
||||||
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
|
|
||||||
let addr = t!(srv.local_addr());
|
|
||||||
|
|
||||||
let (server_cx, client_cx) = contexts();
|
|
||||||
|
|
||||||
let server = async move {
|
|
||||||
let mut incoming = srv.incoming();
|
|
||||||
let socket = t!(incoming.next().await.unwrap());
|
|
||||||
let socket = t!(server_cx.accept(socket).await);
|
|
||||||
copy_data(socket).await
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = async move {
|
|
||||||
let socket = t!(TcpStream::connect(&addr).await);
|
|
||||||
let mut socket = t!(client_cx.connect("localhost", socket).await);
|
|
||||||
let mut data = Vec::new();
|
|
||||||
t!(socket.read_to_end(&mut data).await);
|
|
||||||
data
|
|
||||||
};
|
|
||||||
|
|
||||||
// Finally, run everything!
|
|
||||||
let (_, data) = join!(server, client);
|
|
||||||
// assert_eq!(amt, AMT);
|
|
||||||
assert!(data == vec![9; AMT]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn one_byte_at_a_time() {
|
|
||||||
const AMT: usize = 1024;
|
|
||||||
drop(env_logger::try_init());
|
|
||||||
|
|
||||||
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
|
|
||||||
let addr = t!(srv.local_addr());
|
|
||||||
|
|
||||||
let (server_cx, client_cx) = contexts();
|
|
||||||
|
|
||||||
let server = async move {
|
|
||||||
let mut incoming = srv.incoming();
|
|
||||||
let socket = t!(incoming.next().await.unwrap());
|
|
||||||
let mut socket = t!(server_cx.accept(socket).await);
|
|
||||||
let mut amt = 0;
|
|
||||||
for b in std::iter::repeat(9).take(AMT) {
|
|
||||||
let data = [b as u8];
|
|
||||||
t!(socket.write_all(&data).await);
|
|
||||||
amt += 1;
|
|
||||||
}
|
|
||||||
amt
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = async move {
|
|
||||||
let socket = t!(TcpStream::connect(&addr).await);
|
|
||||||
let mut socket = t!(client_cx.connect("localhost", socket).await);
|
|
||||||
let mut data = Vec::new();
|
|
||||||
loop {
|
|
||||||
let mut buf = [0; 1];
|
|
||||||
match socket.read_exact(&mut buf).await {
|
|
||||||
Ok(_) => data.extend_from_slice(&buf),
|
|
||||||
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
|
|
||||||
Err(err) => panic!(err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
data
|
|
||||||
};
|
|
||||||
|
|
||||||
let (amt, data) = join!(server, client);
|
|
||||||
assert_eq!(amt, AMT);
|
|
||||||
assert!(data == vec![9; AMT as usize]);
|
|
||||||
}
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use rustls::Session;
|
|
||||||
use crate::common::IoSession;
|
use crate::common::IoSession;
|
||||||
|
use rustls::Session;
|
||||||
|
|
||||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||||
/// protocol.
|
/// protocol.
|
||||||
@ -58,20 +57,24 @@ where
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
match self.state {
|
match self.state {
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
TlsState::EarlyData(..) => Poll::Pending,
|
TlsState::EarlyData(..) => Poll::Pending,
|
||||||
TlsState::Stream | TlsState::WriteShutdown => {
|
TlsState::Stream | TlsState::WriteShutdown => {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
|
|
||||||
match stream.as_mut_pin().poll_read(cx, buf) {
|
match stream.as_mut_pin().poll_read(cx, buf) {
|
||||||
Poll::Ready(Ok(0)) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
this.state.shutdown_read();
|
this.state.shutdown_read();
|
||||||
Poll::Ready(Ok(0))
|
Poll::Ready(Ok(0))
|
||||||
},
|
}
|
||||||
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
|
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
|
||||||
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||||
this.state.shutdown_read();
|
this.state.shutdown_read();
|
||||||
@ -80,8 +83,8 @@ where
|
|||||||
this.state.shutdown_write();
|
this.state.shutdown_write();
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(0))
|
Poll::Ready(Ok(0))
|
||||||
},
|
}
|
||||||
output => output
|
output => output,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
|
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
|
||||||
@ -95,10 +98,14 @@ where
|
|||||||
{
|
{
|
||||||
/// Note: that it does not guarantee the final data to be sent.
|
/// Note: that it does not guarantee the final data to be sent.
|
||||||
/// To be cautious, you must manually call `flush`.
|
/// To be cautious, you must manually call `flush`.
|
||||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
|
|
||||||
match this.state {
|
match this.state {
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
@ -110,9 +117,10 @@ where
|
|||||||
if let Some(mut early_data) = stream.session.early_data() {
|
if let Some(mut early_data) = stream.session.early_data() {
|
||||||
let len = match early_data.write(buf) {
|
let len = match early_data.write(buf) {
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||||
return Poll::Pending,
|
return Poll::Pending
|
||||||
Err(err) => return Poll::Ready(Err(err))
|
}
|
||||||
|
Err(err) => return Poll::Ready(Err(err)),
|
||||||
};
|
};
|
||||||
if len != 0 {
|
if len != 0 {
|
||||||
data.extend_from_slice(&buf[..len]);
|
data.extend_from_slice(&buf[..len]);
|
||||||
@ -143,10 +151,11 @@ where
|
|||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
|
|
||||||
#[cfg(feature = "early-data")] {
|
#[cfg(feature = "early-data")]
|
||||||
|
{
|
||||||
use futures_core::ready;
|
use futures_core::ready;
|
||||||
|
|
||||||
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
|
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
|
||||||
@ -176,7 +185,8 @@ where
|
|||||||
self.state.shutdown_write();
|
self.state.shutdown_write();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "early-data")] {
|
#[cfg(feature = "early-data")]
|
||||||
|
{
|
||||||
// we skip the handshake
|
// we skip the handshake
|
||||||
if let TlsState::EarlyData(..) = self.state {
|
if let TlsState::EarlyData(..) = self.state {
|
||||||
return Pin::new(&mut self.io).poll_shutdown(cx);
|
return Pin::new(&mut self.io).poll_shutdown(cx);
|
||||||
@ -184,8 +194,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
stream.as_mut_pin().poll_shutdown(cx)
|
stream.as_mut_pin().poll_shutdown(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
use std::{ io, mem };
|
use crate::common::{Stream, TlsState};
|
||||||
use std::pin::Pin;
|
|
||||||
use std::future::Future;
|
|
||||||
use std::task::{ Context, Poll };
|
|
||||||
use futures_core::future::FusedFuture;
|
use futures_core::future::FusedFuture;
|
||||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
|
||||||
use rustls::Session;
|
use rustls::Session;
|
||||||
use crate::common::{ TlsState, Stream };
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::{io, mem};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
pub(crate) trait IoSession {
|
pub(crate) trait IoSession {
|
||||||
type Io;
|
type Io;
|
||||||
@ -26,7 +25,7 @@ impl<IS> FusedFuture for MidHandshake<IS>
|
|||||||
where
|
where
|
||||||
IS: IoSession + Unpin,
|
IS: IoSession + Unpin,
|
||||||
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
||||||
IS::Session: Session + Unpin
|
IS::Session: Session + Unpin,
|
||||||
{
|
{
|
||||||
fn is_terminated(&self) -> bool {
|
fn is_terminated(&self) -> bool {
|
||||||
if let MidHandshake::End = self {
|
if let MidHandshake::End = self {
|
||||||
@ -41,7 +40,7 @@ impl<IS> Future for MidHandshake<IS>
|
|||||||
where
|
where
|
||||||
IS: IoSession + Unpin,
|
IS: IoSession + Unpin,
|
||||||
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
||||||
IS::Session: Session + Unpin
|
IS::Session: Session + Unpin,
|
||||||
{
|
{
|
||||||
type Output = Result<IS, (io::Error, IS::Io)>;
|
type Output = Result<IS, (io::Error, IS::Io)>;
|
||||||
|
|
||||||
@ -51,20 +50,21 @@ where
|
|||||||
if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) {
|
if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) {
|
||||||
if !stream.skip_handshake() {
|
if !stream.skip_handshake() {
|
||||||
let (state, io, session) = stream.get_mut();
|
let (state, io, session) = stream.get_mut();
|
||||||
let mut tls_stream = Stream::new(io, session)
|
let mut tls_stream = Stream::new(io, session).set_eof(!state.readable());
|
||||||
.set_eof(!state.readable());
|
|
||||||
|
|
||||||
macro_rules! try_poll {
|
macro_rules! try_poll {
|
||||||
( $e:expr ) => {
|
( $e:expr ) => {
|
||||||
match $e {
|
match $e {
|
||||||
Poll::Ready(Ok(_)) => (),
|
Poll::Ready(Ok(_)) => (),
|
||||||
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
|
Poll::Ready(Err(err)) => {
|
||||||
|
return Poll::Ready(Err((err, stream.into_io())))
|
||||||
|
}
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
*this = MidHandshake::Handshaking(stream);
|
*this = MidHandshake::Handshaking(stream);
|
||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
while tls_stream.session.is_handshaking() {
|
while tls_stream.session.is_handshaking() {
|
||||||
|
@ -3,14 +3,13 @@ mod handshake;
|
|||||||
#[cfg(feature = "unstable")]
|
#[cfg(feature = "unstable")]
|
||||||
mod vecbuf;
|
mod vecbuf;
|
||||||
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::{ Poll, Context };
|
|
||||||
use std::io::{ self, Read };
|
|
||||||
use rustls::Session;
|
|
||||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
|
||||||
use futures_core as futures;
|
use futures_core as futures;
|
||||||
pub(crate) use handshake::{ IoSession, MidHandshake };
|
pub(crate) use handshake::{IoSession, MidHandshake};
|
||||||
|
use rustls::Session;
|
||||||
|
use std::io::{self, Read};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum TlsState {
|
pub enum TlsState {
|
||||||
@ -26,8 +25,7 @@ impl TlsState {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn shutdown_read(&mut self) {
|
pub fn shutdown_read(&mut self) {
|
||||||
match *self {
|
match *self {
|
||||||
TlsState::WriteShutdown | TlsState::FullyShutdown =>
|
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
||||||
*self = TlsState::FullyShutdown,
|
|
||||||
_ => *self = TlsState::ReadShutdown,
|
_ => *self = TlsState::ReadShutdown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -35,8 +33,7 @@ impl TlsState {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn shutdown_write(&mut self) {
|
pub fn shutdown_write(&mut self) {
|
||||||
match *self {
|
match *self {
|
||||||
TlsState::ReadShutdown | TlsState::FullyShutdown =>
|
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
||||||
*self = TlsState::FullyShutdown,
|
|
||||||
_ => *self = TlsState::WriteShutdown,
|
_ => *self = TlsState::WriteShutdown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -62,7 +59,7 @@ impl TlsState {
|
|||||||
pub fn is_early_data(&self) -> bool {
|
pub fn is_early_data(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
TlsState::EarlyData(..) => true,
|
TlsState::EarlyData(..) => true,
|
||||||
_ => false
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,7 +73,7 @@ impl TlsState {
|
|||||||
pub struct Stream<'a, IO, S> {
|
pub struct Stream<'a, IO, S> {
|
||||||
pub io: &'a mut IO,
|
pub io: &'a mut IO,
|
||||||
pub session: &'a mut S,
|
pub session: &'a mut S,
|
||||||
pub eof: bool
|
pub eof: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
||||||
@ -100,28 +97,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> {
|
pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> {
|
||||||
self.session.process_new_packets()
|
self.session.process_new_packets().map_err(|err| {
|
||||||
.map_err(|err| {
|
// In case we have an alert to send describing this error,
|
||||||
// In case we have an alert to send describing this error,
|
// try a last-gasp write -- but don't predate the primary
|
||||||
// try a last-gasp write -- but don't predate the primary
|
// error.
|
||||||
// error.
|
let _ = self.write_io(cx);
|
||||||
let _ = self.write_io(cx);
|
|
||||||
|
|
||||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
|
pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
|
||||||
struct Reader<'a, 'b, T> {
|
struct Reader<'a, 'b, T> {
|
||||||
io: &'a mut T,
|
io: &'a mut T,
|
||||||
cx: &'a mut Context<'b>
|
cx: &'a mut Context<'b>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
|
impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
|
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
|
||||||
Poll::Ready(result) => result,
|
Poll::Ready(result) => result,
|
||||||
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
|
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
let n = match self.session.read_tls(&mut reader) {
|
let n = match self.session.read_tls(&mut reader) {
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
|
||||||
Err(err) => return Poll::Ready(Err(err))
|
Err(err) => return Poll::Ready(Err(err)),
|
||||||
};
|
};
|
||||||
|
|
||||||
Poll::Ready(Ok(n))
|
Poll::Ready(Ok(n))
|
||||||
@ -143,21 +139,21 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
|
|
||||||
struct Writer<'a, 'b, T> {
|
struct Writer<'a, 'b, T> {
|
||||||
io: &'a mut T,
|
io: &'a mut T,
|
||||||
cx: &'a mut Context<'b>
|
cx: &'a mut Context<'b>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
|
impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
match Pin::new(&mut self.io).poll_write(self.cx, buf) {
|
match Pin::new(&mut self.io).poll_write(self.cx, buf) {
|
||||||
Poll::Ready(result) => result,
|
Poll::Ready(result) => result,
|
||||||
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
|
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
match Pin::new(&mut self.io).poll_flush(self.cx) {
|
match Pin::new(&mut self.io).poll_flush(self.cx) {
|
||||||
Poll::Ready(result) => result,
|
Poll::Ready(result) => result,
|
||||||
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
|
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -166,7 +162,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
|
|
||||||
match self.session.write_tls(&mut writer) {
|
match self.session.write_tls(&mut writer) {
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
||||||
result => Poll::Ready(result)
|
result => Poll::Ready(result),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +172,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
|
|
||||||
struct Writer<'a, 'b, T> {
|
struct Writer<'a, 'b, T> {
|
||||||
io: &'a mut T,
|
io: &'a mut T,
|
||||||
cx: &'a mut Context<'b>
|
cx: &'a mut Context<'b>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> {
|
impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> {
|
||||||
@ -187,7 +183,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
|
|
||||||
match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) {
|
match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) {
|
||||||
Poll::Ready(result) => result,
|
Poll::Ready(result) => result,
|
||||||
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
|
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -196,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
|
|
||||||
match self.session.writev_tls(&mut writer) {
|
match self.session.writev_tls(&mut writer) {
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
||||||
result => Poll::Ready(result)
|
result => Poll::Ready(result),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,9 +209,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
Poll::Ready(Ok(n)) => wrlen += n,
|
Poll::Ready(Ok(n)) => wrlen += n,
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
write_would_block = true;
|
write_would_block = true;
|
||||||
break
|
break;
|
||||||
},
|
}
|
||||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,9 +221,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
Poll::Ready(Ok(n)) => rdlen += n,
|
Poll::Ready(Ok(n)) => rdlen += n,
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
read_would_block = true;
|
read_would_block = true;
|
||||||
break
|
break;
|
||||||
},
|
}
|
||||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,21 +233,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
(true, true) => {
|
(true, true) => {
|
||||||
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
|
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
|
||||||
Poll::Ready(Err(err))
|
Poll::Ready(Err(err))
|
||||||
},
|
}
|
||||||
(_, false) => Poll::Ready(Ok((rdlen, wrlen))),
|
(_, false) => Poll::Ready(Ok((rdlen, wrlen))),
|
||||||
(_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 {
|
(_, true) if write_would_block || read_would_block => {
|
||||||
Poll::Ready(Ok((rdlen, wrlen)))
|
if rdlen != 0 || wrlen != 0 {
|
||||||
} else {
|
Poll::Ready(Ok((rdlen, wrlen)))
|
||||||
Poll::Pending
|
} else {
|
||||||
},
|
Poll::Pending
|
||||||
(..) => continue
|
}
|
||||||
}
|
}
|
||||||
|
(..) => continue,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
|
||||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
|
|
||||||
while pos != buf.len() {
|
while pos != buf.len() {
|
||||||
@ -262,14 +264,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
|
|||||||
match self.read_io(cx) {
|
match self.read_io(cx) {
|
||||||
Poll::Ready(Ok(0)) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
self.eof = true;
|
self.eof = true;
|
||||||
break
|
break;
|
||||||
},
|
}
|
||||||
Poll::Ready(Ok(_)) => (),
|
Poll::Ready(Ok(_)) => (),
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
would_block = true;
|
would_block = true;
|
||||||
break
|
break;
|
||||||
},
|
}
|
||||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,13 +282,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
|
|||||||
Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)),
|
Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)),
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
pos += n;
|
pos += n;
|
||||||
continue
|
continue;
|
||||||
},
|
}
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 =>
|
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => {
|
||||||
Poll::Ready(Ok(pos)),
|
Poll::Ready(Ok(pos))
|
||||||
Err(err) => Poll::Ready(Err(err))
|
}
|
||||||
}
|
Err(err) => Poll::Ready(Err(err)),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Poll::Ready(Ok(pos))
|
Poll::Ready(Ok(pos))
|
||||||
@ -294,7 +297,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
|
||||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
|
|
||||||
while pos != buf.len() {
|
while pos != buf.len() {
|
||||||
@ -303,25 +310,25 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
|
|||||||
match self.session.write(&buf[pos..]) {
|
match self.session.write(&buf[pos..]) {
|
||||||
Ok(n) => pos += n,
|
Ok(n) => pos += n,
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
|
||||||
Err(err) => return Poll::Ready(Err(err))
|
Err(err) => return Poll::Ready(Err(err)),
|
||||||
};
|
};
|
||||||
|
|
||||||
while self.session.wants_write() {
|
while self.session.wants_write() {
|
||||||
match self.write_io(cx) {
|
match self.write_io(cx) {
|
||||||
Poll::Ready(Ok(0)) | Poll::Pending => {
|
Poll::Ready(Ok(0)) | Poll::Pending => {
|
||||||
would_block = true;
|
would_block = true;
|
||||||
break
|
break;
|
||||||
},
|
}
|
||||||
Poll::Ready(Ok(_)) => (),
|
Poll::Ready(Ok(_)) => (),
|
||||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return match (pos, would_block) {
|
return match (pos, would_block) {
|
||||||
(0, true) => Poll::Pending,
|
(0, true) => Poll::Pending,
|
||||||
(n, true) => Poll::Ready(Ok(n)),
|
(n, true) => Poll::Ready(Ok(n)),
|
||||||
(_, false) => continue
|
(_, false) => continue,
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Poll::Ready(Ok(pos))
|
Poll::Ready(Ok(pos))
|
||||||
|
@ -1,39 +1,44 @@
|
|||||||
use std::pin::Pin;
|
use super::Stream;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::task::{ Poll, Context };
|
|
||||||
use futures_core::ready;
|
use futures_core::ready;
|
||||||
use futures_util::future::poll_fn;
|
use futures_util::future::poll_fn;
|
||||||
use futures_util::task::noop_waker_ref;
|
use futures_util::task::noop_waker_ref;
|
||||||
use tokio::io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt };
|
use rustls::internal::pemfile::{certs, rsa_private_keys};
|
||||||
use std::io::{ self, Read, Write, BufReader, Cursor };
|
use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session};
|
||||||
|
use std::io::{self, BufReader, Cursor, Read, Write};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use webpki::DNSNameRef;
|
use webpki::DNSNameRef;
|
||||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
|
||||||
use rustls::{
|
|
||||||
ServerConfig, ClientConfig,
|
|
||||||
ServerSession, ClientSession,
|
|
||||||
Session, NoClientAuth
|
|
||||||
};
|
|
||||||
use super::Stream;
|
|
||||||
|
|
||||||
|
|
||||||
struct Good<'a>(&'a mut dyn Session);
|
struct Good<'a>(&'a mut dyn Session);
|
||||||
|
|
||||||
impl<'a> AsyncRead for Good<'a> {
|
impl<'a> AsyncRead for Good<'a> {
|
||||||
fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
mut buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Ready(self.0.write_tls(buf.by_ref()))
|
Poll::Ready(self.0.write_tls(buf.by_ref()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> AsyncWrite for Good<'a> {
|
impl<'a> AsyncWrite for Good<'a> {
|
||||||
fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<io::Result<usize>> {
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
mut buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
let len = self.0.read_tls(buf.by_ref())?;
|
let len = self.0.read_tls(buf.by_ref())?;
|
||||||
self.0.process_new_packets()
|
self.0
|
||||||
|
.process_new_packets()
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
||||||
Poll::Ready(Ok(len))
|
Poll::Ready(Ok(len))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.0.process_new_packets()
|
self.0
|
||||||
|
.process_new_packets()
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
@ -47,13 +52,21 @@ impl<'a> AsyncWrite for Good<'a> {
|
|||||||
struct Pending;
|
struct Pending;
|
||||||
|
|
||||||
impl AsyncRead for Pending {
|
impl AsyncRead for Pending {
|
||||||
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for Pending {
|
impl AsyncWrite for Pending {
|
||||||
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<io::Result<usize>> {
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,13 +82,21 @@ impl AsyncWrite for Pending {
|
|||||||
struct Eof;
|
struct Eof;
|
||||||
|
|
||||||
impl AsyncRead for Eof {
|
impl AsyncRead for Eof {
|
||||||
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Ready(Ok(0))
|
Poll::Ready(Ok(0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for Eof {
|
impl AsyncWrite for Eof {
|
||||||
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Ready(Ok(buf.len()))
|
Poll::Ready(Ok(buf.len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,8 +143,14 @@ async fn stream_bad() -> io::Result<()> {
|
|||||||
|
|
||||||
let mut bad = Pending;
|
let mut bad = Pending;
|
||||||
let mut stream = Stream::new(&mut bad, &mut client);
|
let mut stream = Stream::new(&mut bad, &mut client);
|
||||||
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
|
assert_eq!(
|
||||||
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
|
poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
|
||||||
|
8
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?,
|
||||||
|
8
|
||||||
|
);
|
||||||
let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer
|
let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer
|
||||||
assert!(r < 1024);
|
assert!(r < 1024);
|
||||||
|
|
||||||
@ -164,7 +191,10 @@ async fn stream_handshake_eof() -> io::Result<()> {
|
|||||||
|
|
||||||
let mut cx = Context::from_waker(noop_waker_ref());
|
let mut cx = Context::from_waker(noop_waker_ref());
|
||||||
let r = stream.handshake(&mut cx);
|
let r = stream.handshake(&mut cx);
|
||||||
assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof)));
|
assert_eq!(
|
||||||
|
r.map_err(|err| err.kind()),
|
||||||
|
Poll::Ready(Err(io::ErrorKind::UnexpectedEof))
|
||||||
|
);
|
||||||
|
|
||||||
Ok(()) as io::Result<()>
|
Ok(()) as io::Result<()>
|
||||||
}
|
}
|
||||||
@ -204,7 +234,11 @@ fn make_pair() -> (ServerSession, ClientSession) {
|
|||||||
(server, client)
|
(server, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn do_handshake(
|
||||||
|
client: &mut ClientSession,
|
||||||
|
server: &mut ServerSession,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
let mut good = Good(server);
|
let mut good = Good(server);
|
||||||
let mut stream = Stream::new(&mut good, client);
|
let mut stream = Stream::new(&mut good, client);
|
||||||
|
|
||||||
|
@ -1,23 +1,27 @@
|
|||||||
use std::io::IoSlice;
|
|
||||||
use std::cmp::{ self, Ordering };
|
|
||||||
use bytes::Buf;
|
use bytes::Buf;
|
||||||
|
use std::cmp::{self, Ordering};
|
||||||
|
use std::io::IoSlice;
|
||||||
|
|
||||||
pub struct VecBuf<'a, 'b: 'a> {
|
pub struct VecBuf<'a, 'b: 'a> {
|
||||||
pos: usize,
|
pos: usize,
|
||||||
cur: usize,
|
cur: usize,
|
||||||
inner: &'a [&'b [u8]]
|
inner: &'a [&'b [u8]],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> VecBuf<'a, 'b> {
|
impl<'a, 'b> VecBuf<'a, 'b> {
|
||||||
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
|
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
|
||||||
VecBuf { pos: 0, cur: 0, inner: vbytes }
|
VecBuf {
|
||||||
|
pos: 0,
|
||||||
|
cur: 0,
|
||||||
|
inner: vbytes,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
||||||
fn remaining(&self) -> usize {
|
fn remaining(&self) -> usize {
|
||||||
let sum = self.inner
|
let sum = self
|
||||||
|
.inner
|
||||||
.iter()
|
.iter()
|
||||||
.skip(self.pos)
|
.skip(self.pos)
|
||||||
.map(|bytes| bytes.len())
|
.map(|bytes| bytes.len())
|
||||||
@ -32,19 +36,21 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
|||||||
fn advance(&mut self, cnt: usize) {
|
fn advance(&mut self, cnt: usize) {
|
||||||
let current = self.inner[self.pos].len();
|
let current = self.inner[self.pos].len();
|
||||||
match (self.cur + cnt).cmp(¤t) {
|
match (self.cur + cnt).cmp(¤t) {
|
||||||
Ordering::Equal => if self.pos + 1 < self.inner.len() {
|
Ordering::Equal => {
|
||||||
self.pos += 1;
|
if self.pos + 1 < self.inner.len() {
|
||||||
self.cur = 0;
|
self.pos += 1;
|
||||||
} else {
|
self.cur = 0;
|
||||||
self.cur += cnt;
|
} else {
|
||||||
},
|
self.cur += cnt;
|
||||||
|
}
|
||||||
|
}
|
||||||
Ordering::Greater => {
|
Ordering::Greater => {
|
||||||
if self.pos + 1 < self.inner.len() {
|
if self.pos + 1 < self.inner.len() {
|
||||||
self.pos += 1;
|
self.pos += 1;
|
||||||
}
|
}
|
||||||
let remaining = self.cur + cnt - current;
|
let remaining = self.cur + cnt - current;
|
||||||
self.advance(remaining);
|
self.advance(remaining);
|
||||||
},
|
}
|
||||||
Ordering::Less => self.cur += cnt,
|
Ordering::Less => self.cur += cnt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -120,8 +126,7 @@ mod test_vecbuf {
|
|||||||
let b1: &[u8] = &mut [0];
|
let b1: &[u8] = &mut [0];
|
||||||
let b2: &[u8] = &mut [0];
|
let b2: &[u8] = &mut [0];
|
||||||
|
|
||||||
let mut dst: [IoSlice; 2] =
|
let mut dst: [IoSlice; 2] = [IoSlice::new(b1), IoSlice::new(b2)];
|
||||||
[IoSlice::new(b1), IoSlice::new(b2)];
|
|
||||||
|
|
||||||
assert_eq!(2, buf.bytes_vectored(&mut dst[..]));
|
assert_eq!(2, buf.bytes_vectored(&mut dst[..]));
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
|
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
|
||||||
|
|
||||||
mod common;
|
|
||||||
pub mod client;
|
pub mod client;
|
||||||
|
mod common;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
|
||||||
|
use common::{MidHandshake, Stream, TlsState};
|
||||||
|
use futures_core::future::FusedFuture;
|
||||||
|
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
|
||||||
|
use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::future::Future;
|
use std::task::{Context, Poll};
|
||||||
use std::task::{ Context, Poll };
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use futures_core::future::FusedFuture;
|
|
||||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
|
||||||
use webpki::DNSNameRef;
|
use webpki::DNSNameRef;
|
||||||
use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session };
|
|
||||||
use common::{ Stream, TlsState, MidHandshake };
|
|
||||||
|
|
||||||
pub use rustls;
|
pub use rustls;
|
||||||
pub use webpki;
|
pub use webpki;
|
||||||
@ -88,7 +88,7 @@ impl TlsConnector {
|
|||||||
TlsState::Stream
|
TlsState::Stream
|
||||||
},
|
},
|
||||||
|
|
||||||
session
|
session,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,9 +151,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
Pin::new(&mut self.0)
|
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
|
||||||
.poll(cx)
|
|
||||||
.map_err(|(err, _)| err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,9 +167,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
Pin::new(&mut self.0)
|
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
|
||||||
.poll(cx)
|
|
||||||
.map_err(|(err, _)| err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use rustls::Session;
|
|
||||||
use crate::common::IoSession;
|
use crate::common::IoSession;
|
||||||
|
use rustls::Session;
|
||||||
|
|
||||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||||
/// protocol.
|
/// protocol.
|
||||||
@ -60,29 +60,35 @@ where
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
|
|
||||||
match &this.state {
|
match &this.state {
|
||||||
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
|
TlsState::Stream | TlsState::WriteShutdown => {
|
||||||
Poll::Ready(Ok(0)) => {
|
match stream.as_mut_pin().poll_read(cx, buf) {
|
||||||
this.state.shutdown_read();
|
Poll::Ready(Ok(0)) => {
|
||||||
Poll::Ready(Ok(0))
|
this.state.shutdown_read();
|
||||||
}
|
Poll::Ready(Ok(0))
|
||||||
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
|
|
||||||
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
|
|
||||||
this.state.shutdown_read();
|
|
||||||
if this.state.writeable() {
|
|
||||||
stream.session.send_close_notify();
|
|
||||||
this.state.shutdown_write();
|
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(0))
|
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
|
||||||
|
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
|
||||||
|
this.state.shutdown_read();
|
||||||
|
if this.state.writeable() {
|
||||||
|
stream.session.send_close_notify();
|
||||||
|
this.state.shutdown_write();
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(0))
|
||||||
|
}
|
||||||
|
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
}
|
}
|
||||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
}
|
||||||
Poll::Pending => Poll::Pending
|
|
||||||
},
|
|
||||||
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
|
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
s => unreachable!("server TLS can not hit this state: {:?}", s),
|
s => unreachable!("server TLS can not hit this state: {:?}", s),
|
||||||
@ -96,17 +102,21 @@ where
|
|||||||
{
|
{
|
||||||
/// Note: that it does not guarantee the final data to be sent.
|
/// Note: that it does not guarantee the final data to be sent.
|
||||||
/// To be cautious, you must manually call `flush`.
|
/// To be cautious, you must manually call `flush`.
|
||||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
stream.as_mut_pin().poll_write(cx, buf)
|
stream.as_mut_pin().poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
stream.as_mut_pin().poll_flush(cx)
|
stream.as_mut_pin().poll_flush(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,8 +127,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream =
|
||||||
.set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
stream.as_mut_pin().poll_shutdown(cx)
|
stream.as_mut_pin().poll_shutdown(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,21 +1,20 @@
|
|||||||
use std::io;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::net::ToSocketAddrs;
|
|
||||||
use tokio::prelude::*;
|
|
||||||
use tokio::net::TcpStream;
|
|
||||||
use rustls::ClientConfig;
|
use rustls::ClientConfig;
|
||||||
use tokio_rustls::{ TlsConnector, client::TlsStream };
|
use std::io;
|
||||||
|
use std::net::ToSocketAddrs;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::prelude::*;
|
||||||
|
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||||
|
|
||||||
|
async fn get(
|
||||||
async fn get(config: Arc<ClientConfig>, domain: &str, port: u16)
|
config: Arc<ClientConfig>,
|
||||||
-> io::Result<(TlsStream<TcpStream>, String)>
|
domain: &str,
|
||||||
{
|
port: u16,
|
||||||
|
) -> io::Result<(TlsStream<TcpStream>, String)> {
|
||||||
let connector = TlsConnector::from(config);
|
let connector = TlsConnector::from(config);
|
||||||
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
||||||
|
|
||||||
let addr = (domain, port)
|
let addr = (domain, port).to_socket_addrs()?.next().unwrap();
|
||||||
.to_socket_addrs()?
|
|
||||||
.next().unwrap();
|
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
|
|
||||||
@ -31,7 +30,9 @@ async fn get(config: Arc<ClientConfig>, domain: &str, port: u16)
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tls12() -> io::Result<()> {
|
async fn test_tls12() -> io::Result<()> {
|
||||||
let mut config = ClientConfig::new();
|
let mut config = ClientConfig::new();
|
||||||
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
config
|
||||||
|
.root_store
|
||||||
|
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
||||||
config.versions = vec![rustls::ProtocolVersion::TLSv1_2];
|
config.versions = vec![rustls::ProtocolVersion::TLSv1_2];
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let domain = "tls-v1-2.badssl.com";
|
let domain = "tls-v1-2.badssl.com";
|
||||||
@ -52,7 +53,9 @@ fn test_tls13() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_modern() -> io::Result<()> {
|
async fn test_modern() -> io::Result<()> {
|
||||||
let mut config = ClientConfig::new();
|
let mut config = ClientConfig::new();
|
||||||
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
config
|
||||||
|
.root_store
|
||||||
|
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let domain = "mozilla-modern.badssl.com";
|
let domain = "mozilla-modern.badssl.com";
|
||||||
|
|
||||||
|
@ -1,21 +1,17 @@
|
|||||||
#![cfg(feature = "early-data")]
|
#![cfg(feature = "early-data")]
|
||||||
|
|
||||||
use std::io::{ self, BufRead, BufReader, Cursor };
|
use std::io::{self, BufRead, BufReader, Cursor};
|
||||||
use std::process::{ Command, Child, Stdio };
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::process::{Child, Command, Stdio};
|
||||||
|
use std::process::{Child, Command, Stdio};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::marker::Unpin;
|
use std::task::{Context, Poll};
|
||||||
use std::pin::{ Pin };
|
|
||||||
use std::task::{ Context, Poll };
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::prelude::*;
|
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::prelude::*;
|
||||||
use tokio::time::delay_for;
|
use tokio::time::delay_for;
|
||||||
use futures_util::{ future, ready };
|
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||||
use rustls::ClientConfig;
|
|
||||||
use tokio_rustls::{ TlsConnector, client::TlsStream };
|
|
||||||
use std::future::Future;
|
|
||||||
|
|
||||||
|
|
||||||
struct Read1<T>(T);
|
struct Read1<T>(T);
|
||||||
|
|
||||||
@ -29,11 +25,12 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(config: Arc<ClientConfig>, addr: SocketAddr, data: &[u8])
|
async fn send(
|
||||||
-> io::Result<TlsStream<TcpStream>>
|
config: Arc<ClientConfig>,
|
||||||
{
|
addr: SocketAddr,
|
||||||
let connector = TlsConnector::from(config)
|
data: &[u8],
|
||||||
.early_data(true);
|
) -> io::Result<TlsStream<TcpStream>> {
|
||||||
|
let connector = TlsConnector::from(config).early_data(true);
|
||||||
let stream = TcpStream::connect(&addr).await?;
|
let stream = TcpStream::connect(&addr).await?;
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();
|
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();
|
||||||
|
|
||||||
@ -98,10 +95,8 @@ async fn test_0rtt() -> io::Result<()> {
|
|||||||
let stdout = handle.0.stdout.as_mut().unwrap();
|
let stdout = handle.0.stdout.as_mut().unwrap();
|
||||||
let mut lines = BufReader::new(stdout).lines();
|
let mut lines = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
let has_msg1 = lines.by_ref()
|
let has_msg1 = lines.by_ref().any(|line| line.unwrap().contains("hello"));
|
||||||
.any(|line| line.unwrap().contains("hello"));
|
let has_msg2 = lines.by_ref().any(|line| line.unwrap().contains("world!"));
|
||||||
let has_msg2 = lines.by_ref()
|
|
||||||
.any(|line| line.unwrap().contains("world!"));
|
|
||||||
|
|
||||||
assert!(has_msg1 && has_msg2);
|
assert!(has_msg1 && has_msg2);
|
||||||
|
|
||||||
|
@ -1,29 +1,30 @@
|
|||||||
use std::{ io, thread };
|
|
||||||
use std::io::{ BufReader, Cursor };
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::mpsc::channel;
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
use futures_util::future::TryFutureExt;
|
use futures_util::future::TryFutureExt;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
use rustls::internal::pemfile::{certs, rsa_private_keys};
|
||||||
|
use rustls::{ClientConfig, ServerConfig};
|
||||||
|
use std::io::{BufReader, Cursor};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::mpsc::channel;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::{io, thread};
|
||||||
|
use tokio::io::{copy, split};
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::prelude::*;
|
use tokio::prelude::*;
|
||||||
use tokio::runtime;
|
use tokio::runtime;
|
||||||
use tokio::io::{ copy, split };
|
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
||||||
use tokio::net::{ TcpListener, TcpStream };
|
|
||||||
use rustls::{ ServerConfig, ClientConfig };
|
|
||||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
|
||||||
use tokio_rustls::{ TlsConnector, TlsAcceptor };
|
|
||||||
|
|
||||||
const CERT: &str = include_str!("end.cert");
|
const CERT: &str = include_str!("end.cert");
|
||||||
const CHAIN: &str = include_str!("end.chain");
|
const CHAIN: &str = include_str!("end.chain");
|
||||||
const RSA: &str = include_str!("end.rsa");
|
const RSA: &str = include_str!("end.rsa");
|
||||||
|
|
||||||
lazy_static!{
|
lazy_static! {
|
||||||
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
|
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||||
|
|
||||||
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
||||||
config.set_single_cert(cert, keys.pop().unwrap())
|
config
|
||||||
|
.set_single_cert(cert, keys.pop().unwrap())
|
||||||
.expect("invalid key or certificate");
|
.expect("invalid key or certificate");
|
||||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||||
|
|
||||||
@ -55,11 +56,13 @@ lazy_static!{
|
|||||||
copy(&mut reader, &mut writer).await?;
|
copy(&mut reader, &mut writer).await?;
|
||||||
|
|
||||||
Ok(()) as io::Result<()>
|
Ok(()) as io::Result<()>
|
||||||
}.unwrap_or_else(|err| eprintln!("server: {:?}", err));
|
}
|
||||||
|
.unwrap_or_else(|err| eprintln!("server: {:?}", err));
|
||||||
|
|
||||||
handle.spawn(fut);
|
handle.spawn(fut);
|
||||||
}
|
}
|
||||||
}.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err));
|
}
|
||||||
|
.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err));
|
||||||
|
|
||||||
runtime.block_on(done);
|
runtime.block_on(done);
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user