commit 43c85779ca8bc1e30d50dc4875f70e5beb7e8b6f Author: Lucio Franco Date: Thu Jan 9 18:36:35 2020 -0500 Initial commit diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000..836a506 --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,72 @@ +name: CI + +on: [push, pull_request] + +jobs: + check: + runs-on: ubuntu-latest + + env: + RUSTFLAGS: "-D warnings" + + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - name: Run cargo check + uses: actions-rs/cargo@v1 + with: + command: check --all --all-features --all-targets + + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macOS-latest, windows-latest] + rust: [stable] + + env: + RUSTFLAGS: "-D warnings" + + steps: + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.rust }} + profile: minimal + - uses: actions/checkout@master + - name: Test + run: cargo test --all --all-features + + lints: + name: Lints + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + components: rustfmt, clippy + + - name: Run cargo fmt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + + - name: Run cargo clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + args: -- -D warnings diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6936990 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +**/*.rs.bk +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..266de47 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,4 @@ +[workspace] +members = [ + "tokio-native-tls" +] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cdb28b4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..58165ca --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# Tokio Tls + +## Overview + +This crate contains a collection of Tokio based TLS libraries. + +- [`tokio-native-tls`](tokio-native-tls) + +## Getting Help + +First, see if the answer to your question can be found in the [Guides] or the +[API documentation]. If the answer is not there, there is an active community in +the [Tokio Discord server][chat]. We would be happy to try to answer your +question. Last, if that doesn't work, try opening an [issue] with the question. + +[Guides]: https://tokio.rs/docs/ +[API documentation]: https://docs.rs/tokio/latest/tokio +[chat]: https://discord.gg/tokio +[issue]: https://github.com/tokio-rs/tls/issues/new + +## Contributing + +:balloon: Thanks for your help improving the project! We are so happy to have +you! We have a [contributing guide][guide] to help you get involved in the Tokio +project. + +[guide]: CONTRIBUTING.md + +## Related Projects + +In addition to the crates in this repository, the Tokio project also maintains +several other libraries, including: + +* [`tracing`] (formerly `tokio-trace`): A framework for application-level + tracing and async-aware diagnostics. + +* [`mio`]: A low-level, cross-platform abstraction over OS I/O APIs that powers + `tokio`. + +* [`bytes`]: Utilities for working with bytes, including efficient byte buffers. + +[`tokio`]: https://github.com/tokio-rs/tokio +[`tracing`]: https://github.com/tokio-rs/tracing +[`mio`]: https://github.com/tokio-rs/mio +[`bytes`]: https://github.com/tokio-rs/bytes + +## Supported Rust Versions + +Tokio is built against the latest stable, nightly, and beta Rust releases. The +minimum version supported is the stable release from three months before the +current stable release version. For example, if the latest stable Rust is 1.29, +the minimum version supported is 1.26. The current Tokio version is not +guaranteed to build on Rust versions earlier than the minimum supported version. + +## License + +This project is licensed under the [MIT license](LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tokio by you, shall be licensed as MIT, without any additional +terms or conditions. diff --git a/tokio-native-tls/CHANGELOG.md b/tokio-native-tls/CHANGELOG.md new file mode 100644 index 0000000..bd1aa9d --- /dev/null +++ b/tokio-native-tls/CHANGELOG.md @@ -0,0 +1,3 @@ +# 0.1.0 (January 9th, 2019) + +- Initial release from `tokio-tls 0.3` diff --git a/tokio-native-tls/Cargo.toml b/tokio-native-tls/Cargo.toml new file mode 100644 index 0000000..87e3d3e --- /dev/null +++ b/tokio-native-tls/Cargo.toml @@ -0,0 +1,60 @@ +[package] +name = "tokio-native-tls" +# When releasing to crates.io: +# - Remove path dependencies +# - Update html_root_url. +# - Update doc url +# - Cargo.toml +# - README.md +# - Update CHANGELOG.md. +# - Create "v0.1.x" git tag. +version = "0.1.0" +edition = "2018" +authors = ["Tokio Contributors "] +license = "MIT" +repository = "https://github.com/tokio-rs/tls" +homepage = "https://tokio.rs" +documentation = "https://docs.rs/tokio-native-tls/0.1.0/tokio_native_tls/" +description = """ +An implementation of TLS/SSL streams for Tokio using native-tls giving an implementation of TLS +for nonblocking I/O streams. +""" +categories = ["asynchronous", "network-programming"] + +[dependencies] +native-tls = "0.2" +tokio = { version = "0.2.0" } + +[dev-dependencies] +tokio = { version = "0.2.0", features = ["macros", "stream", "rt-core", "io-util", "net"] } +tokio-util = { version = "0.2.0", features = ["full"] } + +cfg-if = "0.1" +env_logger = { version = "0.6", default-features = false } +futures = { version = "0.3.0", features = ["async-await"] } + +[target.'cfg(all(not(target_os = "macos"), not(windows), not(target_os = "ios")))'.dev-dependencies] +openssl = "0.10" + +[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dev-dependencies] +security-framework = "0.2" + +[target.'cfg(windows)'.dev-dependencies] +schannel = "0.1" + +[target.'cfg(windows)'.dev-dependencies.winapi] +version = "0.3" +features = [ + "lmcons", + "basetsd", + "minwinbase", + "minwindef", + "ntdef", + "sysinfoapi", + "timezoneapi", + "wincrypt", + "winerror", +] + +[package.metadata.docs.rs] +all-features = true diff --git a/tokio-native-tls/LICENSE b/tokio-native-tls/LICENSE new file mode 100644 index 0000000..cdb28b4 --- /dev/null +++ b/tokio-native-tls/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/tokio-native-tls/README.md b/tokio-native-tls/README.md new file mode 100644 index 0000000..455612b --- /dev/null +++ b/tokio-native-tls/README.md @@ -0,0 +1,14 @@ +# tokio-tls + +An implementation of TLS/SSL streams for Tokio built on top of the [`native-tls` +crate] + +## License + +This project is licensed under the [MIT license](./LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tokio by you, shall be licensed as MIT, without any additional +terms or conditions. diff --git a/tokio-native-tls/examples/download-rust-lang.rs b/tokio-native-tls/examples/download-rust-lang.rs new file mode 100644 index 0000000..6f864c3 --- /dev/null +++ b/tokio-native-tls/examples/download-rust-lang.rs @@ -0,0 +1,39 @@ +// #![warn(rust_2018_idioms)] + +use native_tls::TlsConnector; +use std::error::Error; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "www.rust-lang.org:443" + .to_socket_addrs()? + .next() + .ok_or("failed to resolve www.rust-lang.org")?; + + let socket = TcpStream::connect(&addr).await?; + let cx = TlsConnector::builder().build()?; + let cx = tokio_native_tls::TlsConnector::from(cx); + + let mut socket = cx.connect("www.rust-lang.org", socket).await?; + + socket + .write_all( + "\ + GET / HTTP/1.0\r\n\ + Host: www.rust-lang.org\r\n\ + \r\n\ + " + .as_bytes(), + ) + .await?; + + let mut data = Vec::new(); + socket.read_to_end(&mut data).await?; + + // println!("data: {:?}", &data); + println!("{}", String::from_utf8_lossy(&data[..])); + Ok(()) +} diff --git a/tokio-native-tls/examples/echo.rs b/tokio-native-tls/examples/echo.rs new file mode 100644 index 0000000..2887c6f --- /dev/null +++ b/tokio-native-tls/examples/echo.rs @@ -0,0 +1,54 @@ +#![warn(rust_2018_idioms)] + +// A tiny async TLS echo server with Tokio +use native_tls; +use native_tls::Identity; +use tokio; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +/** +an example to setup a tls server. +how to test: +wget https://127.0.0.1:12345 --no-check-certificate +*/ +#[tokio::main] +async fn main() -> Result<(), Box> { + // Bind the server's socket + let addr = "127.0.0.1:12345".to_string(); + let mut tcp: TcpListener = TcpListener::bind(&addr).await?; + + // Create the TLS acceptor. + let der = include_bytes!("identity.p12"); + let cert = Identity::from_pkcs12(der, "mypass")?; + let tls_acceptor = + tokio_native_tls::TlsAcceptor::from(native_tls::TlsAcceptor::builder(cert).build()?); + loop { + // Asynchronously wait for an inbound socket. + let (socket, remote_addr) = tcp.accept().await?; + let tls_acceptor = tls_acceptor.clone(); + println!("accept connection from {}", remote_addr); + tokio::spawn(async move { + // Accept the TLS connection. + let mut tls_stream = tls_acceptor.accept(socket).await.expect("accept error"); + // In a loop, read data from the socket and write the data back. + + let mut buf = [0; 1024]; + let n = tls_stream + .read(&mut buf) + .await + .expect("failed to read data from socket"); + + if n == 0 { + return; + } + println!("read={}", unsafe { + String::from_utf8_unchecked(buf[0..n].into()) + }); + tls_stream + .write_all(&buf[0..n]) + .await + .expect("failed to write data to socket"); + }); + } +} diff --git a/tokio-native-tls/examples/identity.p12 b/tokio-native-tls/examples/identity.p12 new file mode 100644 index 0000000..d16abb8 Binary files /dev/null and b/tokio-native-tls/examples/identity.p12 differ diff --git a/tokio-native-tls/src/lib.rs b/tokio-native-tls/src/lib.rs new file mode 100644 index 0000000..2770650 --- /dev/null +++ b/tokio-native-tls/src/lib.rs @@ -0,0 +1,361 @@ +#![doc(html_root_url = "https://docs.rs/tokio-tls/0.3.0")] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![deny(intra_doc_link_resolution_failure)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] + +//! Async TLS streams +//! +//! This library is an implementation of TLS streams using the most appropriate +//! system library by default for negotiating the connection. That is, on +//! Windows this library uses SChannel, on OSX it uses SecureTransport, and on +//! other platforms it uses OpenSSL. +//! +//! Each TLS stream implements the `Read` and `Write` traits to interact and +//! interoperate with the rest of the futures I/O ecosystem. Client connections +//! initiated from this crate verify hostnames automatically and by default. +//! +//! This crate primarily exports this ability through two newtypes, +//! `TlsConnector` and `TlsAcceptor`. These newtypes augment the +//! functionality provided by the `native-tls` crate, on which this crate is +//! built. Configuration of TLS parameters is still primarily done through the +//! `native-tls` crate. + +use tokio::io::{AsyncRead, AsyncWrite}; + +use native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; +use std::fmt; +use std::future::Future; +use std::io::{self, Read, Write}; +use std::marker::Unpin; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::ptr::null_mut; +use std::task::{Context, Poll}; + +#[derive(Debug)] +struct AllowStd { + inner: S, + context: *mut (), +} + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +/// +/// A `TlsStream` represents a handshake that has been completed successfully +/// and both the server and the client are ready for receiving and sending +/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written +/// to a `TlsStream` are encrypted when passing through to `S`. +#[derive(Debug)] +pub struct TlsStream(native_tls::TlsStream>); + +/// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` +/// method. +#[derive(Clone)] +pub struct TlsConnector(native_tls::TlsConnector); + +/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` +/// method. +#[derive(Clone)] +pub struct TlsAcceptor(native_tls::TlsAcceptor); + +struct MidHandshake(Option>>); + +enum StartedHandshake { + Done(TlsStream), + Mid(MidHandshakeTlsStream>), +} + +struct StartedHandshakeFuture(Option>); +struct StartedHandshakeFutureInner { + f: F, + stream: S, +} + +struct Guard<'a, S>(&'a mut TlsStream) +where + AllowStd: Read + Write; + +impl Drop for Guard<'_, S> +where + AllowStd: Read + Write, +{ + fn drop(&mut self) { + (self.0).0.get_mut().context = null_mut(); + } +} + +// *mut () context is neither Send nor Sync +unsafe impl Send for AllowStd {} +unsafe impl Sync for AllowStd {} + +impl AllowStd +where + S: Unpin, +{ + fn with_context(&mut self, f: F) -> R + where + F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, + { + unsafe { + assert!(!self.context.is_null()); + let waker = &mut *(self.context as *mut _); + f(waker, Pin::new(&mut self.inner)) + } + } +} + +impl Read for AllowStd +where + S: AsyncRead + Unpin, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } +} + +impl Write for AllowStd +where + S: AsyncWrite + Unpin, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } +} + +fn cvt(r: io::Result) -> Poll> { + match r { + Ok(v) => Poll::Ready(Ok(v)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } +} + +impl TlsStream { + fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R + where + F: FnOnce(&mut native_tls::TlsStream>) -> R, + AllowStd: Read + Write, + { + self.0.get_mut().context = ctx as *mut _ as *mut (); + let g = Guard(self); + f(&mut (g.0).0) + } + + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S + where + S: AsyncRead + AsyncWrite + Unpin, + { + &self.0.get_ref().inner + } + + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S + where + S: AsyncRead + AsyncWrite + Unpin, + { + &mut self.0.get_mut().inner + } +} + +impl AsyncRead for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { + // Note that this does not forward to `S` because the buffer is + // unconditionally filled in by OpenSSL, not the actual object `S`. + // We're decrypting bytes from `S` into the buffer above! + false + } + + fn poll_read( + mut self: Pin<&mut Self>, + ctx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.with_context(ctx, |s| cvt(s.read(buf))) + } +} + +impl AsyncWrite for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + ctx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.with_context(ctx, |s| cvt(s.write(buf))) + } + + fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + self.with_context(ctx, |s| cvt(s.flush())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + match self.with_context(ctx, |s| s.shutdown()) { + Ok(()) => Poll::Ready(Ok(())), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } + } +} + +async fn handshake(f: F, stream: S) -> Result, Error> +where + F: FnOnce( + AllowStd, + ) -> Result>, HandshakeError>> + + Unpin, + S: AsyncRead + AsyncWrite + Unpin, +{ + let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); + + match start.await { + Err(e) => Err(e), + Ok(StartedHandshake::Done(s)) => Ok(s), + Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await, + } +} + +impl Future for StartedHandshakeFuture +where + F: FnOnce( + AllowStd, + ) -> Result>, HandshakeError>> + + Unpin, + S: Unpin, + AllowStd: Read + Write, +{ + type Output = Result, Error>; + + fn poll( + mut self: Pin<&mut Self>, + ctx: &mut Context<'_>, + ) -> Poll, Error>> { + let inner = self.0.take().expect("future polled after completion"); + let stream = AllowStd { + inner: inner.stream, + context: ctx as *mut _ as *mut (), + }; + + match (inner.f)(stream) { + Ok(mut s) => { + s.get_mut().context = null_mut(); + Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) + } + Err(HandshakeError::WouldBlock(mut s)) => { + s.get_mut().context = null_mut(); + Poll::Ready(Ok(StartedHandshake::Mid(s))) + } + Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), + } + } +} + +impl TlsConnector { + /// Connects the provided stream with this connector, assuming the provided + /// domain. + /// + /// This function will internally call `TlsConnector::connect` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used for clients who have already established, for + /// example, a TCP connection to a remote server. That stream is then + /// provided here to perform the client half of a connection to a + /// TLS-powered server. + pub async fn connect(&self, domain: &str, stream: S) -> Result, Error> + where + S: AsyncRead + AsyncWrite + Unpin, + { + handshake(move |s| self.0.connect(domain, s), stream).await + } +} + +impl fmt::Debug for TlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConnector").finish() + } +} + +impl From for TlsConnector { + fn from(inner: native_tls::TlsConnector) -> TlsConnector { + TlsConnector(inner) + } +} + +impl TlsAcceptor { + /// Accepts a new client connection with the provided stream. + /// + /// This function will internally call `TlsAcceptor::accept` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used after a new socket has been accepted from a + /// `TcpListener`. That socket is then passed to this function to perform + /// the server half of accepting a client connection. + pub async fn accept(&self, stream: S) -> Result, Error> + where + S: AsyncRead + AsyncWrite + Unpin, + { + handshake(move |s| self.0.accept(s), stream).await + } +} + +impl fmt::Debug for TlsAcceptor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsAcceptor").finish() + } +} + +impl From for TlsAcceptor { + fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { + TlsAcceptor(inner) + } +} + +impl Future for MidHandshake { + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut_self = self.get_mut(); + let mut s = mut_self.0.take().expect("future polled after completion"); + + s.get_mut().context = cx as *mut _ as *mut (); + match s.handshake() { + Ok(stream) => Poll::Ready(Ok(TlsStream(stream))), + Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), + Err(HandshakeError::WouldBlock(mut s)) => { + s.get_mut().context = null_mut(); + mut_self.0 = Some(s); + Poll::Pending + } + } + } +} diff --git a/tokio-native-tls/tests/bad.rs b/tokio-native-tls/tests/bad.rs new file mode 100644 index 0000000..55e39d6 --- /dev/null +++ b/tokio-native-tls/tests/bad.rs @@ -0,0 +1,123 @@ +#![warn(rust_2018_idioms)] + +use cfg_if::cfg_if; +use env_logger; +use native_tls::TlsConnector; +use std::io::{self, Error}; +use std::net::ToSocketAddrs; +use tokio::net::TcpStream; + +macro_rules! t { + ($e:expr) => { + match $e { + Ok(e) => e, + Err(e) => panic!("{} failed with {:?}", stringify!($e), e), + } + }; +} + +cfg_if! { + if #[cfg(feature = "force-rustls")] { + fn verify_failed(err: &Error, s: &str) { + let err = err.to_string(); + assert!(err.contains(s), "bad error: {}", err); + } + + fn assert_expired_error(err: &Error) { + verify_failed(err, "CertExpired"); + } + + fn assert_wrong_host(err: &Error) { + verify_failed(err, "CertNotValidForName"); + } + + fn assert_self_signed(err: &Error) { + verify_failed(err, "UnknownIssuer"); + } + + fn assert_untrusted_root(err: &Error) { + verify_failed(err, "UnknownIssuer"); + } + } else if #[cfg(any(feature = "force-openssl", + all(not(target_os = "macos"), + not(target_os = "windows"), + not(target_os = "ios"))))] { + fn verify_failed(err: &Error) { + assert!(format!("{}", err).contains("certificate verify failed")) + } + + use verify_failed as assert_expired_error; + use verify_failed as assert_wrong_host; + use verify_failed as assert_self_signed; + use verify_failed as assert_untrusted_root; + } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { + + fn assert_invalid_cert_chain(err: &Error) { + assert!(format!("{}", err).contains("was not trusted.")) + } + + use crate::assert_invalid_cert_chain as assert_expired_error; + use crate::assert_invalid_cert_chain as assert_wrong_host; + use crate::assert_invalid_cert_chain as assert_self_signed; + use crate::assert_invalid_cert_chain as assert_untrusted_root; + } else { + fn assert_expired_error(err: &Error) { + let s = err.to_string(); + assert!(s.contains("system clock"), "error = {:?}", s); + } + + fn assert_wrong_host(err: &Error) { + let s = err.to_string(); + assert!(s.contains("CN name"), "error = {:?}", s); + } + + fn assert_self_signed(err: &Error) { + let s = err.to_string(); + assert!(s.contains("root certificate which is not trusted"), "error = {:?}", s); + } + + use assert_self_signed as assert_untrusted_root; + } +} + +async fn get_host(host: &'static str) -> Error { + drop(env_logger::try_init()); + + let addr = format!("{}:443", host); + let addr = t!(addr.to_socket_addrs()).next().unwrap(); + + let socket = t!(TcpStream::connect(&addr).await); + let builder = TlsConnector::builder(); + let cx = t!(builder.build()); + let cx = tokio_native_tls::TlsConnector::from(cx); + let res = cx + .connect(host, socket) + .await + .map_err(|e| Error::new(io::ErrorKind::Other, e)); + + assert!(res.is_err()); + res.err().unwrap() +} + +#[tokio::test] +async fn expired() { + assert_expired_error(&get_host("expired.badssl.com").await) +} + +// TODO: the OSX builders on Travis apparently fail this tests spuriously? +// passes locally though? Seems... bad! +#[tokio::test] +#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)] +async fn wrong_host() { + assert_wrong_host(&get_host("wrong.host.badssl.com").await) +} + +#[tokio::test] +async fn self_signed() { + assert_self_signed(&get_host("self-signed.badssl.com").await) +} + +#[tokio::test] +async fn untrusted_root() { + assert_untrusted_root(&get_host("untrusted-root.badssl.com").await) +} diff --git a/tokio-native-tls/tests/google.rs b/tokio-native-tls/tests/google.rs new file mode 100644 index 0000000..57f1e93 --- /dev/null +++ b/tokio-native-tls/tests/google.rs @@ -0,0 +1,101 @@ +#![warn(rust_2018_idioms)] + +use cfg_if::cfg_if; +use env_logger; +use native_tls; +use native_tls::TlsConnector; +use std::io; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +macro_rules! t { + ($e:expr) => { + match $e { + Ok(e) => e, + Err(e) => panic!("{} failed with {:?}", stringify!($e), e), + } + }; +} + +cfg_if! { + if #[cfg(feature = "force-rustls")] { + fn assert_bad_hostname_error(err: &io::Error) { + let err = err.to_string(); + assert!(err.contains("CertNotValidForName"), "bad error: {}", err); + } + } else if #[cfg(any(feature = "force-openssl", + all(not(target_os = "macos"), + not(target_os = "windows"), + not(target_os = "ios"))))] { + fn assert_bad_hostname_error(err: &io::Error) { + let err = err.get_ref().unwrap(); + let err = err.downcast_ref::().unwrap(); + assert!(format!("{}", err).contains("certificate verify failed")); + } + } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { + fn assert_bad_hostname_error(err: &io::Error) { + let err = err.get_ref().unwrap(); + let err = err.downcast_ref::().unwrap(); + assert!(format!("{}", err).contains("was not trusted.")); + } + } else { + fn assert_bad_hostname_error(err: &io::Error) { + let err = err.get_ref().unwrap(); + let err = err.downcast_ref::().unwrap(); + assert!(format!("{}", err).contains("CN name")); + } + } +} + +#[tokio::test] +async fn fetch_google() { + drop(env_logger::try_init()); + + // First up, resolve google.com + let addr = t!("google.com:443".to_socket_addrs()).next().unwrap(); + + let socket = TcpStream::connect(&addr).await.unwrap(); + + // Send off the request by first negotiating an SSL handshake, then writing + // of our request, then flushing, then finally read off the response. + let builder = TlsConnector::builder(); + let connector = t!(builder.build()); + let connector = tokio_native_tls::TlsConnector::from(connector); + let mut socket = t!(connector.connect("google.com", socket).await); + t!(socket.write_all(b"GET / HTTP/1.0\r\n\r\n").await); + let mut data = Vec::new(); + t!(socket.read_to_end(&mut data).await); + + // any response code is fine + assert!(data.starts_with(b"HTTP/1.0 ")); + + let data = String::from_utf8_lossy(&data); + let data = data.trim_end(); + assert!(data.ends_with("") || data.ends_with("")); +} + +fn native2io(e: native_tls::Error) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} + +// see comment in bad.rs for ignore reason +#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)] +#[tokio::test] +async fn wrong_hostname_error() { + drop(env_logger::try_init()); + + let addr = t!("google.com:443".to_socket_addrs()).next().unwrap(); + + let socket = t!(TcpStream::connect(&addr).await); + let builder = TlsConnector::builder(); + let connector = t!(builder.build()); + let connector = tokio_native_tls::TlsConnector::from(connector); + let res = connector + .connect("rust-lang.org", socket) + .await + .map_err(native2io); + + assert!(res.is_err()); + assert_bad_hostname_error(&res.err().unwrap()); +} diff --git a/tokio-native-tls/tests/smoke.rs b/tokio-native-tls/tests/smoke.rs new file mode 100644 index 0000000..baa7fb7 --- /dev/null +++ b/tokio-native-tls/tests/smoke.rs @@ -0,0 +1,628 @@ +#![warn(rust_2018_idioms)] + +use cfg_if::cfg_if; +use env_logger; +use futures::join; +use native_tls; +use native_tls::{Identity, TlsAcceptor, TlsConnector}; +use std::io::Write; +use std::marker::Unpin; +use std::process::Command; +use std::ptr; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::stream::StreamExt; + +macro_rules! t { + ($e:expr) => { + match $e { + Ok(e) => e, + Err(e) => panic!("{} failed with {:?}", stringify!($e), e), + } + }; +} + +#[allow(dead_code)] +struct Keys { + cert_der: Vec, + pkey_der: Vec, + pkcs12_der: Vec, +} + +#[allow(dead_code)] +fn openssl_keys() -> &'static Keys { + static INIT: Once = Once::new(); + static mut KEYS: *mut Keys = ptr::null_mut(); + + INIT.call_once(|| { + let path = t!(env::current_exe()); + let path = path.parent().unwrap(); + let keyfile = path.join("test.key"); + let certfile = path.join("test.crt"); + let config = path.join("openssl.config"); + + File::create(&config) + .unwrap() + .write_all( + b"\ + [req]\n\ + distinguished_name=dn\n\ + [ dn ]\n\ + CN=localhost\n\ + [ ext ]\n\ + basicConstraints=CA:FALSE,pathlen:0\n\ + subjectAltName = @alt_names + extendedKeyUsage=serverAuth,clientAuth + [alt_names] + DNS.1 = localhost + ", + ) + .unwrap(); + + let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; + let output = t!(Command::new("openssl") + .arg("req") + .arg("-nodes") + .arg("-x509") + .arg("-newkey") + .arg("rsa:2048") + .arg("-config") + .arg(&config) + .arg("-extensions") + .arg("ext") + .arg("-subj") + .arg(subj) + .arg("-keyout") + .arg(&keyfile) + .arg("-out") + .arg(&certfile) + .arg("-days") + .arg("1") + .output()); + assert!(output.status.success()); + + let crtout = t!(Command::new("openssl") + .arg("x509") + .arg("-outform") + .arg("der") + .arg("-in") + .arg(&certfile) + .output()); + assert!(crtout.status.success()); + let keyout = t!(Command::new("openssl") + .arg("rsa") + .arg("-outform") + .arg("der") + .arg("-in") + .arg(&keyfile) + .output()); + assert!(keyout.status.success()); + + let pkcs12out = t!(Command::new("openssl") + .arg("pkcs12") + .arg("-export") + .arg("-nodes") + .arg("-inkey") + .arg(&keyfile) + .arg("-in") + .arg(&certfile) + .arg("-password") + .arg("pass:foobar") + .output()); + assert!(pkcs12out.status.success()); + + let keys = Box::new(Keys { + cert_der: crtout.stdout, + pkey_der: keyout.stdout, + pkcs12_der: pkcs12out.stdout, + }); + unsafe { + KEYS = Box::into_raw(keys); + } + }); + unsafe { &*KEYS } +} + +cfg_if! { + if #[cfg(feature = "rustls")] { + use webpki; + use untrusted; + use std::env; + use std::fs::File; + use std::process::Command; + use std::sync::Once; + + use untrusted::Input; + use webpki::trust_anchor_util; + + fn server_cx() -> io::Result { + let mut cx = ServerContext::new(); + + let (cert, key) = keys(); + cx.config_mut() + .set_single_cert(vec![cert.to_vec()], key.to_vec()); + + Ok(cx) + } + + fn configure_client(cx: &mut ClientContext) { + let (cert, _key) = keys(); + let cert = Input::from(cert); + let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap(); + cx.config_mut().root_store.add_trust_anchors(&[anchor]); + } + + // Like OpenSSL we generate certificates on the fly, but for OSX we + // also have to put them into a specific keychain. We put both the + // certificates and the keychain next to our binary. + // + // Right now I don't know of a way to programmatically create a + // self-signed certificate, so we just fork out to the `openssl` binary. + fn keys() -> (&'static [u8], &'static [u8]) { + static INIT: Once = Once::new(); + static mut KEYS: *mut (Vec, Vec) = ptr::null_mut(); + + INIT.call_once(|| { + let (key, cert) = openssl_keys(); + let path = t!(env::current_exe()); + let path = path.parent().unwrap(); + let keyfile = path.join("test.key"); + let certfile = path.join("test.crt"); + let config = path.join("openssl.config"); + + File::create(&config).unwrap().write_all(b"\ + [req]\n\ + distinguished_name=dn\n\ + [ dn ]\n\ + CN=localhost\n\ + [ ext ]\n\ + basicConstraints=CA:FALSE,pathlen:0\n\ + subjectAltName = @alt_names + [alt_names] + DNS.1 = localhost + ").unwrap(); + + let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost"; + let output = t!(Command::new("openssl") + .arg("req") + .arg("-nodes") + .arg("-x509") + .arg("-newkey").arg("rsa:2048") + .arg("-config").arg(&config) + .arg("-extensions").arg("ext") + .arg("-subj").arg(subj) + .arg("-keyout").arg(&keyfile) + .arg("-out").arg(&certfile) + .arg("-days").arg("1") + .output()); + assert!(output.status.success()); + + let crtout = t!(Command::new("openssl") + .arg("x509") + .arg("-outform").arg("der") + .arg("-in").arg(&certfile) + .output()); + assert!(crtout.status.success()); + let keyout = t!(Command::new("openssl") + .arg("rsa") + .arg("-outform").arg("der") + .arg("-in").arg(&keyfile) + .output()); + assert!(keyout.status.success()); + + let cert = crtout.stdout; + let key = keyout.stdout; + unsafe { + KEYS = Box::into_raw(Box::new((cert, key))); + } + }); + unsafe { + (&(*KEYS).0, &(*KEYS).1) + } + } + } else if #[cfg(any(feature = "force-openssl", + all(not(target_os = "macos"), + not(target_os = "windows"), + not(target_os = "ios"))))] { + use std::fs::File; + use std::env; + use std::sync::Once; + + fn contexts() -> (tokio_native_tls::TlsAcceptor, tokio_native_tls::TlsConnector) { + let keys = openssl_keys(); + + let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); + let srv = TlsAcceptor::builder(pkcs12); + + let cert = t!(native_tls::Certificate::from_der(&keys.cert_der)); + + let mut client = TlsConnector::builder(); + t!(client.add_root_certificate(cert).build()); + + (t!(srv.build()).into(), t!(client.build()).into()) + } + } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { + use std::env; + use std::fs::File; + use std::sync::Once; + + fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { + let keys = openssl_keys(); + + let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); + let srv = TlsAcceptor::builder(pkcs12); + + let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap(); + let mut client = TlsConnector::builder(); + client.add_root_certificate(cert); + + (t!(srv.build()).into(), t!(client.build()).into()) + } + } else { + use schannel; + use winapi; + + use std::env; + use std::fs::File; + use std::io; + use std::mem; + use std::sync::Once; + + use schannel::cert_context::CertContext; + use schannel::cert_store::{CertStore, CertAdd, Memory}; + use winapi::shared::basetsd::*; + use winapi::shared::lmcons::*; + use winapi::shared::minwindef::*; + use winapi::shared::ntdef::WCHAR; + use winapi::um::minwinbase::*; + use winapi::um::sysinfoapi::*; + use winapi::um::timezoneapi::*; + use winapi::um::wincrypt::*; + + const FRIENDLY_NAME: &'static str = "tokio-tls localhost testing cert"; + + fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) { + let cert = localhost_cert(); + let mut store = t!(Memory::new()).into_store(); + t!(store.add_cert(&cert, CertAdd::Always)); + let pkcs12_der = t!(store.export_pkcs12("foobar")); + let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar")); + + let srv = TlsAcceptor::builder(pkcs12); + let client = TlsConnector::builder(); + (t!(srv.build()).into(), t!(client.build()).into()) + } + + // ==================================================================== + // Magic! + // + // Lots of magic is happening here to wrangle certificates for running + // these tests on Windows. For more information see the test suite + // in the schannel-rs crate as this is just coyping that. + // + // The general gist of this though is that the only way to add custom + // trusted certificates is to add it to the system store of trust. To + // do that we go through the whole rigamarole here to generate a new + // self-signed certificate and then insert that into the system store. + // + // This generates some dialogs, so we print what we're doing sometimes, + // and otherwise we just manage the ephemeral certificates. Because + // they're in the system store we always ensure that they're only valid + // for a small period of time (e.g. 1 day). + + fn localhost_cert() -> CertContext { + static INIT: Once = Once::new(); + INIT.call_once(|| { + for cert in local_root_store().certs() { + let name = match cert.friendly_name() { + Ok(name) => name, + Err(_) => continue, + }; + if name != FRIENDLY_NAME { + continue + } + if !cert.is_time_valid().unwrap() { + io::stdout().write_all(br#" + +The tokio-tls test suite is about to delete an old copy of one of its +certificates from your root trust store. This certificate was only valid for one +day and it is no longer needed. The host should be "localhost" and the +description should mention "tokio-tls". + + "#).unwrap(); + cert.delete().unwrap(); + } else { + return + } + } + + install_certificate().unwrap(); + }); + + for cert in local_root_store().certs() { + let name = match cert.friendly_name() { + Ok(name) => name, + Err(_) => continue, + }; + if name == FRIENDLY_NAME { + return cert + } + } + + panic!("couldn't find a cert"); + } + + fn local_root_store() -> CertStore { + if env::var("CI").is_ok() { + CertStore::open_local_machine("Root").unwrap() + } else { + CertStore::open_current_user("Root").unwrap() + } + } + + fn install_certificate() -> io::Result { + unsafe { + let mut provider = 0; + let mut hkey = 0; + + let mut buffer = "tokio-tls test suite".encode_utf16() + .chain(Some(0)) + .collect::>(); + let res = CryptAcquireContextW(&mut provider, + buffer.as_ptr(), + ptr::null_mut(), + PROV_RSA_FULL, + CRYPT_MACHINE_KEYSET); + if res != TRUE { + // create a new key container (since it does not exist) + let res = CryptAcquireContextW(&mut provider, + buffer.as_ptr(), + ptr::null_mut(), + PROV_RSA_FULL, + CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET); + if res != TRUE { + return Err(Error::last_os_error()) + } + } + + // create a new keypair (RSA-2048) + let res = CryptGenKey(provider, + AT_SIGNATURE, + 0x0800<<16 | CRYPT_EXPORTABLE, + &mut hkey); + if res != TRUE { + return Err(Error::last_os_error()); + } + + // start creating the certificate + let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\ + G=tokio_tls".encode_utf16() + .chain(Some(0)) + .collect::>(); + let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed(); + let mut cname_len = cname_buffer.len() as DWORD; + let res = CertStrToNameW(X509_ASN_ENCODING, + name.as_ptr(), + CERT_X500_NAME_STR, + ptr::null_mut(), + cname_buffer.as_mut_ptr() as *mut u8, + &mut cname_len, + ptr::null_mut()); + if res != TRUE { + return Err(Error::last_os_error()); + } + + let mut subject_issuer = CERT_NAME_BLOB { + cbData: cname_len, + pbData: cname_buffer.as_ptr() as *mut u8, + }; + let mut key_provider = CRYPT_KEY_PROV_INFO { + pwszContainerName: buffer.as_mut_ptr(), + pwszProvName: ptr::null_mut(), + dwProvType: PROV_RSA_FULL, + dwFlags: CRYPT_MACHINE_KEYSET, + cProvParam: 0, + rgProvParam: ptr::null_mut(), + dwKeySpec: AT_SIGNATURE, + }; + let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER { + pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _, + Parameters: mem::zeroed(), + }; + let mut expiration_date: SYSTEMTIME = mem::zeroed(); + GetSystemTime(&mut expiration_date); + let mut file_time: FILETIME = mem::zeroed(); + let res = SystemTimeToFileTime(&mut expiration_date, + &mut file_time); + if res != TRUE { + return Err(Error::last_os_error()); + } + let mut timestamp: u64 = file_time.dwLowDateTime as u64 | + (file_time.dwHighDateTime as u64) << 32; + // one day, timestamp unit is in 100 nanosecond intervals + timestamp += (1E9 as u64) / 100 * (60 * 60 * 24); + file_time.dwLowDateTime = timestamp as u32; + file_time.dwHighDateTime = (timestamp >> 32) as u32; + let res = FileTimeToSystemTime(&file_time, + &mut expiration_date); + if res != TRUE { + return Err(Error::last_os_error()); + } + + // create a self signed certificate + let cert_context = CertCreateSelfSignCertificate( + 0 as ULONG_PTR, + &mut subject_issuer, + 0, + &mut key_provider, + &mut sig_algorithm, + ptr::null_mut(), + &mut expiration_date, + ptr::null_mut()); + if cert_context.is_null() { + return Err(Error::last_os_error()); + } + + // TODO: this is.. a terrible hack. Right now `schannel` + // doesn't provide a public method to go from a raw + // cert context pointer to the `CertContext` structure it + // has, so we just fake it here with a transmute. This'll + // probably break at some point, but hopefully by then + // it'll have a method to do this! + struct MyCertContext(T); + impl Drop for MyCertContext { + fn drop(&mut self) {} + } + + let cert_context = MyCertContext(cert_context); + let cert_context: CertContext = mem::transmute(cert_context); + + cert_context.set_friendly_name(FRIENDLY_NAME)?; + + // install the certificate to the machine's local store + io::stdout().write_all(br#" + +The tokio-tls test suite is about to add a certificate to your set of root +and trusted certificates. This certificate should be for the domain "localhost" +with the description related to "tokio-tls". This certificate is only valid +for one day and will be automatically deleted if you re-run the tokio-tls +test suite later. + + "#).unwrap(); + local_root_store().add_cert(&cert_context, + CertAdd::ReplaceExisting)?; + Ok(cert_context) + } + } + } +} + +const AMT: usize = 128 * 1024; + +async fn copy_data(mut w: W) -> Result { + let mut data = vec![9; AMT as usize]; + let mut amt = 0; + while !data.is_empty() { + let written = w.write(&data).await?; + if written <= data.len() { + amt += written; + data.resize(data.len() - written, 0); + } else { + w.write_all(&data).await?; + amt += data.len(); + break; + } + + println!("remaining: {}", data.len()); + } + Ok(amt) +} + +#[tokio::test] +async fn client_to_server() { + drop(env_logger::try_init()); + + // Create a server listening on a port, then figure out what that port is + let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); + let addr = t!(srv.local_addr()); + + let (server_cx, client_cx) = contexts(); + + // Create a future to accept one socket, connect the ssl stream, and then + // read all the data from it. + let server = async move { + let mut incoming = srv.incoming(); + let socket = t!(incoming.next().await.unwrap()); + let mut socket = t!(server_cx.accept(socket).await); + let mut data = Vec::new(); + t!(socket.read_to_end(&mut data).await); + data + }; + + // Create a future to connect to our server, connect the ssl stream, and + // then write a bunch of data to it. + let client = async move { + let socket = t!(TcpStream::connect(&addr).await); + let socket = t!(client_cx.connect("localhost", socket).await); + copy_data(socket).await + }; + + // Finally, run everything! + let (data, _) = join!(server, client); + // assert_eq!(amt, AMT); + assert!(data == vec![9; AMT]); +} + +#[tokio::test] +async fn server_to_client() { + drop(env_logger::try_init()); + + // Create a server listening on a port, then figure out what that port is + let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); + let addr = t!(srv.local_addr()); + + let (server_cx, client_cx) = contexts(); + + let server = async move { + let mut incoming = srv.incoming(); + let socket = t!(incoming.next().await.unwrap()); + let socket = t!(server_cx.accept(socket).await); + copy_data(socket).await + }; + + let client = async move { + let socket = t!(TcpStream::connect(&addr).await); + let mut socket = t!(client_cx.connect("localhost", socket).await); + let mut data = Vec::new(); + t!(socket.read_to_end(&mut data).await); + data + }; + + // Finally, run everything! + let (_, data) = join!(server, client); + // assert_eq!(amt, AMT); + assert!(data == vec![9; AMT]); +} + +#[tokio::test] +async fn one_byte_at_a_time() { + const AMT: usize = 1024; + drop(env_logger::try_init()); + + let mut srv = t!(TcpListener::bind("127.0.0.1:0").await); + let addr = t!(srv.local_addr()); + + let (server_cx, client_cx) = contexts(); + + let server = async move { + let mut incoming = srv.incoming(); + let socket = t!(incoming.next().await.unwrap()); + let mut socket = t!(server_cx.accept(socket).await); + let mut amt = 0; + for b in std::iter::repeat(9).take(AMT) { + let data = [b as u8]; + t!(socket.write_all(&data).await); + amt += 1; + } + amt + }; + + let client = async move { + let socket = t!(TcpStream::connect(&addr).await); + let mut socket = t!(client_cx.connect("localhost", socket).await); + let mut data = Vec::new(); + loop { + let mut buf = [0; 1]; + match socket.read_exact(&mut buf).await { + Ok(_) => data.extend_from_slice(&buf), + Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break, + Err(err) => panic!(err), + } + } + data + }; + + let (amt, data) = join!(server, client); + assert_eq!(amt, AMT); + assert!(data == vec![9; AMT as usize]); +}