diff --git a/Cargo.toml b/Cargo.toml index 266de47..141e94b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] members = [ - "tokio-native-tls" + "tokio-native-tls", + "tokio-rustls" ] diff --git a/tokio-rustls/Cargo.toml b/tokio-rustls/Cargo.toml new file mode 100644 index 0000000..0fcb8c5 --- /dev/null +++ b/tokio-rustls/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "tokio-rustls" +version = "0.12.2" +authors = ["quininer kel "] +license = "MIT/Apache-2.0" +repository = "https://github.com/tokio-rs/tls" +homepage = "https://github.com/tokio-rs/tls" +documentation = "https://docs.rs/tokio-rustls" +readme = "README.md" +description = "Asynchronous TLS/SSL streams for Tokio using Rustls." +categories = ["asynchronous", "cryptography", "network-programming"] +edition = "2018" + +[dependencies] +tokio = "0.2.0" +futures-core = "0.3.1" +rustls = "0.16" +webpki = "0.21" + +bytes = { version = "0.5", optional = true } + +[features] +early-data = [] +dangerous_configuration = ["rustls/dangerous_configuration"] +unstable = ["bytes"] + +[dev-dependencies] +tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] } +futures-util = "0.3.1" +lazy_static = "1" +webpki-roots = "0.18" diff --git a/tokio-rustls/LICENSE-APACHE b/tokio-rustls/LICENSE-APACHE new file mode 100644 index 0000000..2154394 --- /dev/null +++ b/tokio-rustls/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright 2017 quininer kel + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/tokio-rustls/LICENSE-MIT b/tokio-rustls/LICENSE-MIT new file mode 100644 index 0000000..4500636 --- /dev/null +++ b/tokio-rustls/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2017 quininer kel + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/tokio-rustls/README.md b/tokio-rustls/README.md new file mode 100644 index 0000000..6faef8f --- /dev/null +++ b/tokio-rustls/README.md @@ -0,0 +1,65 @@ +# tokio-rustls +[![github actions](https://github.com/tokio-rs/tls/workflows/Rust/badge.svg)](https://github.com/tokio-rs/tls/actions) +[![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) +[![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/tokio-rs/tls/blob/master/tokio-rustls/LICENSE-MIT) +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/tokio-rs/tls/blob/master/tokio-rustls/LICENSE-APACHE) +[![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) + +Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using +[Rustls](https://github.com/ctz/rustls). + +### Basic Structure of a Client + +```rust +use webpki::DNSNameRef; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; + +// ... + +let mut config = ClientConfig::new(); +config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); +let config = TlsConnector::from(Arc::new(config)); +let dnsname = DNSNameRef::try_from_ascii_str("www.rust-lang.org").unwrap(); + +let stream = TcpStream::connect(&addr).await?; +let mut stream = config.connect(dnsname, stream).await?; + +// ... +``` + +### Client Example Program + +See [examples/client](examples/client/src/main.rs). You can run it with: + +```sh +cd examples/client +cargo run -- hsts.badssl.com +``` + +### Server Example Program + +See [examples/server](examples/server/src/main.rs). You can run it with: + +```sh +cd examples/server +cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der +``` + +### License & Origin + +This project is licensed under either of + + * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or + http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or + http://opensource.org/licenses/MIT) + +at your option. + +This started as a fork of [tokio-tls](https://github.com/tokio-rs/tokio-tls). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in tokio-rustls by you, as defined in the Apache-2.0 license, shall be +dual licensed as above, without any additional terms or conditions. diff --git a/tokio-rustls/examples/client/Cargo.toml b/tokio-rustls/examples/client/Cargo.toml new file mode 100644 index 0000000..40162f8 --- /dev/null +++ b/tokio-rustls/examples/client/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "client" +version = "0.1.0" +authors = ["quininer "] +edition = "2018" + +[dependencies] +futures-util = "0.3" +tokio = { version = "0.2", features = [ "net", "io-std", "io-util", "rt-threaded" ] } +structopt = "0.2" +tokio-rustls = { path = "../.." } +webpki-roots = "0.18" diff --git a/tokio-rustls/examples/client/src/main.rs b/tokio-rustls/examples/client/src/main.rs new file mode 100644 index 0000000..6012c7e --- /dev/null +++ b/tokio-rustls/examples/client/src/main.rs @@ -0,0 +1,88 @@ +use std::io; +use std::fs::File; +use std::path::PathBuf; +use std::sync::Arc; +use std::net::ToSocketAddrs; +use std::io::BufReader; +use futures_util::future; +use structopt::StructOpt; +use tokio::runtime; +use tokio::net::TcpStream; +use tokio::io::{ + AsyncWriteExt, + copy, split, + stdin as tokio_stdin, stdout as tokio_stdout +}; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef }; + + +#[derive(StructOpt)] +struct Options { + host: String, + + /// port + #[structopt(short="p", long="port", default_value="443")] + port: u16, + + /// domain + #[structopt(short="d", long="domain")] + domain: Option, + + /// cafile + #[structopt(short="c", long="cafile", parse(from_os_str))] + cafile: Option +} + + +fn main() -> io::Result<()> { + let options = Options::from_args(); + + let addr = (options.host.as_str(), options.port) + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + let domain = options.domain.unwrap_or(options.host); + let content = format!( + "GET / HTTP/1.0\r\nHost: {}\r\n\r\n", + domain + ); + + let mut runtime = runtime::Builder::new() + .basic_scheduler() + .enable_io() + .build()?; + let mut config = ClientConfig::new(); + if let Some(cafile) = &options.cafile { + let mut pem = BufReader::new(File::open(cafile)?); + config.root_store.add_pem_file(&mut pem) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?; + } else { + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + } + let connector = TlsConnector::from(Arc::new(config)); + + let fut = async { + let stream = TcpStream::connect(&addr).await?; + + let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout()); + + let domain = DNSNameRef::try_from_ascii_str(&domain) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; + + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(content.as_bytes()).await?; + + let (mut reader, mut writer) = split(stream); + future::select( + copy(&mut reader, &mut stdout), + copy(&mut stdin, &mut writer) + ) + .await + .factor_first() + .0?; + + Ok(()) + }; + + runtime.block_on(fut) +} diff --git a/tokio-rustls/examples/server/Cargo.toml b/tokio-rustls/examples/server/Cargo.toml new file mode 100644 index 0000000..cd42662 --- /dev/null +++ b/tokio-rustls/examples/server/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "server" +version = "0.1.0" +authors = ["quininer "] +edition = "2018" + +[dependencies] +futures-util = "0.3" +tokio = { version = "0.2", features = [ "net", "io-util", "rt-threaded" ] } +structopt = "0.2" +tokio-rustls = { path = "../.." } diff --git a/tokio-rustls/examples/server/src/main.rs b/tokio-rustls/examples/server/src/main.rs new file mode 100644 index 0000000..1fcb0e4 --- /dev/null +++ b/tokio-rustls/examples/server/src/main.rs @@ -0,0 +1,99 @@ +use std::fs::File; +use std::sync::Arc; +use std::net::ToSocketAddrs; +use std::path::{ PathBuf, Path }; +use std::io::{ self, BufReader }; +use futures_util::future::TryFutureExt; +use structopt::StructOpt; +use tokio::runtime; +use tokio::net::TcpListener; +use tokio::io::{ AsyncWriteExt, copy, split }; +use tokio_rustls::rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; +use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::TlsAcceptor; + + +#[derive(StructOpt)] +struct Options { + addr: String, + + /// cert file + #[structopt(short="c", long="cert", parse(from_os_str))] + cert: PathBuf, + + /// key file + #[structopt(short="k", long="key", parse(from_os_str))] + key: PathBuf, + + /// echo mode + #[structopt(short="e", long="echo-mode")] + echo: bool +} + +fn load_certs(path: &Path) -> io::Result> { + certs(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) +} + +fn load_keys(path: &Path) -> io::Result> { + rsa_private_keys(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) +} + + +fn main() -> io::Result<()> { + let options = Options::from_args(); + + let addr = options.addr.to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?; + let certs = load_certs(&options.cert)?; + let mut keys = load_keys(&options.key)?; + let flag_echo = options.echo; + + let mut runtime = runtime::Builder::new() + .threaded_scheduler() + .enable_io() + .build()?; + let handle = runtime.handle().clone(); + let mut config = ServerConfig::new(NoClientAuth::new()); + config.set_single_cert(certs, keys.remove(0)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + let acceptor = TlsAcceptor::from(Arc::new(config)); + + let fut = async { + let mut listener = TcpListener::bind(&addr).await?; + + loop { + let (stream, peer_addr) = listener.accept().await?; + let acceptor = acceptor.clone(); + + let fut = async move { + let mut stream = acceptor.accept(stream).await?; + + if flag_echo { + let (mut reader, mut writer) = split(stream); + let n = copy(&mut reader, &mut writer).await?; + writer.flush().await?; + println!("Echo: {} - {}", peer_addr, n); + } else { + stream.write_all( + &b"HTTP/1.0 200 ok\r\n\ + Connection: close\r\n\ + Content-length: 12\r\n\ + \r\n\ + Hello world!"[..] + ).await?; + stream.flush().await?; + println!("Hello: {}", peer_addr); + } + + Ok(()) as io::Result<()> + }; + + handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))); + } + }; + + runtime.block_on(fut) +} diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs new file mode 100644 index 0000000..5007aa8 --- /dev/null +++ b/tokio-rustls/src/client.rs @@ -0,0 +1,191 @@ +use super::*; +use rustls::Session; +use crate::common::IoSession; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ClientSession, + pub(crate) state: TlsState, +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ClientSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ClientSession) { + (self.io, self.session) + } +} + +impl IoSession for TlsStream { + type Io = IO; + type Session = ClientSession; + + #[inline] + fn skip_handshake(&self) -> bool { + self.state.is_early_data() + } + + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } + + #[inline] + fn into_io(self) -> Self::Io { + self.io + } +} + +impl AsyncRead for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[cfg(feature = "unstable")] + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + false + } + + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + match self.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData(..) => Poll::Pending, + TlsState::Stream | TlsState::WriteShutdown => { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + + match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) + }, + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { + stream.session.send_close_notify(); + this.state.shutdown_write(); + } + Poll::Ready(Ok(0)) + }, + output => output + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + } + } +} + +impl AsyncWrite for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + + match this.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData(ref mut pos, ref mut data) => { + use futures_core::ready; + use std::io::Write; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = match early_data.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => + return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; + if len != 0 { + data.extend_from_slice(&buf[..len]); + return Poll::Ready(Ok(len)); + } + } + + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + stream.as_mut_pin().poll_write(cx, buf) + } + _ => stream.as_mut_pin().poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] { + use futures_core::ready; + + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + this.state = TlsState::Stream; + } + } + + stream.as_mut_pin().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); + } + + #[cfg(feature = "early-data")] { + // we skip the handshake + if let TlsState::EarlyData(..) = self.state { + return Pin::new(&mut self.io).poll_shutdown(cx); + } + } + + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.as_mut_pin().poll_shutdown(cx) + } +} diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs new file mode 100644 index 0000000..c59541e --- /dev/null +++ b/tokio-rustls/src/common/handshake.rs @@ -0,0 +1,84 @@ +use std::{ io, mem }; +use std::pin::Pin; +use std::future::Future; +use std::task::{ Context, Poll }; +use futures_core::future::FusedFuture; +use tokio::io::{ AsyncRead, AsyncWrite }; +use rustls::Session; +use crate::common::{ TlsState, Stream }; + + +pub(crate) trait IoSession { + type Io; + type Session; + + fn skip_handshake(&self) -> bool; + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); + fn into_io(self) -> Self::Io; +} + +pub(crate) enum MidHandshake { + Handshaking(IS), + End, +} + +impl FusedFuture for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: Session + Unpin +{ + fn is_terminated(&self) -> bool { + if let MidHandshake::End = self { + true + } else { + false + } + } +} + +impl Future for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: Session + Unpin +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) { + if !stream.skip_handshake() { + let (state, io, session) = stream.get_mut(); + let mut tls_stream = Stream::new(io, session) + .set_eof(!state.readable()); + + macro_rules! try_poll { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Pending => { + *this = MidHandshake::Handshaking(stream); + return Poll::Pending; + } + } + } + } + + while tls_stream.session.is_handshaking() { + try_poll!(tls_stream.handshake(cx)); + } + + while tls_stream.session.wants_write() { + try_poll!(tls_stream.write_io(cx)); + } + } + + Poll::Ready(Ok(stream)) + } else { + panic!("unexpected polling after handshake") + } + } +} diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs new file mode 100644 index 0000000..1d0dd07 --- /dev/null +++ b/tokio-rustls/src/common/mod.rs @@ -0,0 +1,347 @@ +mod handshake; + +#[cfg(feature = "unstable")] +mod vecbuf; + +use std::pin::Pin; +use std::task::{ Poll, Context }; +use std::io::{ self, Read }; +use rustls::Session; +use tokio::io::{ AsyncRead, AsyncWrite }; +use futures_core as futures; +pub(crate) use handshake::{ IoSession, MidHandshake }; + + +#[derive(Debug)] +pub enum TlsState { + #[cfg(feature = "early-data")] + EarlyData(usize, Vec), + Stream, + ReadShutdown, + WriteShutdown, + FullyShutdown, +} + +impl TlsState { + #[inline] + pub fn shutdown_read(&mut self) { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => + *self = TlsState::FullyShutdown, + _ => *self = TlsState::ReadShutdown, + } + } + + #[inline] + pub fn shutdown_write(&mut self) { + match *self { + TlsState::ReadShutdown | TlsState::FullyShutdown => + *self = TlsState::FullyShutdown, + _ => *self = TlsState::WriteShutdown, + } + } + + #[inline] + pub fn writeable(&self) -> bool { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } + + #[inline] + pub fn readable(&self) -> bool { + match self { + TlsState::ReadShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } + + #[inline] + #[cfg(feature = "early-data")] + pub fn is_early_data(&self) -> bool { + match self { + TlsState::EarlyData(..) => true, + _ => false + } + } + + #[inline] + #[cfg(not(feature = "early-data"))] + pub const fn is_early_data(&self) -> bool { + false + } +} + +pub struct Stream<'a, IO, S> { + pub io: &'a mut IO, + pub session: &'a mut S, + pub eof: bool +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { + pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { + Stream { + io, + session, + // The state so far is only used to detect EOF, so either Stream + // or EarlyData state should both be all right. + eof: false, + } + } + + pub fn set_eof(mut self, eof: bool) -> Self { + self.eof = eof; + self + } + + pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { + Pin::new(self) + } + + pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { + self.session.process_new_packets() + .map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); + + io::Error::new(io::ErrorKind::InvalidData, err) + }) + } + + pub fn read_io(&mut self, cx: &mut Context) -> Poll> { + struct Reader<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match Pin::new(&mut self.io).poll_read(self.cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut reader = Reader { io: self.io, cx }; + + let n = match self.session.read_tls(&mut reader) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; + + Poll::Ready(Ok(n)) + } + + #[cfg(not(feature = "unstable"))] + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + use std::io::Write; + + struct Writer<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { + fn write(&mut self, buf: &[u8]) -> io::Result { + match Pin::new(&mut self.io).poll_write(self.cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + + fn flush(&mut self) -> io::Result<()> { + match Pin::new(&mut self.io).poll_flush(self.cx) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut writer = Writer { io: self.io, cx }; + + match self.session.write_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + #[cfg(feature = "unstable")] + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + use rustls::WriteV; + + struct Writer<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> { + fn writev(&mut self, vbuf: &[&[u8]]) -> io::Result { + use vecbuf::VecBuf; + + let mut vbuf = VecBuf::new(vbuf); + + match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut writer = Writer { io: self.io, cx }; + + match self.session.writev_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { + write_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + while !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => { + read_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + self.process_new_packets(cx)?; + + return match (self.eof, self.session.is_handshaking()) { + (true, true) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Poll::Ready(Err(err)) + }, + (_, false) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + }, + (..) => continue + } + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + let mut pos = 0; + + while pos != buf.len() { + let mut would_block = false; + + // read a packet + while self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => { + self.eof = true; + break + }, + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + self.process_new_packets(cx)?; + + return match self.session.read(&mut buf[pos..]) { + Ok(0) if pos == 0 && would_block => Poll::Pending, + Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)), + Ok(n) => { + pos += n; + continue + }, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => + Poll::Ready(Ok(pos)), + Err(err) => Poll::Ready(Err(err)) + } + } + + Poll::Ready(Ok(pos)) + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let mut pos = 0; + + while pos != buf.len() { + let mut would_block = false; + + match self.session.write(&buf[pos..]) { + Ok(n) => pos += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), + Err(err) => return Poll::Ready(Err(err)) + }; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break + }, + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + return match (pos, would_block) { + (0, true) => Poll::Pending, + (n, true) => Poll::Ready(Ok(n)), + (_, false) => continue + } + } + + Poll::Ready(Ok(pos)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.session.flush()?; + while self.session.wants_write() { + futures::ready!(self.write_io(cx))?; + } + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while self.session.wants_write() { + futures::ready!(self.write_io(cx))?; + } + Pin::new(&mut self.io).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod test_stream; diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs new file mode 100644 index 0000000..0055014 --- /dev/null +++ b/tokio-rustls/src/common/test_stream.rs @@ -0,0 +1,220 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ Poll, Context }; +use futures_core::ready; +use futures_util::future::poll_fn; +use futures_util::task::noop_waker_ref; +use tokio::io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; +use std::io::{ self, Read, Write, BufReader, Cursor }; +use webpki::DNSNameRef; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use rustls::{ + ServerConfig, ClientConfig, + ServerSession, ClientSession, + Session, NoClientAuth +}; +use super::Stream; + + +struct Good<'a>(&'a mut dyn Session); + +impl<'a> AsyncRead for Good<'a> { + fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll> { + Poll::Ready(self.0.write_tls(buf.by_ref())) + } +} + +impl<'a> AsyncWrite for Good<'a> { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll> { + let len = self.0.read_tls(buf.by_ref())?; + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Poll::Ready(Ok(len)) + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.0.send_close_notify(); + Poll::Ready(Ok(())) + } +} + +struct Pending; + +impl AsyncRead for Pending { + fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + Poll::Pending + } +} + +impl AsyncWrite for Pending { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct Eof; + +impl AsyncRead for Eof { + fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + Poll::Ready(Ok(0)) + } +} + +impl AsyncWrite for Eof { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn stream_good() -> io::Result<()> { + const FILE: &'static [u8] = include_bytes!("../../README.md"); + + let (mut server, mut client) = make_pair(); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + io::copy(&mut Cursor::new(FILE), &mut server)?; + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!").await?; + stream.flush().await?; + } + + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_bad() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + client.set_buffer_limit(1024); + + let mut bad = Pending; + let mut stream = Stream::new(&mut bad, &mut client); + assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert!(r < 1024); + + let mut cx = Context::from_waker(noop_waker_ref()); + let ret = stream.as_mut_pin().poll_write(&mut cx, &[0x01]); + assert!(ret.is_pending()); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_handshake() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); + let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; + + assert!(r > 0); + assert!(w > 0); + + poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_handshake_eof() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut bad = Eof; + let mut stream = Stream::new(&mut bad, &mut client); + + let mut cx = Context::from_waker(noop_waker_ref()); + let r = stream.handshake(&mut cx); + assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_eof() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 0); + + Ok(()) as io::Result<()> +} + +fn make_pair() -> (ServerSession, ClientSession) { + const CERT: &str = include_str!("../../tests/end.cert"); + const CHAIN: &str = include_str!("../../tests/end.chain"); + const RSA: &str = include_str!("../../tests/end.rsa"); + + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let mut sconfig = ServerConfig::new(NoClientAuth::new()); + sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap(); + let server = ServerSession::new(&Arc::new(sconfig)); + + let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap(); + let mut cconfig = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(CHAIN)); + cconfig.root_store.add_pem_file(&mut chain).unwrap(); + let client = ClientSession::new(&Arc::new(cconfig), domain); + + (server, client) +} + +fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll> { + let mut good = Good(server); + let mut stream = Stream::new(&mut good, client); + + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + while stream.session.wants_write() { + ready!(stream.write_io(cx))?; + } + + Poll::Ready(Ok(())) +} diff --git a/tokio-rustls/src/common/vecbuf.rs b/tokio-rustls/src/common/vecbuf.rs new file mode 100644 index 0000000..6ea19e3 --- /dev/null +++ b/tokio-rustls/src/common/vecbuf.rs @@ -0,0 +1,128 @@ +use std::io::IoSlice; +use std::cmp::{ self, Ordering }; +use bytes::Buf; + + +pub struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + pub fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + }, + Ordering::Greater => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + #[allow(clippy::needless_range_loop)] + fn bytes_vectored<'c>(&'c self, dst: &mut [IoSlice<'c>]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + if len > 0 { + dst[0] = IoSlice::new(self.bytes()); + } + + for i in 1..len { + dst[i] = IoSlice::new(&self.inner[self.pos + i]); + } + + len + } +} + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(1); + + assert_eq!(buf.remaining(), 4); + assert_eq!(buf.bytes(), b"e"); + + buf.advance(1); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [IoSlice; 2] = + [IoSlice::new(b1), IoSlice::new(b2)]; + + assert_eq!(2, buf.bytes_vectored(&mut dst[..])); + } +} diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs new file mode 100644 index 0000000..db34b07 --- /dev/null +++ b/tokio-rustls/src/lib.rs @@ -0,0 +1,316 @@ +//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). + +mod common; +pub mod client; +pub mod server; + +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::future::Future; +use std::task::{ Context, Poll }; +use futures_core::future::FusedFuture; +use tokio::io::{ AsyncRead, AsyncWrite }; +use webpki::DNSNameRef; +use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; +use common::{ Stream, TlsState, MidHandshake }; + +pub use rustls; +pub use webpki; + +/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. +#[derive(Clone)] +pub struct TlsConnector { + inner: Arc, + #[cfg(feature = "early-data")] + early_data: bool, +} + +/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. +#[derive(Clone)] +pub struct TlsAcceptor { + inner: Arc, +} + +impl From> for TlsConnector { + fn from(inner: Arc) -> TlsConnector { + TlsConnector { + inner, + #[cfg(feature = "early-data")] + early_data: false, + } + } +} + +impl From> for TlsAcceptor { + fn from(inner: Arc) -> TlsAcceptor { + TlsAcceptor { inner } + } +} + +impl TlsConnector { + /// Enable 0-RTT. + /// + /// If you want to use 0-RTT, + /// You must also set `ClientConfig.enable_early_data` to `true`. + #[cfg(feature = "early-data")] + pub fn early_data(mut self, flag: bool) -> TlsConnector { + self.early_data = flag; + self + } + + #[inline] + pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect + where + IO: AsyncRead + AsyncWrite + Unpin, + { + self.connect_with(domain, stream, |_| ()) + } + + pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect + where + IO: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut ClientSession), + { + let mut session = ClientSession::new(&self.inner, domain); + f(&mut session); + + Connect(MidHandshake::Handshaking(client::TlsStream { + io: stream, + + #[cfg(not(feature = "early-data"))] + state: TlsState::Stream, + + #[cfg(feature = "early-data")] + state: if self.early_data && session.early_data().is_some() { + TlsState::EarlyData(0, Vec::new()) + } else { + TlsState::Stream + }, + + session + })) + } +} + +impl TlsAcceptor { + #[inline] + pub fn accept(&self, stream: IO) -> Accept + where + IO: AsyncRead + AsyncWrite + Unpin, + { + self.accept_with(stream, |_| ()) + } + + pub fn accept_with(&self, stream: IO, f: F) -> Accept + where + IO: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut ServerSession), + { + let mut session = ServerSession::new(&self.inner); + f(&mut session); + + Accept(MidHandshake::Handshaking(server::TlsStream { + session, + io: stream, + state: TlsState::Stream, + })) + } +} + +/// Future returned from `TlsConnector::connect` which will resolve +/// once the connection handshake has finished. +pub struct Connect(MidHandshake>); + +/// Future returned from `TlsAcceptor::accept` which will resolve +/// once the accept handshake has finished. +pub struct Accept(MidHandshake>); + +/// Like [Connect], but returns `IO` on failure. +pub struct FailableConnect(MidHandshake>); + +/// Like [Accept], but returns `IO` on failure. +pub struct FailableAccept(MidHandshake>); + +impl Connect { + #[inline] + pub fn into_failable(self) -> FailableConnect { + FailableConnect(self.0) + } +} + +impl Accept { + #[inline] + pub fn into_failable(self) -> FailableAccept { + FailableAccept(self.0) + } +} + +impl Future for Connect { + type Output = io::Result>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0) + .poll(cx) + .map_err(|(err, _)| err) + } +} + +impl FusedFuture for Connect { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for Accept { + type Output = io::Result>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0) + .poll(cx) + .map_err(|(err, _)| err) + } +} + +impl FusedFuture for Accept { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for FailableConnect { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl FusedFuture for FailableConnect { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for FailableAccept { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl FusedFuture for FailableAccept { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +/// Unified TLS stream type +/// +/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use +/// a single type to keep both client- and server-initiated TLS-encrypted connections. +pub enum TlsStream { + Client(client::TlsStream), + Server(server::TlsStream), +} + +impl TlsStream { + pub fn get_ref(&self) -> (&T, &dyn Session) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_ref(); + (io, &*session) + } + Server(io) => { + let (io, session) = io.get_ref(); + (io, &*session) + } + } + } + + pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + Server(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + } + } +} + +impl From> for TlsStream { + fn from(s: client::TlsStream) -> Self { + Self::Client(s) + } +} + +impl From> for TlsStream { + fn from(s: server::TlsStream) -> Self { + Self::Server(s) + } +} + +impl AsyncRead for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_flush(cx), + TlsStream::Server(x) => Pin::new(x).poll_flush(cx), + } + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx), + TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx), + } + } +} diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs new file mode 100644 index 0000000..abf86d6 --- /dev/null +++ b/tokio-rustls/src/server.rs @@ -0,0 +1,124 @@ +use super::*; +use rustls::Session; +use crate::common::IoSession; + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ServerSession, + pub(crate) state: TlsState, +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ServerSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ServerSession) { + (self.io, self.session) + } +} + +impl IoSession for TlsStream { + type Io = IO; + type Session = ServerSession; + + #[inline] + fn skip_handshake(&self) -> bool { + false + } + + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } + + #[inline] + fn into_io(self) -> Self::Io { + self.io + } +} + +impl AsyncRead for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[cfg(feature = "unstable")] + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { + // TODO + // + // https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer + false + } + + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + + match &this.state { + TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) + } + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { + stream.session.send_close_notify(); + this.state.shutdown_write(); + } + Poll::Ready(Ok(0)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending + }, + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + #[cfg(feature = "early-data")] + s => unreachable!("server TLS can not hit this state: {:?}", s), + } + } +} + +impl AsyncWrite for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.as_mut_pin().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.as_mut_pin().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); + } + + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.as_mut_pin().poll_shutdown(cx) + } +} diff --git a/tokio-rustls/tests/badssl.rs b/tokio-rustls/tests/badssl.rs new file mode 100644 index 0000000..3a02e86 --- /dev/null +++ b/tokio-rustls/tests/badssl.rs @@ -0,0 +1,63 @@ +use std::io; +use std::sync::Arc; +use std::net::ToSocketAddrs; +use tokio::prelude::*; +use tokio::net::TcpStream; +use rustls::ClientConfig; +use tokio_rustls::{ TlsConnector, client::TlsStream }; + + +async fn get(config: Arc, domain: &str, port: u16) + -> io::Result<(TlsStream, String)> +{ + let connector = TlsConnector::from(config); + let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); + + let addr = (domain, port) + .to_socket_addrs()? + .next().unwrap(); + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + let mut buf = Vec::new(); + + let stream = TcpStream::connect(&addr).await?; + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(input.as_bytes()).await?; + stream.flush().await?; + stream.read_to_end(&mut buf).await?; + + Ok((stream, String::from_utf8(buf).unwrap())) +} + +#[tokio::test] +async fn test_tls12() -> io::Result<()> { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config.versions = vec![rustls::ProtocolVersion::TLSv1_2]; + let config = Arc::new(config); + let domain = "tls-v1-2.badssl.com"; + + let (_, output) = get(config.clone(), domain, 1012).await?; + assert!(output.contains("tls-v1-2.badssl.com")); + + Ok(()) +} + +#[ignore] +#[should_panic] +#[test] +fn test_tls13() { + unimplemented!("todo https://github.com/chromium/badssl.com/pull/373"); +} + +#[tokio::test] +async fn test_modern() -> io::Result<()> { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + let config = Arc::new(config); + let domain = "mozilla-modern.badssl.com"; + + let (_, output) = get(config.clone(), domain, 443).await?; + assert!(output.contains("mozilla-modern.badssl.com")); + + Ok(()) +} diff --git a/tokio-rustls/tests/early-data.rs b/tokio-rustls/tests/early-data.rs new file mode 100644 index 0000000..35523e0 --- /dev/null +++ b/tokio-rustls/tests/early-data.rs @@ -0,0 +1,108 @@ +#![cfg(feature = "early-data")] + +use std::io::{ self, BufReader, BufRead, Cursor }; +use std::process::{ Command, Child, Stdio }; +use std::net::SocketAddr; +use std::sync::Arc; +use std::marker::Unpin; +use std::pin::{ Pin }; +use std::task::{ Context, Poll }; +use std::time::Duration; +use tokio::prelude::*; +use tokio::net::TcpStream; +use tokio::time::delay_for; +use futures_util::{ future, ready }; +use rustls::ClientConfig; +use tokio_rustls::{ TlsConnector, client::TlsStream }; +use std::future::Future; + + +struct Read1(T); + +impl Future for Read1 { + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut buf = [0]; + ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?; + Poll::Pending + } +} + +async fn send(config: Arc, addr: SocketAddr, data: &[u8]) + -> io::Result> +{ + let connector = TlsConnector::from(config) + .early_data(true); + let stream = TcpStream::connect(&addr).await?; + let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); + + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(data).await?; + stream.flush().await?; + + // sleep 1s + // + // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html + let sleep1 = delay_for(Duration::from_secs(1)); + let mut stream = match future::select(Read1(stream), sleep1).await { + future::Either::Right((_, Read1(stream))) => stream, + future::Either::Left((Err(err), _)) => return Err(err), + future::Either::Left((Ok(_), _)) => unreachable!(), + }; + + stream.shutdown().await?; + + Ok(stream) +} + +struct DropKill(Child); + +impl Drop for DropKill { + fn drop(&mut self) { + self.0.kill().unwrap(); + } +} + +#[tokio::test] +async fn test_0rtt() -> io::Result<()> { + let mut handle = Command::new("openssl") + .arg("s_server") + .arg("-early_data") + .arg("-tls1_3") + .args(&["-cert", "./tests/end.cert"]) + .args(&["-key", "./tests/end.rsa"]) + .args(&["-port", "12354"]) + .stdout(Stdio::piped()) + .spawn() + .map(DropKill)?; + + // wait openssl server + delay_for(Duration::from_secs(1)).await; + + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); + config.root_store.add_pem_file(&mut chain).unwrap(); + config.versions = vec![rustls::ProtocolVersion::TLSv1_3]; + config.enable_early_data = true; + let config = Arc::new(config); + let addr = SocketAddr::from(([127, 0, 0, 1], 12354)); + + let io = send(config.clone(), addr, b"hello").await?; + assert!(!io.get_ref().1.is_early_data_accepted()); + + let io = send(config, addr, b"world!").await?; + assert!(io.get_ref().1.is_early_data_accepted()); + + let stdout = handle.0.stdout.as_mut().unwrap(); + let mut lines = BufReader::new(stdout).lines(); + + let has_msg1 = lines.by_ref() + .any(|line| line.unwrap().contains("hello")); + let has_msg2 = lines.by_ref() + .any(|line| line.unwrap().contains("world!")); + + assert!(has_msg1 && has_msg2); + + Ok(()) +} diff --git a/tokio-rustls/tests/end.cert b/tokio-rustls/tests/end.cert new file mode 100644 index 0000000..66f087e --- /dev/null +++ b/tokio-rustls/tests/end.cert @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u +eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTE2MTIxMDE3NDIzM1oX +DTIyMDYwMjE3NDIzM1owGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1YDz66+7VD4DL1+/sVHMQ+BbDRgmD +OQlX++mfW8D3QNQm/qDBEbu7T7qqdc9GKDar4WIzBN8SBkzM1EjMGwNnZPV/Tfz0 +qUAR1L/7Zzf1GaFZvWXgksyUpfwvmprH3Iy/dpkETwtPthpTPNlui3hZnm/5kkjR +RWg9HmID4O04Ld6SK313v2ZgrPZbkKvbqlqhUnYWjL3blKVGbpXIsuZzEU9Ph+gH +tPcEhZpFsM6eLe+2TVscIrycMEOTXqAAmO6zZ9sQWtfllu3CElm904H6+jA/9Leg +al72pMmkYr8wWniqDDuijXuCPlVx5EDFFyxBmW18UeDEQaKV3kNfelaTAgMBAAGj +gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFIYhJkVy +AAKT6cY/ruH1Eu+NNxteMEIGA1UdIwQ7MDmAFNwuPy4Do//Sm5CZDrocHWTrNr96 +oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO +dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0 +MA0GCSqGSIb3DQEBCwUAA4IBgQCWV76jfQDZKtfmj45fTwZzoe/PxjWPRbAvSEnt +LRHrPhqQfpMLqpun8uu/w86mHiR/AmiAySMu3zivW6wfGzlRWLi/zCyO6r9LGsgH +bNk5CF642cdZFvn1SiSm1oGXQrolIpcyXu88nUpt74RnY4ETCC1dRQKqxsYufe5T +DOmTm3ChinNW4QRG3yvW6DVuyxVAgZvofyKJOsM3GO6oogIM41aBqZ3UTwmIwp6D +oISdiATslFOzYzjnyXNR8DG8OOkv1ehWuyb8x+hQCZAuogQOWYtCSd6k3kKgd0EM +4CWbt1XDV9ZJwBf2uxZeKuCu/KIy9auNtijAwPsUv9qxuzko018zhl3lWm5p2Sqw +O7fFshU3A6df8hMw7ST6/tgFY7geT88U4iJhfWMwr/CZSRSVMXhTyJgbLIXxKYZj +Ym5v4NAIQP6hI4HixzQaYgrhW6YX6myk+emMjQLRJHT8uHvmT7fuxMJVWWgsCkr1 +C75pRQEagykN/Uzr5e6Tm8sVu88= +-----END CERTIFICATE----- diff --git a/tokio-rustls/tests/end.chain b/tokio-rustls/tests/end.chain new file mode 100644 index 0000000..7c39013 --- /dev/null +++ b/tokio-rustls/tests/end.chain @@ -0,0 +1,89 @@ +-----BEGIN CERTIFICATE----- +MIIGnzCCAoegAwIBAgIBezANBgkqhkiG9w0BAQsFADAaMRgwFgYDVQQDDA9wb255 +dG93biBSU0EgQ0EwHhcNMTYxMjEwMTc0MjMzWhcNMjYxMjA4MTc0MjMzWjAsMSow +KAYDVQQDDCFwb255dG93biBSU0EgbGV2ZWwgMiBpbnRlcm1lZGlhdGUwggGiMA0G +CSqGSIb3DQEBAQUAA4IBjwAwggGKAoIBgQDnfb7vaJbaHEyVTflswWhmHqx5W0NO +KyKbDp2zXEJwDO+NDJq6i1HGnFd/vO4LyjJBU1wUsKtE+m55cfRmUHVuZ2w4n/VF +p7Z7n+SNuvJNcrzDxyKVy4GIZ39zQePnniqtLqXh6eI8Ow6jiMgVxC/wbWcVLKv6 +4RM+2fLjJAC9b27QfjhOlMKVeMOEvPrrpjLSauaHAktQPhuzIAwzxM0+KnvDkWWy +NVqAV/lq6fSO/9vJRhM4E2nxo6yqi7qTdxVxMmKsNn7L6HvjQgx+FXziAUs55Qd9 +cP7etCmPmoefkcgdbxDOIKH8D+DvfacZwngqcnr/q96Ff4uJ13d2OzR1mWVSZ2hE +JQt/BbZBANciqu9OZf3dj6uOOXgFF705ak0GfLtpZpc29M+fVnknXPDSiKFqjzOO +KL+SRGyuNc9ZYjBKkXPJ1OToAs6JSvgDxfOfX0thuo2rslqfpj2qCFugsRIRAqvb +eyFwg+BPM/P/EfauXlAcQtBF04fOi7xN2okCAwEAAaNeMFwwHQYDVR0OBBYEFNwu +Py4Do//Sm5CZDrocHWTrNr96MCAGA1UdJQEB/wQWMBQGCCsGAQUFBwMBBggrBgEF +BQcDAjAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIB/jANBgkqhkiG9w0BAQsFAAOC +BAEAMHZpBqDIUAVFZNw4XbuimXQ4K8q4uePrLGHLb4F/gHbr8kYrU4H+cy4l+xXf +2dlEBdZoqjSF7uXzQg5Fd8Ff3ZgutXd1xeUJnxo0VdpKIhqeaTPqhffC2X6FQQH5 +KrN7NVWQSnUhPNpBFELpmdpY1lHigFW7nytYj0C6VJ4QsbqhfW+n/t+Zgqtfh/Od +ZbclzxFwMM55zRA2HP6IwXS2+d61Jk/RpDHTzhWdjGH4906zGNNMa7slHpCTA9Ju +TrtjEAGt2PBSievBJOHZW80KVAoEX2n9B3ZABaz+uX0VVZG0D2FwhPpUeA57YiXu +qiktZR4Ankph3LabXp4IlAX16qpYsEW8TWE/HLreeqoM0WDoI6rF9qnTpV2KWqBf +ziMYkfSkT7hQ2bWc493lW+QwSxCsuBsDwlrCwAl6jFSf1+jEQx98/8n9rDNyD9dL +PvECmtF30WY98nwZ9/kO2DufQrd0mwSHcIT0pAwl5fimpkwTjj+TTbytO3M4jK5L +tuIzsViQ95BmJQ3XuLdkQ/Ug8rpECYRX5fQX1qXkkvl920ohpKqKyEji1OmfmJ0Z +tZChaEcu3Mp3U+gD4az2ogmle3i/Phz8ZEPFo4/21G5Qd72z0lBgaQIeyyCk5MHt +Yg0vA7X0/w4bz+OJv5tf7zJsPCYSprr+c/7YUJk9Fqu6+g9ZAavI99xFKdGhz4Og +w0trnKNCxYc6+NPopTDbXuY+fo4DK7C0CSae5sKs7013Ne6w4KvgfLKpvlemkGfg +ZA3+1FMXVfFIEH7Cw9cx6F02Sr3k1VrU68oM3wH5nvTUkELOf8nRMlzliQjVCpKB +yFSe9dzRVSFEbMDxChiEulGgNUHj/6wwpg0ZmCwPRHutppT3jkfEqizN5iHb69GH +k6kol6knJofkaL656Q3Oc9o0ZrMlFh1RwmOvAk5fVK0/CV88/phROz2Wdmy5Bz4a +t0vzqFWA54y6+9EEVoOk9SU0CYfpGtpX4URjLK1EUG/l+RR3366Uee6TPrtEZ9cg +56VQMxhSaRNAvJ6DfiSuscSCNJzwuXaMXSZydGYnnP9Tb9p6c1uy1sXdluZkBIcK +CgC+gdDMSNlDn9ghc4xZGkuA8bjzfAYuRuGKmfTt8uuklkjw2b9w3SHjC4/Cmd2W +cFRnzfg2oL6e78hNg2ZGgsLzvb6Lu6/5IhXCO7RitzYf2+HLBbc+YLFsnG3qeGe1 +28yGnXOQd97Cr4+IzFucVy/33gMQkesNUSDFJSq1gE/hGrMgTTMQJ7yC3PRqg0kG +tpqTyKNdM0g1adxlR1qfDPvpUBApkgBbySnMyWEr5+tBuoHUtH2m49oV9YD4odMJ +yJjlGxituO/YNN6O8oANlraG1Q== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIJBzCCBO+gAwIBAgIJAN7WS1mRS9A+MA0GCSqGSIb3DQEBCwUAMBoxGDAWBgNV +BAMMD3Bvbnl0b3duIFJTQSBDQTAeFw0xNjEyMTAxNzQyMzNaFw0yNjEyMDgxNzQy +MzNaMBoxGDAWBgNVBAMMD3Bvbnl0b3duIFJTQSBDQTCCBCIwDQYJKoZIhvcNAQEB +BQADggQPADCCBAoCggQBAMNEzJ7aNdD2JSk9+NF9Hh2za9OQnt1d/7j6DtE3ieoT +ms8mMSXzoImXZayZ9Glx3yx/RhEb2vmINyb0vRUM4I/GH+XHdOBcs9kaJNv/Mpw4 +Ggd4e1LUqV1pzNrhYwRrTQTKyaDiDX2WEBNfQaaYnHltmSmsfyt3Klj+IMc6CyqV +q8SOQ6Go414Vn++Jj7p3E6owdwuvSvO8ERLobiA6vYB+qrS7E48c4zRIAFIO4uwt +g4TiCJLLWc1fRSoqGGX7KS+LzQF8Pq67IOHVna4e9peSe6nQnm0LQZAmaosYHvF4 +AX0Bj6TLv9PXCAGtB7Pciev5Br0tRZEdVyYfmwiVKUWcp77TghV3W+VaJVhPh5LN +X91ktvpeYek3uglqv2ZHtSG2S1KkBtTkbMOD+a2BEUfq0c0+BIsj6jdvt4cvIfet +4gUOxCvYMBs4/dmNT1zoe/kJ0lf8YXYLsXwVWdIW3jEE8QdkLtLI9XfyU9OKLZuD +mmoAf7ezvv/T3nKLFqhcwUFGgGtCIX+oWC16XSbDPBcKDBwNZn8C49b7BLdxqAg3 +msfxwhYzSs9F1MXt/h2dh7FVmkCSxtgNDX3NJn5/yT6USws2y0AS5vXVP9hRf0NV +KfKn9XlmHCxnZExwm68uZkUUYHB05jSWFojbfWE+Mf9djUeQ4FuwusztZdbyQ4yS +mMtBXO0I6SQBmjCoOa1ySW3DTuw/eKCfq+PoxqWD434bYA9nUa+pE27MP7GLyjCS +6+ED3MACizSF0YxkcC9pWUo4L5FKp+DxnNbtzMIILnsDZTVHOvKUy/gjTyTWm/+7 +2t98l7vBE8gn3Aux0V5WFe2uZIZ07wIi/OThoBO8mpt9Bm5cJTG07JStKEXX/UH1 +nL7cDZ2V5qbf4hJdDy4qixxxIZtmf//1BRlVQ9iYTOsMoy+36DXWbc3vSmjRefW1 +YENt4zxOPe4LUq2Z+LXq1OgVQrHrVevux0vieys7Rr2gA1sH8FaaNwTr7Q8dq+Av +Evk+iOUH4FuYorU1HuGHPkAkvLWosVwlB+VhfEai0V6+PmttmaOnCJNHfFTu5wCu +B9CFJ1tdzTzAbrLwgtWmO70KV7CfZPHO7lMWhSvplU0i5T9WytxP91IoFtXwRSO8 ++Ghyu0ynB3HywCH2dez89Vy903P6PEU0qTnYWRz6D/wi5+yHHNrm9CilWurs/Qex +kyB7lLD7Cb1JJc8QIFTqT6vj+cids3xd245hUdpFyZTX99YbF6IkiB2zGi5wvUmP +f1GPvkTLb7eF7bne9OClEjEqvc0hVJ2abO2WXkqxlQFEYZHNofm+y6bnby/BZZJo +beaSFcLOCe2Z8iZvVnzfHBCeLyWE89gc94z784S3LEsCAwEAAaNQME4wHQYDVR0O +BBYEFNz2wEPCQbx9OdRCNE4eALwHJfIgMB8GA1UdIwQYMBaAFNz2wEPCQbx9OdRC +NE4eALwHJfIgMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADggQBACbm2YX7 +sBG0Aslj36gmVlCTTluNg2tuK2isHbK3YhNwujrH/o/o2OV7UeUkZkPwE4g4/SjC +OwDWYniRNyDKBOeD9Q0XxR5z5IZQO+pRVvXF8DXO6kygWCOJM9XheKxp9Uke0aDg +m8F02NslKLUdy7piGlLSz1sgdjiE3izIwFZRpZY7sMozNWWvSAmzprbkE78LghIm +VEydQzIQlr5soWqc65uFLNbEA6QBPoFc6dDW+mnzXf8nrZUM03CACxAsuq/YkjRp +OHgwgfdNRdlu4YhZtuQNak4BUvDmigTGxDC+aMJw0ldL1bLtqLG6BvQbyLNPOOfo +5S8lGh4y06gb//052xHaqtCh5Ax5sHUE5By6wKHAKbuJy26qyKfaRoc3Jigs4Fd5 +3CuoDWHbyXfkgKiU+sc+1mvCxQKFRJ2fpGEFP8iEcLvdUae7ZkRM4Kb0vST+QhQV +fDaFkM3Bwqtui5YaZ6cHHQVyXQdujCmfesoZXKil2yduQ3KWgePjewzRV+aDWMzk +qKaF+TRANSqWbBU6JTwwQ4veKQThU3ir7nS2ovdPbhNS/FnWoKodj6eaqXfdYuBh +XOXLewIF568MJsLOuBubeAO2a9LOlhnv6eLGp2P4M7vwEdN/LRRQtwBBmqq8C3h+ +ewrJP12B/ag0bJDi9vCgPhYtDEpjpfsnxZEIqVZwshJ/MqXykFp2kYk62ylyfDWq +veI/aHwpzT2k+4CI/XmPWXl9NlI50HPdpcwCBDy8xVHwb/x7stNgQdIhaj9tzmKa +S+eqitclc8Iqrbd523H//QDzm8yiqRZUdveNa9gioTMErR0ujCpK8tO8mVZcVfNX +i1/Vsar5++nXcPhxKsd1t8XV2dk3gUZIfMgzLLzs+KSiFg+bT3c7LkCd+I3w30Iv +fh9cxFBAyYO9giwxaCfJgoz7OYqaHOOtASF85UV7gK9ELT7/z+RAcS/UfY1xbd54 +hIi1vRZj8lfkAYNtnYlud44joi1BvW/GZGFCiJ13SSvfHNs9v/5xguyCSgyCc0qx +ZkN/fzj/5wFQbxSl3MPn/JrsvlH6wvJht1SA50uVdUvJ5e5V8EgLYfMqlJNNpTHP +wZcHF+Dw126oyu2KhUxD126Gusxp+tV6I0EEZnVwwduFQWq9xm/gT+qohpveeylf +Q2XGz56DF2udJJnSFGSqzQOl9XopNC/4ecBMwIzqdFSpaWgK3VNAcigyDajgoE4v +ZuiVDEiLhLowZvi1V8GOWzcka7R2BQBjhOLWByQGDcm8cOMS7w8oCSQCaYmJyHvE +tTHq7fX6/sXv0AJqM3ysSdU01IVBNahnr5WEkmQMaFF0DGvRfqkVdKcChwrKv7r2 +DLxargy39i2aQGg= +-----END CERTIFICATE----- diff --git a/tokio-rustls/tests/end.rsa b/tokio-rustls/tests/end.rsa new file mode 100644 index 0000000..744bba5 --- /dev/null +++ b/tokio-rustls/tests/end.rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAtWA8+uvu1Q+Ay9fv7FRzEPgWw0YJgzkJV/vpn1vA90DUJv6g +wRG7u0+6qnXPRig2q+FiMwTfEgZMzNRIzBsDZ2T1f0389KlAEdS/+2c39RmhWb1l +4JLMlKX8L5qax9yMv3aZBE8LT7YaUzzZbot4WZ5v+ZJI0UVoPR5iA+DtOC3ekit9 +d79mYKz2W5Cr26paoVJ2Foy925SlRm6VyLLmcxFPT4foB7T3BIWaRbDOni3vtk1b +HCK8nDBDk16gAJjus2fbEFrX5ZbtwhJZvdOB+vowP/S3oGpe9qTJpGK/MFp4qgw7 +oo17gj5VceRAxRcsQZltfFHgxEGild5DX3pWkwIDAQABAoIBAFDTazlSbGML/pRY +TTWeyIw2UkaA7npIr45C13BJfitw+1nJPK/tDCDDveZ6i3yzLPHZhV5A/HtWzWC1 +9R7nptOrnO83PNN2nPOVQFxzOe+ClXGdQkoagQp5EXHRTspj0WD9I+FUrDDAcOjJ +BAgMJPyi6zlnZAXGDVa3NGyQDoZqwU2k36L4rEsJIkG0NVurZhpiCexNkkf32495 +TOINQ0iKdfJ4iZoEYQ9G+x4NiuAJRCHuIcH76SNfT+Uv3wX0ut5EFPtflnvtdgcp +QVcoKwYdO0+mgO5xqWlBcsujSvgBdiNAGnAxKHWiEaacuIJi4+yYovyEebP6QI2X +Zg/U2wkCgYEA794dE5CPXLOmv6nioVC/ubOESk7vjSlEka/XFbKr4EY794YEqrB1 +8TUqg09Bn3396AS1e6P2shr3bxos5ybhOxDGSLnJ+aC0tRFjd1BPKnA80vZM7ggt +5cjmdD5Zp0tIQTIAAYU5bONQOwj0ej4PE7lny26eLa5vfvCwlrD+rM0CgYEAwZMN +W/5PA2A+EM08IaHic8my0dCunrNLF890ouZnDG99SbgMGvvEsGIcCP1sai702hNh +VgGDxCz6/HUy+4O4YNFVtjY7uGEpfIEcEI7CsLQRP2ggWEFxThZtnEtO8PbM3J/i +qcS6njHdE+0XuCjgZwGgva5xH2pkWFzw/AIpEN8CgYB2HOo2axWc8T2n3TCifI+c +EqCOsqXU3cBM+MgxgASQcCUxMkX0AuZguuxPMmS+85xmdoMi+c8NTqgOhlYcEJIR +sqXgw9OH3zF8g6513w7Md+4Ld4rUHyTypGWOUfF1pmVS7RsBpKdtTdWA7FzuIMbt +0HsiujqbheyTFlPuMAOH9QKBgBWS1gJSrWuq5j/pH7J/4EUXTZ6kq1F0mgHlVRJy +qzlvk38LzA2V0a32wTkfRV3wLcnALzDuqkjK2o4YYb42R+5CZlMQaEd8TKtbmE0g +HAKljuaKLFCpun8BcOXiXsHsP5i3GQPisQnAdOsrmWEk7R2NyORa9LCToutWMGVl +uD3xAoGAA183Vldm+m4KPsKS17t8MbwBryDXvowGzruh/Z+PGA0spr+ke4XxwT1y +kMMP1+5flzmjlAf4+W8LehKuVqvQoMlPn5UVHmSxQ7cGx/O/o6Gbn8Q25/6UT+sM +B1Y0rlLoKG62pnkeXp1O4I57gnClatWRg5qw11a8V8e3jvDKIYM= +-----END RSA PRIVATE KEY----- diff --git a/tokio-rustls/tests/test.rs b/tokio-rustls/tests/test.rs new file mode 100644 index 0000000..9b98688 --- /dev/null +++ b/tokio-rustls/tests/test.rs @@ -0,0 +1,128 @@ +use std::{ io, thread }; +use std::io::{ BufReader, Cursor }; +use std::sync::Arc; +use std::sync::mpsc::channel; +use std::net::SocketAddr; +use futures_util::future::TryFutureExt; +use lazy_static::lazy_static; +use tokio::prelude::*; +use tokio::runtime; +use tokio::io::{ copy, split }; +use tokio::net::{ TcpListener, TcpStream }; +use rustls::{ ServerConfig, ClientConfig }; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::{ TlsConnector, TlsAcceptor }; + +const CERT: &str = include_str!("end.cert"); +const CHAIN: &str = include_str!("end.chain"); +const RSA: &str = include_str!("end.rsa"); + +lazy_static!{ + static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, keys.pop().unwrap()) + .expect("invalid key or certificate"); + let acceptor = TlsAcceptor::from(Arc::new(config)); + + let (send, recv) = channel(); + + thread::spawn(move || { + let mut runtime = runtime::Builder::new() + .basic_scheduler() + .enable_io() + .build() + .unwrap(); + + let handle = runtime.handle().clone(); + + let done = async move { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let mut listener = TcpListener::bind(&addr).await?; + + send.send(listener.local_addr()?).unwrap(); + + loop { + let (stream, _) = listener.accept().await?; + + let acceptor = acceptor.clone(); + let fut = async move { + let stream = acceptor.accept(stream).await?; + + let (mut reader, mut writer) = split(stream); + copy(&mut reader, &mut writer).await?; + + Ok(()) as io::Result<()> + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + + handle.spawn(fut); + } + }.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err)); + + runtime.block_on(done); + }); + + let addr = recv.recv().unwrap(); + (addr, "testserver.com", CHAIN) + }; +} + +fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { + &*TEST_SERVER +} + +async fn start_client(addr: SocketAddr, domain: &str, config: Arc) -> io::Result<()> { + const FILE: &'static [u8] = include_bytes!("../README.md"); + + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); + let config = TlsConnector::from(config); + let mut buf = vec![0; FILE.len()]; + + let stream = TcpStream::connect(&addr).await?; + let mut stream = config.connect(domain, stream).await?; + stream.write_all(FILE).await?; + stream.flush().await?; + stream.read_exact(&mut buf).await?; + + assert_eq!(buf, FILE); + + Ok(()) +} + +#[tokio::test] +async fn pass() -> io::Result<()> { + let (addr, domain, chain) = start_server(); + + // TODO: not sure how to resolve this right now but since + // TcpStream::bind now returns a future it creates a race + // condition until its ready sometimes. + use std::time::*; + tokio::time::delay_for(Duration::from_secs(1)).await; + + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + + start_client(addr.clone(), domain, config.clone()).await?; + + Ok(()) +} + +#[tokio::test] +async fn fail() -> io::Result<()> { + let (addr, domain, chain) = start_server(); + + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + + assert_ne!(domain, &"google.com"); + let ret = start_client(addr.clone(), "google.com", config).await; + assert!(ret.is_err()); + + Ok(()) +}