Initial commit

This commit is contained in:
Lucio Franco 2020-01-09 18:36:35 -05:00
commit 43c85779ca
16 changed files with 1575 additions and 0 deletions

72
.github/workflows/CI.yml vendored Normal file
View File

@ -0,0 +1,72 @@
name: CI
on: [push, pull_request]
jobs:
check:
runs-on: ubuntu-latest
env:
RUSTFLAGS: "-D warnings"
steps:
- name: Checkout sources
uses: actions/checkout@v2
- name: Install stable toolchain
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- name: Run cargo check
uses: actions-rs/cargo@v1
with:
command: check --all --all-features --all-targets
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
rust: [stable]
env:
RUSTFLAGS: "-D warnings"
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: ${{ matrix.rust }}
profile: minimal
- uses: actions/checkout@master
- name: Test
run: cargo test --all --all-features
lints:
name: Lints
runs-on: ubuntu-latest
steps:
- name: Checkout sources
uses: actions/checkout@v2
- name: Install stable toolchain
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
components: rustfmt, clippy
- name: Run cargo fmt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- name: Run cargo clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: -- -D warnings

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target
**/*.rs.bk
Cargo.lock

4
Cargo.toml Normal file
View File

@ -0,0 +1,4 @@
[workspace]
members = [
"tokio-native-tls"
]

25
LICENSE Normal file
View File

@ -0,0 +1,25 @@
Copyright (c) 2019 Tokio Contributors
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

63
README.md Normal file
View File

@ -0,0 +1,63 @@
# Tokio Tls
## Overview
This crate contains a collection of Tokio based TLS libraries.
- [`tokio-native-tls`](tokio-native-tls)
## Getting Help
First, see if the answer to your question can be found in the [Guides] or the
[API documentation]. If the answer is not there, there is an active community in
the [Tokio Discord server][chat]. We would be happy to try to answer your
question. Last, if that doesn't work, try opening an [issue] with the question.
[Guides]: https://tokio.rs/docs/
[API documentation]: https://docs.rs/tokio/latest/tokio
[chat]: https://discord.gg/tokio
[issue]: https://github.com/tokio-rs/tls/issues/new
## Contributing
:balloon: Thanks for your help improving the project! We are so happy to have
you! We have a [contributing guide][guide] to help you get involved in the Tokio
project.
[guide]: CONTRIBUTING.md
## Related Projects
In addition to the crates in this repository, the Tokio project also maintains
several other libraries, including:
* [`tracing`] (formerly `tokio-trace`): A framework for application-level
tracing and async-aware diagnostics.
* [`mio`]: A low-level, cross-platform abstraction over OS I/O APIs that powers
`tokio`.
* [`bytes`]: Utilities for working with bytes, including efficient byte buffers.
[`tokio`]: https://github.com/tokio-rs/tokio
[`tracing`]: https://github.com/tokio-rs/tracing
[`mio`]: https://github.com/tokio-rs/mio
[`bytes`]: https://github.com/tokio-rs/bytes
## Supported Rust Versions
Tokio is built against the latest stable, nightly, and beta Rust releases. The
minimum version supported is the stable release from three months before the
current stable release version. For example, if the latest stable Rust is 1.29,
the minimum version supported is 1.26. The current Tokio version is not
guaranteed to build on Rust versions earlier than the minimum supported version.
## License
This project is licensed under the [MIT license](LICENSE).
### Contribution
Unless you explicitly state otherwise, any contribution intentionally submitted
for inclusion in Tokio by you, shall be licensed as MIT, without any additional
terms or conditions.

View File

@ -0,0 +1,3 @@
# 0.1.0 (January 9th, 2019)
- Initial release from `tokio-tls 0.3`

View File

@ -0,0 +1,60 @@
[package]
name = "tokio-native-tls"
# When releasing to crates.io:
# - Remove path dependencies
# - Update html_root_url.
# - Update doc url
# - Cargo.toml
# - README.md
# - Update CHANGELOG.md.
# - Create "v0.1.x" git tag.
version = "0.1.0"
edition = "2018"
authors = ["Tokio Contributors <team@tokio.rs>"]
license = "MIT"
repository = "https://github.com/tokio-rs/tls"
homepage = "https://tokio.rs"
documentation = "https://docs.rs/tokio-native-tls/0.1.0/tokio_native_tls/"
description = """
An implementation of TLS/SSL streams for Tokio using native-tls giving an implementation of TLS
for nonblocking I/O streams.
"""
categories = ["asynchronous", "network-programming"]
[dependencies]
native-tls = "0.2"
tokio = { version = "0.2.0" }
[dev-dependencies]
tokio = { version = "0.2.0", features = ["macros", "stream", "rt-core", "io-util", "net"] }
tokio-util = { version = "0.2.0", features = ["full"] }
cfg-if = "0.1"
env_logger = { version = "0.6", default-features = false }
futures = { version = "0.3.0", features = ["async-await"] }
[target.'cfg(all(not(target_os = "macos"), not(windows), not(target_os = "ios")))'.dev-dependencies]
openssl = "0.10"
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dev-dependencies]
security-framework = "0.2"
[target.'cfg(windows)'.dev-dependencies]
schannel = "0.1"
[target.'cfg(windows)'.dev-dependencies.winapi]
version = "0.3"
features = [
"lmcons",
"basetsd",
"minwinbase",
"minwindef",
"ntdef",
"sysinfoapi",
"timezoneapi",
"wincrypt",
"winerror",
]
[package.metadata.docs.rs]
all-features = true

25
tokio-native-tls/LICENSE Normal file
View File

@ -0,0 +1,25 @@
Copyright (c) 2019 Tokio Contributors
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,14 @@
# tokio-tls
An implementation of TLS/SSL streams for Tokio built on top of the [`native-tls`
crate]
## License
This project is licensed under the [MIT license](./LICENSE).
### Contribution
Unless you explicitly state otherwise, any contribution intentionally submitted
for inclusion in Tokio by you, shall be licensed as MIT, without any additional
terms or conditions.

View File

@ -0,0 +1,39 @@
// #![warn(rust_2018_idioms)]
use native_tls::TlsConnector;
use std::error::Error;
use std::net::ToSocketAddrs;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let addr = "www.rust-lang.org:443"
.to_socket_addrs()?
.next()
.ok_or("failed to resolve www.rust-lang.org")?;
let socket = TcpStream::connect(&addr).await?;
let cx = TlsConnector::builder().build()?;
let cx = tokio_native_tls::TlsConnector::from(cx);
let mut socket = cx.connect("www.rust-lang.org", socket).await?;
socket
.write_all(
"\
GET / HTTP/1.0\r\n\
Host: www.rust-lang.org\r\n\
\r\n\
"
.as_bytes(),
)
.await?;
let mut data = Vec::new();
socket.read_to_end(&mut data).await?;
// println!("data: {:?}", &data);
println!("{}", String::from_utf8_lossy(&data[..]));
Ok(())
}

View File

@ -0,0 +1,54 @@
#![warn(rust_2018_idioms)]
// A tiny async TLS echo server with Tokio
use native_tls;
use native_tls::Identity;
use tokio;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
/**
an example to setup a tls server.
how to test:
wget https://127.0.0.1:12345 --no-check-certificate
*/
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Bind the server's socket
let addr = "127.0.0.1:12345".to_string();
let mut tcp: TcpListener = TcpListener::bind(&addr).await?;
// Create the TLS acceptor.
let der = include_bytes!("identity.p12");
let cert = Identity::from_pkcs12(der, "mypass")?;
let tls_acceptor =
tokio_native_tls::TlsAcceptor::from(native_tls::TlsAcceptor::builder(cert).build()?);
loop {
// Asynchronously wait for an inbound socket.
let (socket, remote_addr) = tcp.accept().await?;
let tls_acceptor = tls_acceptor.clone();
println!("accept connection from {}", remote_addr);
tokio::spawn(async move {
// Accept the TLS connection.
let mut tls_stream = tls_acceptor.accept(socket).await.expect("accept error");
// In a loop, read data from the socket and write the data back.
let mut buf = [0; 1024];
let n = tls_stream
.read(&mut buf)
.await
.expect("failed to read data from socket");
if n == 0 {
return;
}
println!("read={}", unsafe {
String::from_utf8_unchecked(buf[0..n].into())
});
tls_stream
.write_all(&buf[0..n])
.await
.expect("failed to write data to socket");
});
}
}

Binary file not shown.

361
tokio-native-tls/src/lib.rs Normal file
View File

@ -0,0 +1,361 @@
#![doc(html_root_url = "https://docs.rs/tokio-tls/0.3.0")]
#![warn(
missing_debug_implementations,
missing_docs,
rust_2018_idioms,
unreachable_pub
)]
#![deny(intra_doc_link_resolution_failure)]
#![doc(test(
no_crate_inject,
attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
))]
//! Async TLS streams
//!
//! This library is an implementation of TLS streams using the most appropriate
//! system library by default for negotiating the connection. That is, on
//! Windows this library uses SChannel, on OSX it uses SecureTransport, and on
//! other platforms it uses OpenSSL.
//!
//! Each TLS stream implements the `Read` and `Write` traits to interact and
//! interoperate with the rest of the futures I/O ecosystem. Client connections
//! initiated from this crate verify hostnames automatically and by default.
//!
//! This crate primarily exports this ability through two newtypes,
//! `TlsConnector` and `TlsAcceptor`. These newtypes augment the
//! functionality provided by the `native-tls` crate, on which this crate is
//! built. Configuration of TLS parameters is still primarily done through the
//! `native-tls` crate.
use tokio::io::{AsyncRead, AsyncWrite};
use native_tls::{Error, HandshakeError, MidHandshakeTlsStream};
use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::marker::Unpin;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::ptr::null_mut;
use std::task::{Context, Poll};
#[derive(Debug)]
struct AllowStd<S> {
inner: S,
context: *mut (),
}
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
///
/// A `TlsStream<S>` represents a handshake that has been completed successfully
/// and both the server and the client are ready for receiving and sending
/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
/// to a `TlsStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct TlsStream<S>(native_tls::TlsStream<AllowStd<S>>);
/// A wrapper around a `native_tls::TlsConnector`, providing an async `connect`
/// method.
#[derive(Clone)]
pub struct TlsConnector(native_tls::TlsConnector);
/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept`
/// method.
#[derive(Clone)]
pub struct TlsAcceptor(native_tls::TlsAcceptor);
struct MidHandshake<S>(Option<MidHandshakeTlsStream<AllowStd<S>>>);
enum StartedHandshake<S> {
Done(TlsStream<S>),
Mid(MidHandshakeTlsStream<AllowStd<S>>),
}
struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
struct StartedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
struct Guard<'a, S>(&'a mut TlsStream<S>)
where
AllowStd<S>: Read + Write;
impl<S> Drop for Guard<'_, S>
where
AllowStd<S>: Read + Write,
{
fn drop(&mut self) {
(self.0).0.get_mut().context = null_mut();
}
}
// *mut () context is neither Send nor Sync
unsafe impl<S: Send> Send for AllowStd<S> {}
unsafe impl<S: Sync> Sync for AllowStd<S> {}
impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
unsafe {
assert!(!self.context.is_null());
let waker = &mut *(self.context as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
}
impl<S> Read for AllowStd<S>
where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
impl<S> Write for AllowStd<S>
where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
impl<S> TlsStream<S> {
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
where
F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
self.0.get_mut().context = ctx as *mut _ as *mut ();
let g = Guard(self);
f(&mut (g.0).0)
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S
where
S: AsyncRead + AsyncWrite + Unpin,
{
&self.0.get_ref().inner
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S
where
S: AsyncRead + AsyncWrite + Unpin,
{
&mut self.0.get_mut().inner
}
}
impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
// Note that this does not forward to `S` because the buffer is
// unconditionally filled in by OpenSSL, not the actual object `S`.
// We're decrypting bytes from `S` into the buffer above!
false
}
fn poll_read(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.read(buf)))
}
}
impl<S> AsyncWrite for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.write(buf)))
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| cvt(s.flush()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.with_context(ctx, |s| s.shutdown()) {
Ok(()) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
}
async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
where
F: FnOnce(
AllowStd<S>,
) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>>
+ Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
match start.await {
Err(e) => Err(e),
Ok(StartedHandshake::Done(s)) => Ok(s),
Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await,
}
}
impl<F, S> Future for StartedHandshakeFuture<F, S>
where
F: FnOnce(
AllowStd<S>,
) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>>
+ Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = Result<StartedHandshake<S>, Error>;
fn poll(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<StartedHandshake<S>, Error>> {
let inner = self.0.take().expect("future polled after completion");
let stream = AllowStd {
inner: inner.stream,
context: ctx as *mut _ as *mut (),
};
match (inner.f)(stream) {
Ok(mut s) => {
s.get_mut().context = null_mut();
Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s))))
}
Err(HandshakeError::WouldBlock(mut s)) => {
s.get_mut().context = null_mut();
Poll::Ready(Ok(StartedHandshake::Mid(s)))
}
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
}
}
}
impl TlsConnector {
/// Connects the provided stream with this connector, assuming the provided
/// domain.
///
/// This function will internally call `TlsConnector::connect` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used for clients who have already established, for
/// example, a TCP connection to a remote server. That stream is then
/// provided here to perform the client half of a connection to a
/// TLS-powered server.
pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake(move |s| self.0.connect(domain, s), stream).await
}
}
impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
}
}
impl From<native_tls::TlsConnector> for TlsConnector {
fn from(inner: native_tls::TlsConnector) -> TlsConnector {
TlsConnector(inner)
}
}
impl TlsAcceptor {
/// Accepts a new client connection with the provided stream.
///
/// This function will internally call `TlsAcceptor::accept` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used after a new socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of accepting a client connection.
pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake(move |s| self.0.accept(s), stream).await
}
}
impl fmt::Debug for TlsAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsAcceptor").finish()
}
}
impl From<native_tls::TlsAcceptor> for TlsAcceptor {
fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor {
TlsAcceptor(inner)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
type Output = Result<TlsStream<S>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut_self = self.get_mut();
let mut s = mut_self.0.take().expect("future polled after completion");
s.get_mut().context = cx as *mut _ as *mut ();
match s.handshake() {
Ok(stream) => Poll::Ready(Ok(TlsStream(stream))),
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
Err(HandshakeError::WouldBlock(mut s)) => {
s.get_mut().context = null_mut();
mut_self.0 = Some(s);
Poll::Pending
}
}
}
}

View File

@ -0,0 +1,123 @@
#![warn(rust_2018_idioms)]
use cfg_if::cfg_if;
use env_logger;
use native_tls::TlsConnector;
use std::io::{self, Error};
use std::net::ToSocketAddrs;
use tokio::net::TcpStream;
macro_rules! t {
($e:expr) => {
match $e {
Ok(e) => e,
Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
}
};
}
cfg_if! {
if #[cfg(feature = "force-rustls")] {
fn verify_failed(err: &Error, s: &str) {
let err = err.to_string();
assert!(err.contains(s), "bad error: {}", err);
}
fn assert_expired_error(err: &Error) {
verify_failed(err, "CertExpired");
}
fn assert_wrong_host(err: &Error) {
verify_failed(err, "CertNotValidForName");
}
fn assert_self_signed(err: &Error) {
verify_failed(err, "UnknownIssuer");
}
fn assert_untrusted_root(err: &Error) {
verify_failed(err, "UnknownIssuer");
}
} else if #[cfg(any(feature = "force-openssl",
all(not(target_os = "macos"),
not(target_os = "windows"),
not(target_os = "ios"))))] {
fn verify_failed(err: &Error) {
assert!(format!("{}", err).contains("certificate verify failed"))
}
use verify_failed as assert_expired_error;
use verify_failed as assert_wrong_host;
use verify_failed as assert_self_signed;
use verify_failed as assert_untrusted_root;
} else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
fn assert_invalid_cert_chain(err: &Error) {
assert!(format!("{}", err).contains("was not trusted."))
}
use crate::assert_invalid_cert_chain as assert_expired_error;
use crate::assert_invalid_cert_chain as assert_wrong_host;
use crate::assert_invalid_cert_chain as assert_self_signed;
use crate::assert_invalid_cert_chain as assert_untrusted_root;
} else {
fn assert_expired_error(err: &Error) {
let s = err.to_string();
assert!(s.contains("system clock"), "error = {:?}", s);
}
fn assert_wrong_host(err: &Error) {
let s = err.to_string();
assert!(s.contains("CN name"), "error = {:?}", s);
}
fn assert_self_signed(err: &Error) {
let s = err.to_string();
assert!(s.contains("root certificate which is not trusted"), "error = {:?}", s);
}
use assert_self_signed as assert_untrusted_root;
}
}
async fn get_host(host: &'static str) -> Error {
drop(env_logger::try_init());
let addr = format!("{}:443", host);
let addr = t!(addr.to_socket_addrs()).next().unwrap();
let socket = t!(TcpStream::connect(&addr).await);
let builder = TlsConnector::builder();
let cx = t!(builder.build());
let cx = tokio_native_tls::TlsConnector::from(cx);
let res = cx
.connect(host, socket)
.await
.map_err(|e| Error::new(io::ErrorKind::Other, e));
assert!(res.is_err());
res.err().unwrap()
}
#[tokio::test]
async fn expired() {
assert_expired_error(&get_host("expired.badssl.com").await)
}
// TODO: the OSX builders on Travis apparently fail this tests spuriously?
// passes locally though? Seems... bad!
#[tokio::test]
#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)]
async fn wrong_host() {
assert_wrong_host(&get_host("wrong.host.badssl.com").await)
}
#[tokio::test]
async fn self_signed() {
assert_self_signed(&get_host("self-signed.badssl.com").await)
}
#[tokio::test]
async fn untrusted_root() {
assert_untrusted_root(&get_host("untrusted-root.badssl.com").await)
}

View File

@ -0,0 +1,101 @@
#![warn(rust_2018_idioms)]
use cfg_if::cfg_if;
use env_logger;
use native_tls;
use native_tls::TlsConnector;
use std::io;
use std::net::ToSocketAddrs;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
macro_rules! t {
($e:expr) => {
match $e {
Ok(e) => e,
Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
}
};
}
cfg_if! {
if #[cfg(feature = "force-rustls")] {
fn assert_bad_hostname_error(err: &io::Error) {
let err = err.to_string();
assert!(err.contains("CertNotValidForName"), "bad error: {}", err);
}
} else if #[cfg(any(feature = "force-openssl",
all(not(target_os = "macos"),
not(target_os = "windows"),
not(target_os = "ios"))))] {
fn assert_bad_hostname_error(err: &io::Error) {
let err = err.get_ref().unwrap();
let err = err.downcast_ref::<native_tls::Error>().unwrap();
assert!(format!("{}", err).contains("certificate verify failed"));
}
} else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
fn assert_bad_hostname_error(err: &io::Error) {
let err = err.get_ref().unwrap();
let err = err.downcast_ref::<native_tls::Error>().unwrap();
assert!(format!("{}", err).contains("was not trusted."));
}
} else {
fn assert_bad_hostname_error(err: &io::Error) {
let err = err.get_ref().unwrap();
let err = err.downcast_ref::<native_tls::Error>().unwrap();
assert!(format!("{}", err).contains("CN name"));
}
}
}
#[tokio::test]
async fn fetch_google() {
drop(env_logger::try_init());
// First up, resolve google.com
let addr = t!("google.com:443".to_socket_addrs()).next().unwrap();
let socket = TcpStream::connect(&addr).await.unwrap();
// Send off the request by first negotiating an SSL handshake, then writing
// of our request, then flushing, then finally read off the response.
let builder = TlsConnector::builder();
let connector = t!(builder.build());
let connector = tokio_native_tls::TlsConnector::from(connector);
let mut socket = t!(connector.connect("google.com", socket).await);
t!(socket.write_all(b"GET / HTTP/1.0\r\n\r\n").await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
// any response code is fine
assert!(data.starts_with(b"HTTP/1.0 "));
let data = String::from_utf8_lossy(&data);
let data = data.trim_end();
assert!(data.ends_with("</html>") || data.ends_with("</HTML>"));
}
fn native2io(e: native_tls::Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
// see comment in bad.rs for ignore reason
#[cfg_attr(all(target_os = "macos", feature = "force-openssl"), ignore)]
#[tokio::test]
async fn wrong_hostname_error() {
drop(env_logger::try_init());
let addr = t!("google.com:443".to_socket_addrs()).next().unwrap();
let socket = t!(TcpStream::connect(&addr).await);
let builder = TlsConnector::builder();
let connector = t!(builder.build());
let connector = tokio_native_tls::TlsConnector::from(connector);
let res = connector
.connect("rust-lang.org", socket)
.await
.map_err(native2io);
assert!(res.is_err());
assert_bad_hostname_error(&res.err().unwrap());
}

View File

@ -0,0 +1,628 @@
#![warn(rust_2018_idioms)]
use cfg_if::cfg_if;
use env_logger;
use futures::join;
use native_tls;
use native_tls::{Identity, TlsAcceptor, TlsConnector};
use std::io::Write;
use std::marker::Unpin;
use std::process::Command;
use std::ptr;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind};
use tokio::net::{TcpListener, TcpStream};
use tokio::stream::StreamExt;
macro_rules! t {
($e:expr) => {
match $e {
Ok(e) => e,
Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
}
};
}
#[allow(dead_code)]
struct Keys {
cert_der: Vec<u8>,
pkey_der: Vec<u8>,
pkcs12_der: Vec<u8>,
}
#[allow(dead_code)]
fn openssl_keys() -> &'static Keys {
static INIT: Once = Once::new();
static mut KEYS: *mut Keys = ptr::null_mut();
INIT.call_once(|| {
let path = t!(env::current_exe());
let path = path.parent().unwrap();
let keyfile = path.join("test.key");
let certfile = path.join("test.crt");
let config = path.join("openssl.config");
File::create(&config)
.unwrap()
.write_all(
b"\
[req]\n\
distinguished_name=dn\n\
[ dn ]\n\
CN=localhost\n\
[ ext ]\n\
basicConstraints=CA:FALSE,pathlen:0\n\
subjectAltName = @alt_names
extendedKeyUsage=serverAuth,clientAuth
[alt_names]
DNS.1 = localhost
",
)
.unwrap();
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
let output = t!(Command::new("openssl")
.arg("req")
.arg("-nodes")
.arg("-x509")
.arg("-newkey")
.arg("rsa:2048")
.arg("-config")
.arg(&config)
.arg("-extensions")
.arg("ext")
.arg("-subj")
.arg(subj)
.arg("-keyout")
.arg(&keyfile)
.arg("-out")
.arg(&certfile)
.arg("-days")
.arg("1")
.output());
assert!(output.status.success());
let crtout = t!(Command::new("openssl")
.arg("x509")
.arg("-outform")
.arg("der")
.arg("-in")
.arg(&certfile)
.output());
assert!(crtout.status.success());
let keyout = t!(Command::new("openssl")
.arg("rsa")
.arg("-outform")
.arg("der")
.arg("-in")
.arg(&keyfile)
.output());
assert!(keyout.status.success());
let pkcs12out = t!(Command::new("openssl")
.arg("pkcs12")
.arg("-export")
.arg("-nodes")
.arg("-inkey")
.arg(&keyfile)
.arg("-in")
.arg(&certfile)
.arg("-password")
.arg("pass:foobar")
.output());
assert!(pkcs12out.status.success());
let keys = Box::new(Keys {
cert_der: crtout.stdout,
pkey_der: keyout.stdout,
pkcs12_der: pkcs12out.stdout,
});
unsafe {
KEYS = Box::into_raw(keys);
}
});
unsafe { &*KEYS }
}
cfg_if! {
if #[cfg(feature = "rustls")] {
use webpki;
use untrusted;
use std::env;
use std::fs::File;
use std::process::Command;
use std::sync::Once;
use untrusted::Input;
use webpki::trust_anchor_util;
fn server_cx() -> io::Result<ServerContext> {
let mut cx = ServerContext::new();
let (cert, key) = keys();
cx.config_mut()
.set_single_cert(vec![cert.to_vec()], key.to_vec());
Ok(cx)
}
fn configure_client(cx: &mut ClientContext) {
let (cert, _key) = keys();
let cert = Input::from(cert);
let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap();
cx.config_mut().root_store.add_trust_anchors(&[anchor]);
}
// Like OpenSSL we generate certificates on the fly, but for OSX we
// also have to put them into a specific keychain. We put both the
// certificates and the keychain next to our binary.
//
// Right now I don't know of a way to programmatically create a
// self-signed certificate, so we just fork out to the `openssl` binary.
fn keys() -> (&'static [u8], &'static [u8]) {
static INIT: Once = Once::new();
static mut KEYS: *mut (Vec<u8>, Vec<u8>) = ptr::null_mut();
INIT.call_once(|| {
let (key, cert) = openssl_keys();
let path = t!(env::current_exe());
let path = path.parent().unwrap();
let keyfile = path.join("test.key");
let certfile = path.join("test.crt");
let config = path.join("openssl.config");
File::create(&config).unwrap().write_all(b"\
[req]\n\
distinguished_name=dn\n\
[ dn ]\n\
CN=localhost\n\
[ ext ]\n\
basicConstraints=CA:FALSE,pathlen:0\n\
subjectAltName = @alt_names
[alt_names]
DNS.1 = localhost
").unwrap();
let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
let output = t!(Command::new("openssl")
.arg("req")
.arg("-nodes")
.arg("-x509")
.arg("-newkey").arg("rsa:2048")
.arg("-config").arg(&config)
.arg("-extensions").arg("ext")
.arg("-subj").arg(subj)
.arg("-keyout").arg(&keyfile)
.arg("-out").arg(&certfile)
.arg("-days").arg("1")
.output());
assert!(output.status.success());
let crtout = t!(Command::new("openssl")
.arg("x509")
.arg("-outform").arg("der")
.arg("-in").arg(&certfile)
.output());
assert!(crtout.status.success());
let keyout = t!(Command::new("openssl")
.arg("rsa")
.arg("-outform").arg("der")
.arg("-in").arg(&keyfile)
.output());
assert!(keyout.status.success());
let cert = crtout.stdout;
let key = keyout.stdout;
unsafe {
KEYS = Box::into_raw(Box::new((cert, key)));
}
});
unsafe {
(&(*KEYS).0, &(*KEYS).1)
}
}
} else if #[cfg(any(feature = "force-openssl",
all(not(target_os = "macos"),
not(target_os = "windows"),
not(target_os = "ios"))))] {
use std::fs::File;
use std::env;
use std::sync::Once;
fn contexts() -> (tokio_native_tls::TlsAcceptor, tokio_native_tls::TlsConnector) {
let keys = openssl_keys();
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let cert = t!(native_tls::Certificate::from_der(&keys.cert_der));
let mut client = TlsConnector::builder();
t!(client.add_root_certificate(cert).build());
(t!(srv.build()).into(), t!(client.build()).into())
}
} else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
use std::env;
use std::fs::File;
use std::sync::Once;
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
let keys = openssl_keys();
let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap();
let mut client = TlsConnector::builder();
client.add_root_certificate(cert);
(t!(srv.build()).into(), t!(client.build()).into())
}
} else {
use schannel;
use winapi;
use std::env;
use std::fs::File;
use std::io;
use std::mem;
use std::sync::Once;
use schannel::cert_context::CertContext;
use schannel::cert_store::{CertStore, CertAdd, Memory};
use winapi::shared::basetsd::*;
use winapi::shared::lmcons::*;
use winapi::shared::minwindef::*;
use winapi::shared::ntdef::WCHAR;
use winapi::um::minwinbase::*;
use winapi::um::sysinfoapi::*;
use winapi::um::timezoneapi::*;
use winapi::um::wincrypt::*;
const FRIENDLY_NAME: &'static str = "tokio-tls localhost testing cert";
fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
let cert = localhost_cert();
let mut store = t!(Memory::new()).into_store();
t!(store.add_cert(&cert, CertAdd::Always));
let pkcs12_der = t!(store.export_pkcs12("foobar"));
let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar"));
let srv = TlsAcceptor::builder(pkcs12);
let client = TlsConnector::builder();
(t!(srv.build()).into(), t!(client.build()).into())
}
// ====================================================================
// Magic!
//
// Lots of magic is happening here to wrangle certificates for running
// these tests on Windows. For more information see the test suite
// in the schannel-rs crate as this is just coyping that.
//
// The general gist of this though is that the only way to add custom
// trusted certificates is to add it to the system store of trust. To
// do that we go through the whole rigamarole here to generate a new
// self-signed certificate and then insert that into the system store.
//
// This generates some dialogs, so we print what we're doing sometimes,
// and otherwise we just manage the ephemeral certificates. Because
// they're in the system store we always ensure that they're only valid
// for a small period of time (e.g. 1 day).
fn localhost_cert() -> CertContext {
static INIT: Once = Once::new();
INIT.call_once(|| {
for cert in local_root_store().certs() {
let name = match cert.friendly_name() {
Ok(name) => name,
Err(_) => continue,
};
if name != FRIENDLY_NAME {
continue
}
if !cert.is_time_valid().unwrap() {
io::stdout().write_all(br#"
The tokio-tls test suite is about to delete an old copy of one of its
certificates from your root trust store. This certificate was only valid for one
day and it is no longer needed. The host should be "localhost" and the
description should mention "tokio-tls".
"#).unwrap();
cert.delete().unwrap();
} else {
return
}
}
install_certificate().unwrap();
});
for cert in local_root_store().certs() {
let name = match cert.friendly_name() {
Ok(name) => name,
Err(_) => continue,
};
if name == FRIENDLY_NAME {
return cert
}
}
panic!("couldn't find a cert");
}
fn local_root_store() -> CertStore {
if env::var("CI").is_ok() {
CertStore::open_local_machine("Root").unwrap()
} else {
CertStore::open_current_user("Root").unwrap()
}
}
fn install_certificate() -> io::Result<CertContext> {
unsafe {
let mut provider = 0;
let mut hkey = 0;
let mut buffer = "tokio-tls test suite".encode_utf16()
.chain(Some(0))
.collect::<Vec<_>>();
let res = CryptAcquireContextW(&mut provider,
buffer.as_ptr(),
ptr::null_mut(),
PROV_RSA_FULL,
CRYPT_MACHINE_KEYSET);
if res != TRUE {
// create a new key container (since it does not exist)
let res = CryptAcquireContextW(&mut provider,
buffer.as_ptr(),
ptr::null_mut(),
PROV_RSA_FULL,
CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET);
if res != TRUE {
return Err(Error::last_os_error())
}
}
// create a new keypair (RSA-2048)
let res = CryptGenKey(provider,
AT_SIGNATURE,
0x0800<<16 | CRYPT_EXPORTABLE,
&mut hkey);
if res != TRUE {
return Err(Error::last_os_error());
}
// start creating the certificate
let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\
G=tokio_tls".encode_utf16()
.chain(Some(0))
.collect::<Vec<_>>();
let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed();
let mut cname_len = cname_buffer.len() as DWORD;
let res = CertStrToNameW(X509_ASN_ENCODING,
name.as_ptr(),
CERT_X500_NAME_STR,
ptr::null_mut(),
cname_buffer.as_mut_ptr() as *mut u8,
&mut cname_len,
ptr::null_mut());
if res != TRUE {
return Err(Error::last_os_error());
}
let mut subject_issuer = CERT_NAME_BLOB {
cbData: cname_len,
pbData: cname_buffer.as_ptr() as *mut u8,
};
let mut key_provider = CRYPT_KEY_PROV_INFO {
pwszContainerName: buffer.as_mut_ptr(),
pwszProvName: ptr::null_mut(),
dwProvType: PROV_RSA_FULL,
dwFlags: CRYPT_MACHINE_KEYSET,
cProvParam: 0,
rgProvParam: ptr::null_mut(),
dwKeySpec: AT_SIGNATURE,
};
let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER {
pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _,
Parameters: mem::zeroed(),
};
let mut expiration_date: SYSTEMTIME = mem::zeroed();
GetSystemTime(&mut expiration_date);
let mut file_time: FILETIME = mem::zeroed();
let res = SystemTimeToFileTime(&mut expiration_date,
&mut file_time);
if res != TRUE {
return Err(Error::last_os_error());
}
let mut timestamp: u64 = file_time.dwLowDateTime as u64 |
(file_time.dwHighDateTime as u64) << 32;
// one day, timestamp unit is in 100 nanosecond intervals
timestamp += (1E9 as u64) / 100 * (60 * 60 * 24);
file_time.dwLowDateTime = timestamp as u32;
file_time.dwHighDateTime = (timestamp >> 32) as u32;
let res = FileTimeToSystemTime(&file_time,
&mut expiration_date);
if res != TRUE {
return Err(Error::last_os_error());
}
// create a self signed certificate
let cert_context = CertCreateSelfSignCertificate(
0 as ULONG_PTR,
&mut subject_issuer,
0,
&mut key_provider,
&mut sig_algorithm,
ptr::null_mut(),
&mut expiration_date,
ptr::null_mut());
if cert_context.is_null() {
return Err(Error::last_os_error());
}
// TODO: this is.. a terrible hack. Right now `schannel`
// doesn't provide a public method to go from a raw
// cert context pointer to the `CertContext` structure it
// has, so we just fake it here with a transmute. This'll
// probably break at some point, but hopefully by then
// it'll have a method to do this!
struct MyCertContext<T>(T);
impl<T> Drop for MyCertContext<T> {
fn drop(&mut self) {}
}
let cert_context = MyCertContext(cert_context);
let cert_context: CertContext = mem::transmute(cert_context);
cert_context.set_friendly_name(FRIENDLY_NAME)?;
// install the certificate to the machine's local store
io::stdout().write_all(br#"
The tokio-tls test suite is about to add a certificate to your set of root
and trusted certificates. This certificate should be for the domain "localhost"
with the description related to "tokio-tls". This certificate is only valid
for one day and will be automatically deleted if you re-run the tokio-tls
test suite later.
"#).unwrap();
local_root_store().add_cert(&cert_context,
CertAdd::ReplaceExisting)?;
Ok(cert_context)
}
}
}
}
const AMT: usize = 128 * 1024;
async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
let mut data = vec![9; AMT as usize];
let mut amt = 0;
while !data.is_empty() {
let written = w.write(&data).await?;
if written <= data.len() {
amt += written;
data.resize(data.len() - written, 0);
} else {
w.write_all(&data).await?;
amt += data.len();
break;
}
println!("remaining: {}", data.len());
}
Ok(amt)
}
#[tokio::test]
async fn client_to_server() {
drop(env_logger::try_init());
// Create a server listening on a port, then figure out what that port is
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
// Create a future to accept one socket, connect the ssl stream, and then
// read all the data from it.
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let mut socket = t!(server_cx.accept(socket).await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
data
};
// Create a future to connect to our server, connect the ssl stream, and
// then write a bunch of data to it.
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let socket = t!(client_cx.connect("localhost", socket).await);
copy_data(socket).await
};
// Finally, run everything!
let (data, _) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn server_to_client() {
drop(env_logger::try_init());
// Create a server listening on a port, then figure out what that port is
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let socket = t!(server_cx.accept(socket).await);
copy_data(socket).await
};
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let mut socket = t!(client_cx.connect("localhost", socket).await);
let mut data = Vec::new();
t!(socket.read_to_end(&mut data).await);
data
};
// Finally, run everything!
let (_, data) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn one_byte_at_a_time() {
const AMT: usize = 1024;
drop(env_logger::try_init());
let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
let addr = t!(srv.local_addr());
let (server_cx, client_cx) = contexts();
let server = async move {
let mut incoming = srv.incoming();
let socket = t!(incoming.next().await.unwrap());
let mut socket = t!(server_cx.accept(socket).await);
let mut amt = 0;
for b in std::iter::repeat(9).take(AMT) {
let data = [b as u8];
t!(socket.write_all(&data).await);
amt += 1;
}
amt
};
let client = async move {
let socket = t!(TcpStream::connect(&addr).await);
let mut socket = t!(client_cx.connect("localhost", socket).await);
let mut data = Vec::new();
loop {
let mut buf = [0; 1];
match socket.read_exact(&mut buf).await {
Ok(_) => data.extend_from_slice(&buf),
Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
Err(err) => panic!(err),
}
}
data
};
let (amt, data) = join!(server, client);
assert_eq!(amt, AMT);
assert!(data == vec![9; AMT as usize]);
}