diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8fa61c3..7456886 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,6 +1,10 @@ name: CI -on: [push, pull_request] +on: + push: + branches: + - master + pull_request: {} jobs: check: @@ -40,7 +44,7 @@ jobs: profile: minimal - uses: actions/checkout@master - name: Test - run: cargo test --all --all-features + run: cargo test --all lints: name: Lints diff --git a/tokio-native-tls/tests/cert.der b/tokio-native-tls/tests/cert.der new file mode 100644 index 0000000..e1f964d Binary files /dev/null and b/tokio-native-tls/tests/cert.der differ diff --git a/tokio-native-tls/tests/identity.p12 b/tokio-native-tls/tests/identity.p12 new file mode 100644 index 0000000..d16abb8 Binary files /dev/null and b/tokio-native-tls/tests/identity.p12 differ diff --git a/tokio-native-tls/tests/root-ca.der b/tokio-native-tls/tests/root-ca.der new file mode 100644 index 0000000..a9335c6 Binary files /dev/null and b/tokio-native-tls/tests/root-ca.der differ diff --git a/tokio-native-tls/tests/smoke.rs b/tokio-native-tls/tests/smoke.rs index baa7fb7..48c29a0 100644 --- a/tokio-native-tls/tests/smoke.rs +++ b/tokio-native-tls/tests/smoke.rs @@ -1,500 +1,126 @@ -#![warn(rust_2018_idioms)] - -use cfg_if::cfg_if; -use env_logger; use futures::join; -use native_tls; -use native_tls::{Identity, TlsAcceptor, TlsConnector}; -use std::io::Write; -use std::marker::Unpin; -use std::process::Command; -use std::ptr; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::stream::StreamExt; +use native_tls::{Certificate, Identity}; +use std::io::Error; +use tokio::{ + io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::{TcpListener, TcpStream}, +}; +use tokio_native_tls::{TlsAcceptor, TlsConnector}; -macro_rules! t { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - } +#[tokio::test] +async fn client_to_server() { + let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = srv.local_addr().unwrap(); + + let (server_tls, client_tls) = context(); + + // Create a future to accept one socket, connect the ssl stream, and then + // read all the data from it. + let server = async move { + let (socket, _) = srv.accept().await.unwrap(); + let mut socket = server_tls.accept(socket).await.unwrap(); + let mut data = Vec::new(); + socket.read_to_end(&mut data).await.unwrap(); + data }; + + // Create a future to connect to our server, connect the ssl stream, and + // then write a bunch of data to it. + let client = async move { + let socket = TcpStream::connect(&addr).await.unwrap(); + let socket = client_tls.connect("foobar.com", socket).await.unwrap(); + copy_data(socket).await + }; + + // Finally, run everything! + let (data, _) = join!(server, client); + // assert_eq!(amt, AMT); + assert!(data == vec![9; AMT]); } -#[allow(dead_code)] -struct Keys { - cert_der: Vec, - pkey_der: Vec, - pkcs12_der: Vec, +#[tokio::test] +async fn server_to_client() { + // Create a server listening on a port, then figure out what that port is + let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = srv.local_addr().unwrap(); + + let (server_tls, client_tls) = context(); + + let server = async move { + let (socket, _) = srv.accept().await.unwrap(); + let socket = server_tls.accept(socket).await.unwrap(); + copy_data(socket).await + }; + + let client = async move { + let socket = TcpStream::connect(&addr).await.unwrap(); + let mut socket = client_tls.connect("foobar.com", socket).await.unwrap(); + let mut data = Vec::new(); + socket.read_to_end(&mut data).await.unwrap(); + data + }; + + // Finally, run everything! + let (_, data) = join!(server, client); + assert!(data == vec![9; AMT]); } -#[allow(dead_code)] -fn openssl_keys() -> &'static Keys { - static INIT: Once = Once::new(); - static mut KEYS: *mut Keys = ptr::null_mut(); +#[tokio::test] +async fn one_byte_at_a_time() { + const AMT: usize = 1024; - INIT.call_once(|| { - let path = t!(env::current_exe()); - let path = path.parent().unwrap(); - let keyfile = path.join("test.key"); - let certfile = path.join("test.crt"); - let config = path.join("openssl.config"); + let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = srv.local_addr().unwrap(); - File::create(&config) - .unwrap() - .write_all( - b"\ - [req]\n\ - distinguished_name=dn\n\ - [ dn ]\n\ - CN=localhost\n\ - [ ext ]\n\ - basicConstraints=CA:FALSE,pathlen:0\n\ - subjectAltName = @alt_names - extendedKeyUsage=serverAuth,clientAuth - [alt_names] - DNS.1 = localhost - ", - ) - .unwrap(); + let (server_tls, client_tls) = context(); - let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; - let output = t!(Command::new("openssl") - .arg("req") - .arg("-nodes") - .arg("-x509") - .arg("-newkey") - .arg("rsa:2048") - .arg("-config") - .arg(&config) - .arg("-extensions") - .arg("ext") - .arg("-subj") - .arg(subj) - .arg("-keyout") - .arg(&keyfile) - .arg("-out") - .arg(&certfile) - .arg("-days") - .arg("1") - .output()); - assert!(output.status.success()); - - let crtout = t!(Command::new("openssl") - .arg("x509") - .arg("-outform") - .arg("der") - .arg("-in") - .arg(&certfile) - .output()); - assert!(crtout.status.success()); - let keyout = t!(Command::new("openssl") - .arg("rsa") - .arg("-outform") - .arg("der") - .arg("-in") - .arg(&keyfile) - .output()); - assert!(keyout.status.success()); - - let pkcs12out = t!(Command::new("openssl") - .arg("pkcs12") - .arg("-export") - .arg("-nodes") - .arg("-inkey") - .arg(&keyfile) - .arg("-in") - .arg(&certfile) - .arg("-password") - .arg("pass:foobar") - .output()); - assert!(pkcs12out.status.success()); - - let keys = Box::new(Keys { - cert_der: crtout.stdout, - pkey_der: keyout.stdout, - pkcs12_der: pkcs12out.stdout, - }); - unsafe { - KEYS = Box::into_raw(keys); + let server = async move { + let (socket, _) = srv.accept().await.unwrap(); + let mut socket = server_tls.accept(socket).await.unwrap(); + let mut amt = 0; + for b in std::iter::repeat(9).take(AMT) { + let data = [b as u8]; + socket.write_all(&data).await.unwrap(); + amt += 1; } - }); - unsafe { &*KEYS } + amt + }; + + let client = async move { + let socket = TcpStream::connect(&addr).await.unwrap(); + let mut socket = client_tls.connect("foobar.com", socket).await.unwrap(); + let mut data = Vec::new(); + loop { + let mut buf = [0; 1]; + match socket.read_exact(&mut buf).await { + Ok(_) => data.extend_from_slice(&buf), + Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(err) => panic!(err), + } + } + data + }; + + let (amt, data) = join!(server, client); + assert_eq!(amt, AMT); + assert!(data == vec![9; AMT as usize]); } -cfg_if! { - if #[cfg(feature = "rustls")] { - use webpki; - use untrusted; - use std::env; - use std::fs::File; - use std::process::Command; - use std::sync::Once; +fn context() -> (TlsAcceptor, TlsConnector) { + // Certs borrowed from `rust-native-tls/tests` + let pkcs12 = include_bytes!("identity.p12"); + let der = include_bytes!("root-ca.der"); - use untrusted::Input; - use webpki::trust_anchor_util; + let identity = Identity::from_pkcs12(pkcs12, "mypass").unwrap(); + let acceptor = native_tls::TlsAcceptor::builder(identity).build().unwrap(); - fn server_cx() -> io::Result { - let mut cx = ServerContext::new(); + let cert = Certificate::from_der(der).unwrap(); + let connector = native_tls::TlsConnector::builder() + .add_root_certificate(cert) + .build() + .unwrap(); - let (cert, key) = keys(); - cx.config_mut() - .set_single_cert(vec![cert.to_vec()], key.to_vec()); - - Ok(cx) - } - - fn configure_client(cx: &mut ClientContext) { - let (cert, _key) = keys(); - let cert = Input::from(cert); - let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap(); - cx.config_mut().root_store.add_trust_anchors(&[anchor]); - } - - // Like OpenSSL we generate certificates on the fly, but for OSX we - // also have to put them into a specific keychain. We put both the - // certificates and the keychain next to our binary. - // - // Right now I don't know of a way to programmatically create a - // self-signed certificate, so we just fork out to the `openssl` binary. - fn keys() -> (&'static [u8], &'static [u8]) { - static INIT: Once = Once::new(); - static mut KEYS: *mut (Vec, Vec) = ptr::null_mut(); - - INIT.call_once(|| { - let (key, cert) = openssl_keys(); - let path = t!(env::current_exe()); - let path = path.parent().unwrap(); - let keyfile = path.join("test.key"); - let certfile = path.join("test.crt"); - let config = path.join("openssl.config"); - - File::create(&config).unwrap().write_all(b"\ - [req]\n\ - distinguished_name=dn\n\ - [ dn ]\n\ - CN=localhost\n\ - [ ext ]\n\ - basicConstraints=CA:FALSE,pathlen:0\n\ - subjectAltName = @alt_names - [alt_names] - DNS.1 = localhost - ").unwrap(); - - let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; - let output = t!(Command::new("openssl") - .arg("req") - .arg("-nodes") - .arg("-x509") - .arg("-newkey").arg("rsa:2048") - .arg("-config").arg(&config) - .arg("-extensions").arg("ext") - .arg("-subj").arg(subj) - .arg("-keyout").arg(&keyfile) - .arg("-out").arg(&certfile) - .arg("-days").arg("1") - .output()); - assert!(output.status.success()); - - let crtout = t!(Command::new("openssl") - .arg("x509") - .arg("-outform").arg("der") - .arg("-in").arg(&certfile) - .output()); - assert!(crtout.status.success()); - let keyout = t!(Command::new("openssl") - .arg("rsa") - .arg("-outform").arg("der") - .arg("-in").arg(&keyfile) - .output()); - assert!(keyout.status.success()); - - let cert = crtout.stdout; - let key = keyout.stdout; - unsafe { - KEYS = Box::into_raw(Box::new((cert, key))); - } - }); - unsafe { - (&(*KEYS).0, &(*KEYS).1) - } - } - } else if #[cfg(any(feature = "force-openssl", - all(not(target_os = "macos"), - not(target_os = "windows"), - not(target_os = "ios"))))] { - use std::fs::File; - use std::env; - use std::sync::Once; - - fn contexts() -> (tokio_native_tls::TlsAcceptor, tokio_native_tls::TlsConnector) { - let keys = openssl_keys(); - - let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = TlsAcceptor::builder(pkcs12); - - let cert = t!(native_tls::Certificate::from_der(&keys.cert_der)); - - let mut client = TlsConnector::builder(); - t!(client.add_root_certificate(cert).build()); - - (t!(srv.build()).into(), t!(client.build()).into()) - } - } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - use std::env; - use std::fs::File; - use std::sync::Once; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let keys = openssl_keys(); - - let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = TlsAcceptor::builder(pkcs12); - - let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap(); - let mut client = TlsConnector::builder(); - client.add_root_certificate(cert); - - (t!(srv.build()).into(), t!(client.build()).into()) - } - } else { - use schannel; - use winapi; - - use std::env; - use std::fs::File; - use std::io; - use std::mem; - use std::sync::Once; - - use schannel::cert_context::CertContext; - use schannel::cert_store::{CertStore, CertAdd, Memory}; - use winapi::shared::basetsd::*; - use winapi::shared::lmcons::*; - use winapi::shared::minwindef::*; - use winapi::shared::ntdef::WCHAR; - use winapi::um::minwinbase::*; - use winapi::um::sysinfoapi::*; - use winapi::um::timezoneapi::*; - use winapi::um::wincrypt::*; - - const FRIENDLY_NAME: &'static str = "tokio-tls localhost testing cert"; - - fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { - let cert = localhost_cert(); - let mut store = t!(Memory::new()).into_store(); - t!(store.add_cert(&cert, CertAdd::Always)); - let pkcs12_der = t!(store.export_pkcs12("foobar")); - let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar")); - - let srv = TlsAcceptor::builder(pkcs12); - let client = TlsConnector::builder(); - (t!(srv.build()).into(), t!(client.build()).into()) - } - - // ==================================================================== - // Magic! - // - // Lots of magic is happening here to wrangle certificates for running - // these tests on Windows. For more information see the test suite - // in the schannel-rs crate as this is just coyping that. - // - // The general gist of this though is that the only way to add custom - // trusted certificates is to add it to the system store of trust. To - // do that we go through the whole rigamarole here to generate a new - // self-signed certificate and then insert that into the system store. - // - // This generates some dialogs, so we print what we're doing sometimes, - // and otherwise we just manage the ephemeral certificates. Because - // they're in the system store we always ensure that they're only valid - // for a small period of time (e.g. 1 day). - - fn localhost_cert() -> CertContext { - static INIT: Once = Once::new(); - INIT.call_once(|| { - for cert in local_root_store().certs() { - let name = match cert.friendly_name() { - Ok(name) => name, - Err(_) => continue, - }; - if name != FRIENDLY_NAME { - continue - } - if !cert.is_time_valid().unwrap() { - io::stdout().write_all(br#" - -The tokio-tls test suite is about to delete an old copy of one of its -certificates from your root trust store. This certificate was only valid for one -day and it is no longer needed. The host should be "localhost" and the -description should mention "tokio-tls". - - "#).unwrap(); - cert.delete().unwrap(); - } else { - return - } - } - - install_certificate().unwrap(); - }); - - for cert in local_root_store().certs() { - let name = match cert.friendly_name() { - Ok(name) => name, - Err(_) => continue, - }; - if name == FRIENDLY_NAME { - return cert - } - } - - panic!("couldn't find a cert"); - } - - fn local_root_store() -> CertStore { - if env::var("CI").is_ok() { - CertStore::open_local_machine("Root").unwrap() - } else { - CertStore::open_current_user("Root").unwrap() - } - } - - fn install_certificate() -> io::Result { - unsafe { - let mut provider = 0; - let mut hkey = 0; - - let mut buffer = "tokio-tls test suite".encode_utf16() - .chain(Some(0)) - .collect::>(); - let res = CryptAcquireContextW(&mut provider, - buffer.as_ptr(), - ptr::null_mut(), - PROV_RSA_FULL, - CRYPT_MACHINE_KEYSET); - if res != TRUE { - // create a new key container (since it does not exist) - let res = CryptAcquireContextW(&mut provider, - buffer.as_ptr(), - ptr::null_mut(), - PROV_RSA_FULL, - CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET); - if res != TRUE { - return Err(Error::last_os_error()) - } - } - - // create a new keypair (RSA-2048) - let res = CryptGenKey(provider, - AT_SIGNATURE, - 0x0800<<16 | CRYPT_EXPORTABLE, - &mut hkey); - if res != TRUE { - return Err(Error::last_os_error()); - } - - // start creating the certificate - let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\ - G=tokio_tls".encode_utf16() - .chain(Some(0)) - .collect::>(); - let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed(); - let mut cname_len = cname_buffer.len() as DWORD; - let res = CertStrToNameW(X509_ASN_ENCODING, - name.as_ptr(), - CERT_X500_NAME_STR, - ptr::null_mut(), - cname_buffer.as_mut_ptr() as *mut u8, - &mut cname_len, - ptr::null_mut()); - if res != TRUE { - return Err(Error::last_os_error()); - } - - let mut subject_issuer = CERT_NAME_BLOB { - cbData: cname_len, - pbData: cname_buffer.as_ptr() as *mut u8, - }; - let mut key_provider = CRYPT_KEY_PROV_INFO { - pwszContainerName: buffer.as_mut_ptr(), - pwszProvName: ptr::null_mut(), - dwProvType: PROV_RSA_FULL, - dwFlags: CRYPT_MACHINE_KEYSET, - cProvParam: 0, - rgProvParam: ptr::null_mut(), - dwKeySpec: AT_SIGNATURE, - }; - let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER { - pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _, - Parameters: mem::zeroed(), - }; - let mut expiration_date: SYSTEMTIME = mem::zeroed(); - GetSystemTime(&mut expiration_date); - let mut file_time: FILETIME = mem::zeroed(); - let res = SystemTimeToFileTime(&mut expiration_date, - &mut file_time); - if res != TRUE { - return Err(Error::last_os_error()); - } - let mut timestamp: u64 = file_time.dwLowDateTime as u64 | - (file_time.dwHighDateTime as u64) << 32; - // one day, timestamp unit is in 100 nanosecond intervals - timestamp += (1E9 as u64) / 100 * (60 * 60 * 24); - file_time.dwLowDateTime = timestamp as u32; - file_time.dwHighDateTime = (timestamp >> 32) as u32; - let res = FileTimeToSystemTime(&file_time, - &mut expiration_date); - if res != TRUE { - return Err(Error::last_os_error()); - } - - // create a self signed certificate - let cert_context = CertCreateSelfSignCertificate( - 0 as ULONG_PTR, - &mut subject_issuer, - 0, - &mut key_provider, - &mut sig_algorithm, - ptr::null_mut(), - &mut expiration_date, - ptr::null_mut()); - if cert_context.is_null() { - return Err(Error::last_os_error()); - } - - // TODO: this is.. a terrible hack. Right now `schannel` - // doesn't provide a public method to go from a raw - // cert context pointer to the `CertContext` structure it - // has, so we just fake it here with a transmute. This'll - // probably break at some point, but hopefully by then - // it'll have a method to do this! - struct MyCertContext(T); - impl Drop for MyCertContext { - fn drop(&mut self) {} - } - - let cert_context = MyCertContext(cert_context); - let cert_context: CertContext = mem::transmute(cert_context); - - cert_context.set_friendly_name(FRIENDLY_NAME)?; - - // install the certificate to the machine's local store - io::stdout().write_all(br#" - -The tokio-tls test suite is about to add a certificate to your set of root -and trusted certificates. This certificate should be for the domain "localhost" -with the description related to "tokio-tls". This certificate is only valid -for one day and will be automatically deleted if you re-run the tokio-tls -test suite later. - - "#).unwrap(); - local_root_store().add_cert(&cert_context, - CertAdd::ReplaceExisting)?; - Ok(cert_context) - } - } - } + (acceptor.into(), connector.into()) } const AMT: usize = 128 * 1024; @@ -517,112 +143,3 @@ async fn copy_data(mut w: W) -> Result { } Ok(amt) } - -#[tokio::test] -async fn client_to_server() { - drop(env_logger::try_init()); - - // Create a server listening on a port, then figure out what that port is - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - // Create a future to accept one socket, connect the ssl stream, and then - // read all the data from it. - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let mut socket = t!(server_cx.accept(socket).await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - data - }; - - // Create a future to connect to our server, connect the ssl stream, and - // then write a bunch of data to it. - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let socket = t!(client_cx.connect("localhost", socket).await); - copy_data(socket).await - }; - - // Finally, run everything! - let (data, _) = join!(server, client); - // assert_eq!(amt, AMT); - assert!(data == vec![9; AMT]); -} - -#[tokio::test] -async fn server_to_client() { - drop(env_logger::try_init()); - - // Create a server listening on a port, then figure out what that port is - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let socket = t!(server_cx.accept(socket).await); - copy_data(socket).await - }; - - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let mut socket = t!(client_cx.connect("localhost", socket).await); - let mut data = Vec::new(); - t!(socket.read_to_end(&mut data).await); - data - }; - - // Finally, run everything! - let (_, data) = join!(server, client); - // assert_eq!(amt, AMT); - assert!(data == vec![9; AMT]); -} - -#[tokio::test] -async fn one_byte_at_a_time() { - const AMT: usize = 1024; - drop(env_logger::try_init()); - - let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); - let addr = t!(srv.local_addr()); - - let (server_cx, client_cx) = contexts(); - - let server = async move { - let mut incoming = srv.incoming(); - let socket = t!(incoming.next().await.unwrap()); - let mut socket = t!(server_cx.accept(socket).await); - let mut amt = 0; - for b in std::iter::repeat(9).take(AMT) { - let data = [b as u8]; - t!(socket.write_all(&data).await); - amt += 1; - } - amt - }; - - let client = async move { - let socket = t!(TcpStream::connect(&addr).await); - let mut socket = t!(client_cx.connect("localhost", socket).await); - let mut data = Vec::new(); - loop { - let mut buf = [0; 1]; - match socket.read_exact(&mut buf).await { - Ok(_) => data.extend_from_slice(&buf), - Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break, - Err(err) => panic!(err), - } - } - data - }; - - let (amt, data) = join!(server, client); - assert_eq!(amt, AMT); - assert!(data == vec![9; AMT as usize]); -} diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 5007aa8..ff7f857 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -1,7 +1,6 @@ use super::*; -use rustls::Session; use crate::common::IoSession; - +use rustls::Session; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -58,20 +57,24 @@ where false } - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData(..) => Poll::Pending, TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) - }, + } Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); @@ -80,8 +83,8 @@ where this.state.shutdown_write(); } Poll::Ready(Ok(0)) - }, - output => output + } + output => output, } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), @@ -95,10 +98,14 @@ where { /// Note: that it does not guarantee the final data to be sent. /// To be cautious, you must manually call `flush`. - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); match this.state { #[cfg(feature = "early-data")] @@ -110,9 +117,10 @@ where if let Some(mut early_data) = stream.session.early_data() { let len = match early_data.write(buf) { Ok(n) => n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => - return Poll::Pending, - Err(err) => return Poll::Ready(Err(err)) + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(err) => return Poll::Ready(Err(err)), }; if len != 0 { data.extend_from_slice(&buf[..len]); @@ -143,10 +151,11 @@ where fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); - #[cfg(feature = "early-data")] { + #[cfg(feature = "early-data")] + { use futures_core::ready; if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { @@ -176,7 +185,8 @@ where self.state.shutdown_write(); } - #[cfg(feature = "early-data")] { + #[cfg(feature = "early-data")] + { // we skip the handshake if let TlsState::EarlyData(..) = self.state { return Pin::new(&mut self.io).poll_shutdown(cx); @@ -184,8 +194,8 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs index c59541e..b9b7894 100644 --- a/tokio-rustls/src/common/handshake.rs +++ b/tokio-rustls/src/common/handshake.rs @@ -1,12 +1,11 @@ -use std::{ io, mem }; -use std::pin::Pin; -use std::future::Future; -use std::task::{ Context, Poll }; +use crate::common::{Stream, TlsState}; use futures_core::future::FusedFuture; -use tokio::io::{ AsyncRead, AsyncWrite }; use rustls::Session; -use crate::common::{ TlsState, Stream }; - +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, mem}; +use tokio::io::{AsyncRead, AsyncWrite}; pub(crate) trait IoSession { type Io; @@ -26,7 +25,7 @@ impl FusedFuture for MidHandshake where IS: IoSession + Unpin, IS::Io: AsyncRead + AsyncWrite + Unpin, - IS::Session: Session + Unpin + IS::Session: Session + Unpin, { fn is_terminated(&self) -> bool { if let MidHandshake::End = self { @@ -41,7 +40,7 @@ impl Future for MidHandshake where IS: IoSession + Unpin, IS::Io: AsyncRead + AsyncWrite + Unpin, - IS::Session: Session + Unpin + IS::Session: Session + Unpin, { type Output = Result; @@ -51,20 +50,21 @@ where if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) { if !stream.skip_handshake() { let (state, io, session) = stream.get_mut(); - let mut tls_stream = Stream::new(io, session) - .set_eof(!state.readable()); + let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); macro_rules! try_poll { ( $e:expr ) => { match $e { Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Ready(Err(err)) => { + return Poll::Ready(Err((err, stream.into_io()))) + } Poll::Pending => { *this = MidHandshake::Handshaking(stream); return Poll::Pending; } } - } + }; } while tls_stream.session.is_handshaking() { diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index 1d0dd07..778fa92 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -3,14 +3,13 @@ mod handshake; #[cfg(feature = "unstable")] mod vecbuf; -use std::pin::Pin; -use std::task::{ Poll, Context }; -use std::io::{ self, Read }; -use rustls::Session; -use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; -pub(crate) use handshake::{ IoSession, MidHandshake }; - +pub(crate) use handshake::{IoSession, MidHandshake}; +use rustls::Session; +use std::io::{self, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; #[derive(Debug)] pub enum TlsState { @@ -26,8 +25,7 @@ impl TlsState { #[inline] pub fn shutdown_read(&mut self) { match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => - *self = TlsState::FullyShutdown, + TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, _ => *self = TlsState::ReadShutdown, } } @@ -35,8 +33,7 @@ impl TlsState { #[inline] pub fn shutdown_write(&mut self) { match *self { - TlsState::ReadShutdown | TlsState::FullyShutdown => - *self = TlsState::FullyShutdown, + TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, _ => *self = TlsState::WriteShutdown, } } @@ -62,7 +59,7 @@ impl TlsState { pub fn is_early_data(&self) -> bool { match self { TlsState::EarlyData(..) => true, - _ => false + _ => false, } } @@ -76,7 +73,7 @@ impl TlsState { pub struct Stream<'a, IO, S> { pub io: &'a mut IO, pub session: &'a mut S, - pub eof: bool + pub eof: bool, } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { @@ -100,28 +97,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { - self.session.process_new_packets() - .map_err(|err| { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ = self.write_io(cx); + self.session.process_new_packets().map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); - io::Error::new(io::ErrorKind::InvalidData, err) - }) + io::Error::new(io::ErrorKind::InvalidData, err) + }) } pub fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, - cx: &'a mut Context<'b> + cx: &'a mut Context<'b>, } impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { fn read(&mut self, buf: &mut [u8]) -> io::Result { match Pin::new(&mut self.io).poll_read(self.cx, buf) { Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), } } } @@ -131,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { let n = match self.session.read_tls(&mut reader) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, - Err(err) => return Poll::Ready(Err(err)) + Err(err) => return Poll::Ready(Err(err)), }; Poll::Ready(Ok(n)) @@ -143,21 +139,21 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { struct Writer<'a, 'b, T> { io: &'a mut T, - cx: &'a mut Context<'b> + cx: &'a mut Context<'b>, } impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { fn write(&mut self, buf: &[u8]) -> io::Result { match Pin::new(&mut self.io).poll_write(self.cx, buf) { Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), } } fn flush(&mut self) -> io::Result<()> { match Pin::new(&mut self.io).poll_flush(self.cx) { Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), } } } @@ -166,7 +162,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { match self.session.write_tls(&mut writer) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) + result => Poll::Ready(result), } } @@ -176,7 +172,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { struct Writer<'a, 'b, T> { io: &'a mut T, - cx: &'a mut Context<'b> + cx: &'a mut Context<'b>, } impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> { @@ -187,7 +183,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) { Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), } } } @@ -196,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { match self.session.writev_tls(&mut writer) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) + result => Poll::Ready(result), } } @@ -213,9 +209,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) => wrlen += n, Poll::Pending => { write_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } @@ -225,9 +221,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) => rdlen += n, Poll::Pending => { read_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } @@ -237,21 +233,27 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { (true, true) => { let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); Poll::Ready(Err(err)) - }, + } (_, false) => Poll::Ready(Ok((rdlen, wrlen))), - (_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 { - Poll::Ready(Ok((rdlen, wrlen))) - } else { - Poll::Pending - }, - (..) => continue - } + (_, true) if write_would_block || read_would_block => { + if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + } + } + (..) => continue, + }; } } } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { let mut pos = 0; while pos != buf.len() { @@ -262,14 +264,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a match self.read_io(cx) { Poll::Ready(Ok(0)) => { self.eof = true; - break - }, + break; + } Poll::Ready(Ok(_)) => (), Poll::Pending => { would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } @@ -280,13 +282,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)), Ok(n) => { pos += n; - continue - }, + continue; + } Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => - Poll::Ready(Ok(pos)), - Err(err) => Poll::Ready(Err(err)) - } + Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => { + Poll::Ready(Ok(pos)) + } + Err(err) => Poll::Ready(Err(err)), + }; } Poll::Ready(Ok(pos)) @@ -294,7 +297,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { let mut pos = 0; while pos != buf.len() { @@ -303,25 +310,25 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' match self.session.write(&buf[pos..]) { Ok(n) => pos += n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), - Err(err) => return Poll::Ready(Err(err)) + Err(err) => return Poll::Ready(Err(err)), }; while self.session.wants_write() { match self.write_io(cx) { Poll::Ready(Ok(0)) | Poll::Pending => { would_block = true; - break - }, + break; + } Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } return match (pos, would_block) { (0, true) => Poll::Pending, (n, true) => Poll::Ready(Ok(n)), - (_, false) => continue - } + (_, false) => continue, + }; } Poll::Ready(Ok(pos)) diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 0055014..b333239 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -1,39 +1,44 @@ -use std::pin::Pin; -use std::sync::Arc; -use std::task::{ Poll, Context }; +use super::Stream; use futures_core::ready; use futures_util::future::poll_fn; use futures_util::task::noop_waker_ref; -use tokio::io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; -use std::io::{ self, Read, Write, BufReader, Cursor }; +use rustls::internal::pemfile::{certs, rsa_private_keys}; +use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session}; +use std::io::{self, BufReader, Cursor, Read, Write}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use webpki::DNSNameRef; -use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use rustls::{ - ServerConfig, ClientConfig, - ServerSession, ClientSession, - Session, NoClientAuth -}; -use super::Stream; - struct Good<'a>(&'a mut dyn Session); impl<'a> AsyncRead for Good<'a> { - fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + mut buf: &mut [u8], + ) -> Poll> { Poll::Ready(self.0.write_tls(buf.by_ref())) } } impl<'a> AsyncWrite for Good<'a> { - fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { let len = self.0.read_tls(buf.by_ref())?; - self.0.process_new_packets() + self.0 + .process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; Poll::Ready(Ok(len)) } fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - self.0.process_new_packets() + self.0 + .process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; Poll::Ready(Ok(())) } @@ -47,13 +52,21 @@ impl<'a> AsyncWrite for Good<'a> { struct Pending; impl AsyncRead for Pending { - fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _: &mut [u8], + ) -> Poll> { Poll::Pending } } impl AsyncWrite for Pending { - fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { Poll::Pending } @@ -69,13 +82,21 @@ impl AsyncWrite for Pending { struct Eof; impl AsyncRead for Eof { - fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _: &mut [u8], + ) -> Poll> { Poll::Ready(Ok(0)) } } impl AsyncWrite for Eof { - fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { Poll::Ready(Ok(buf.len())) } @@ -122,8 +143,14 @@ async fn stream_bad() -> io::Result<()> { let mut bad = Pending; let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); - assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!( + poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, + 8 + ); + assert_eq!( + poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, + 8 + ); let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer assert!(r < 1024); @@ -164,7 +191,10 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut cx = Context::from_waker(noop_waker_ref()); let r = stream.handshake(&mut cx); - assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); + assert_eq!( + r.map_err(|err| err.kind()), + Poll::Ready(Err(io::ErrorKind::UnexpectedEof)) + ); Ok(()) as io::Result<()> } @@ -204,7 +234,11 @@ fn make_pair() -> (ServerSession, ClientSession) { (server, client) } -fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll> { +fn do_handshake( + client: &mut ClientSession, + server: &mut ServerSession, + cx: &mut Context<'_>, +) -> Poll> { let mut good = Good(server); let mut stream = Stream::new(&mut good, client); diff --git a/tokio-rustls/src/common/vecbuf.rs b/tokio-rustls/src/common/vecbuf.rs index 6ea19e3..35fd573 100644 --- a/tokio-rustls/src/common/vecbuf.rs +++ b/tokio-rustls/src/common/vecbuf.rs @@ -1,23 +1,27 @@ -use std::io::IoSlice; -use std::cmp::{ self, Ordering }; use bytes::Buf; - +use std::cmp::{self, Ordering}; +use std::io::IoSlice; pub struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, - inner: &'a [&'b [u8]] + inner: &'a [&'b [u8]], } impl<'a, 'b> VecBuf<'a, 'b> { pub fn new(vbytes: &'a [&'b [u8]]) -> Self { - VecBuf { pos: 0, cur: 0, inner: vbytes } + VecBuf { + pos: 0, + cur: 0, + inner: vbytes, + } } } impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn remaining(&self) -> usize { - let sum = self.inner + let sum = self + .inner .iter() .skip(self.pos) .map(|bytes| bytes.len()) @@ -32,19 +36,21 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn advance(&mut self, cnt: usize) { let current = self.inner[self.pos].len(); match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => if self.pos + 1 < self.inner.len() { - self.pos += 1; - self.cur = 0; - } else { - self.cur += cnt; - }, + Ordering::Equal => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + } + } Ordering::Greater => { if self.pos + 1 < self.inner.len() { self.pos += 1; } let remaining = self.cur + cnt - current; self.advance(remaining); - }, + } Ordering::Less => self.cur += cnt, } } @@ -120,8 +126,7 @@ mod test_vecbuf { let b1: &[u8] = &mut [0]; let b2: &[u8] = &mut [0]; - let mut dst: [IoSlice; 2] = - [IoSlice::new(b1), IoSlice::new(b2)]; + let mut dst: [IoSlice; 2] = [IoSlice::new(b1), IoSlice::new(b2)]; assert_eq!(2, buf.bytes_vectored(&mut dst[..])); } diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index db34b07..161a31a 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -1,19 +1,19 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -mod common; pub mod client; +mod common; pub mod server; +use common::{MidHandshake, Stream, TlsState}; +use futures_core::future::FusedFuture; +use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; +use std::future::Future; use std::io; use std::pin::Pin; use std::sync::Arc; -use std::future::Future; -use std::task::{ Context, Poll }; -use futures_core::future::FusedFuture; -use tokio::io::{ AsyncRead, AsyncWrite }; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; use webpki::DNSNameRef; -use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; -use common::{ Stream, TlsState, MidHandshake }; pub use rustls; pub use webpki; @@ -88,7 +88,7 @@ impl TlsConnector { TlsState::Stream }, - session + session, })) } } @@ -151,9 +151,7 @@ impl Future for Connect { #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0) - .poll(cx) - .map_err(|(err, _)| err) + Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) } } @@ -169,9 +167,7 @@ impl Future for Accept { #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0) - .poll(cx) - .map_err(|(err, _)| err) + Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) } } diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs index abf86d6..b5f8375 100644 --- a/tokio-rustls/src/server.rs +++ b/tokio-rustls/src/server.rs @@ -1,6 +1,6 @@ use super::*; -use rustls::Session; use crate::common::IoSession; +use rustls::Session; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -60,29 +60,35 @@ where false } - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); match &this.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { - Poll::Ready(Ok(0)) => { - this.state.shutdown_read(); - Poll::Ready(Ok(0)) - } - Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), - Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { - this.state.shutdown_read(); - if this.state.writeable() { - stream.session.send_close_notify(); - this.state.shutdown_write(); + TlsState::Stream | TlsState::WriteShutdown => { + match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) } - Poll::Ready(Ok(0)) + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { + stream.session.send_close_notify(); + this.state.shutdown_write(); + } + Poll::Ready(Ok(0)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending - }, + } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), #[cfg(feature = "early-data")] s => unreachable!("server TLS can not hit this state: {:?}", s), @@ -96,17 +102,21 @@ where { /// Note: that it does not guarantee the final data to be sent. /// To be cautious, you must manually call `flush`. - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_flush(cx) } @@ -117,8 +127,8 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/tokio-rustls/tests/badssl.rs b/tokio-rustls/tests/badssl.rs index 3a02e86..4c564c6 100644 --- a/tokio-rustls/tests/badssl.rs +++ b/tokio-rustls/tests/badssl.rs @@ -1,21 +1,20 @@ -use std::io; -use std::sync::Arc; -use std::net::ToSocketAddrs; -use tokio::prelude::*; -use tokio::net::TcpStream; use rustls::ClientConfig; -use tokio_rustls::{ TlsConnector, client::TlsStream }; +use std::io; +use std::net::ToSocketAddrs; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio::prelude::*; +use tokio_rustls::{client::TlsStream, TlsConnector}; - -async fn get(config: Arc, domain: &str, port: u16) - -> io::Result<(TlsStream, String)> -{ +async fn get( + config: Arc, + domain: &str, + port: u16, +) -> io::Result<(TlsStream, String)> { let connector = TlsConnector::from(config); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - let addr = (domain, port) - .to_socket_addrs()? - .next().unwrap(); + let addr = (domain, port).to_socket_addrs()?.next().unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); let mut buf = Vec::new(); @@ -31,7 +30,9 @@ async fn get(config: Arc, domain: &str, port: u16) #[tokio::test] async fn test_tls12() -> io::Result<()> { let mut config = ClientConfig::new(); - config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); config.versions = vec![rustls::ProtocolVersion::TLSv1_2]; let config = Arc::new(config); let domain = "tls-v1-2.badssl.com"; @@ -52,7 +53,9 @@ fn test_tls13() { #[tokio::test] async fn test_modern() -> io::Result<()> { let mut config = ClientConfig::new(); - config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; diff --git a/tokio-rustls/tests/early-data.rs b/tokio-rustls/tests/early-data.rs index d7f2df9..325fcad 100644 --- a/tokio-rustls/tests/early-data.rs +++ b/tokio-rustls/tests/early-data.rs @@ -1,21 +1,17 @@ #![cfg(feature = "early-data")] -use std::io::{ self, BufRead, BufReader, Cursor }; -use std::process::{ Command, Child, Stdio }; +use std::io::{self, BufRead, BufReader, Cursor}; use std::net::SocketAddr; +use std::pin::Pin; +use std::process::{Child, Command, Stdio}; +use std::process::{Child, Command, Stdio}; use std::sync::Arc; -use std::marker::Unpin; -use std::pin::{ Pin }; -use std::task::{ Context, Poll }; +use std::task::{Context, Poll}; use std::time::Duration; -use tokio::prelude::*; use tokio::net::TcpStream; +use tokio::prelude::*; use tokio::time::delay_for; -use futures_util::{ future, ready }; -use rustls::ClientConfig; -use tokio_rustls::{ TlsConnector, client::TlsStream }; -use std::future::Future; - +use tokio_rustls::{client::TlsStream, TlsConnector}; struct Read1(T); @@ -29,11 +25,12 @@ impl Future for Read1 { } } -async fn send(config: Arc, addr: SocketAddr, data: &[u8]) - -> io::Result> -{ - let connector = TlsConnector::from(config) - .early_data(true); +async fn send( + config: Arc, + addr: SocketAddr, + data: &[u8], +) -> io::Result> { + let connector = TlsConnector::from(config).early_data(true); let stream = TcpStream::connect(&addr).await?; let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); @@ -98,10 +95,8 @@ async fn test_0rtt() -> io::Result<()> { let stdout = handle.0.stdout.as_mut().unwrap(); let mut lines = BufReader::new(stdout).lines(); - let has_msg1 = lines.by_ref() - .any(|line| line.unwrap().contains("hello")); - let has_msg2 = lines.by_ref() - .any(|line| line.unwrap().contains("world!")); + let has_msg1 = lines.by_ref().any(|line| line.unwrap().contains("hello")); + let has_msg2 = lines.by_ref().any(|line| line.unwrap().contains("world!")); assert!(has_msg1 && has_msg2); diff --git a/tokio-rustls/tests/test.rs b/tokio-rustls/tests/test.rs index 9b98688..d0b449d 100644 --- a/tokio-rustls/tests/test.rs +++ b/tokio-rustls/tests/test.rs @@ -1,29 +1,30 @@ -use std::{ io, thread }; -use std::io::{ BufReader, Cursor }; -use std::sync::Arc; -use std::sync::mpsc::channel; -use std::net::SocketAddr; use futures_util::future::TryFutureExt; use lazy_static::lazy_static; +use rustls::internal::pemfile::{certs, rsa_private_keys}; +use rustls::{ClientConfig, ServerConfig}; +use std::io::{BufReader, Cursor}; +use std::net::SocketAddr; +use std::sync::mpsc::channel; +use std::sync::Arc; +use std::{io, thread}; +use tokio::io::{copy, split}; +use tokio::net::{TcpListener, TcpStream}; use tokio::prelude::*; use tokio::runtime; -use tokio::io::{ copy, split }; -use tokio::net::{ TcpListener, TcpStream }; -use rustls::{ ServerConfig, ClientConfig }; -use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_rustls::{ TlsConnector, TlsAcceptor }; +use tokio_rustls::{TlsAcceptor, TlsConnector}; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -lazy_static!{ +lazy_static! { static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, keys.pop().unwrap()) + config + .set_single_cert(cert, keys.pop().unwrap()) .expect("invalid key or certificate"); let acceptor = TlsAcceptor::from(Arc::new(config)); @@ -55,11 +56,13 @@ lazy_static!{ copy(&mut reader, &mut writer).await?; Ok(()) as io::Result<()> - }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + } + .unwrap_or_else(|err| eprintln!("server: {:?}", err)); handle.spawn(fut); } - }.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err)); + } + .unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err)); runtime.block_on(done); });