Merge branch 'master' into new-api

This commit is contained in:
quininer 2018-09-17 20:09:50 +08:00
commit f633a72f02
13 changed files with 568 additions and 364 deletions

View File

@ -1,13 +1,21 @@
language: rust
rust:
- stable
cache: cargo
os:
- linux
- osx
matrix:
include:
- rust: stable
os: linux
- rust: nightly
env: FEATURE=nightly
os: linux
- rust: stable
os: osx
- rust: nightly
env: FEATURE=nightly
os: osx
script:
- cargo test --all-features
- cargo test --features "$FEATURE"
- cd examples/server
- cargo check
- cd ../../examples/client

View File

@ -1,6 +1,6 @@
[package]
name = "tokio-rustls"
version = "0.7.1"
version = "0.7.2"
authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0"
repository = "https://github.com/quininer/tokio-rustls"
@ -15,20 +15,17 @@ travis-ci = { repository = "quininer/tokio-rustls" }
appveyor = { repository = "quininer/tokio-rustls" }
[dependencies]
futures-core = { version = "0.2.0", optional = true }
futures-io = { version = "0.2.0", optional = true }
tokio = { version = "0.1.6", optional = true }
bytes = { version = "0.4", optional = true }
iovec = { version = "0.1", optional = true }
rustls = "0.13"
webpki = "0.18.1"
[dev-dependencies]
# futures = "0.2.0"
tokio = "0.1.6"
lazy_static = "1"
[features]
default = [ "tokio" ]
# unstable-futures = [
# "futures-core",
# "futures-io",
# "tokio/unstable-futures"
# ]
default = ["tokio-support"]
nightly = ["bytes", "iovec"]
tokio-support = ["tokio"]

View File

@ -13,7 +13,7 @@ install:
build: false
test_script:
- 'cargo test --all-features'
- 'cargo test'
- 'cd examples/server'
- 'cargo check'
- 'cd ../../examples/client'

View File

@ -5,15 +5,8 @@ authors = ["quininer <quininer@live.com>"]
[dependencies]
webpki = "0.18.1"
tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] }
tokio-rustls = { path = "../.." }
tokio = "0.1"
clap = "2.26"
clap = "2"
webpki-roots = "0.15"
[target.'cfg(unix)'.dependencies]
tokio-file-unix = "0.5"
[target.'cfg(not(unix))'.dependencies]
tokio-fs = "0.1"
tokio-stdin-stdout = "0.1"

View File

@ -4,8 +4,7 @@ extern crate webpki;
extern crate webpki_roots;
extern crate tokio_rustls;
#[cfg(unix)] extern crate tokio_file_unix;
#[cfg(not(unix))] extern crate tokio_fs;
extern crate tokio_stdin_stdout;
use std::sync::Arc;
use std::net::ToSocketAddrs;
@ -16,6 +15,7 @@ use tokio::net::TcpStream;
use tokio::prelude::*;
use clap::{ App, Arg };
use tokio_rustls::{ TlsConnector, rustls::ClientConfig };
use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout };
fn app() -> App<'static, 'static> {
App::new("client")
@ -52,59 +52,23 @@ fn main() {
let config = TlsConnector::from(Arc::new(config));
let socket = TcpStream::connect(&addr);
let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0));
#[cfg(unix)]
let resp = {
use tokio::reactor::Handle;
use tokio_file_unix::{ raw_stdin, raw_stdout, File };
let done = socket
.and_then(move |stream| {
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
config.connect(domain, stream)
})
.and_then(move |stream| io::write_all(stream, text))
.and_then(move |(stream, _)| {
let (r, w) = stream.split();
io::copy(r, stdout)
.map(drop)
.select2(io::copy(stdin, w).map(drop))
.map_err(|res| res.split().0)
})
.map(drop)
.map_err(|err| eprintln!("{:?}", err));
let stdin = raw_stdin()
.and_then(File::new_nb)
.and_then(|fd| fd.into_reader(&Handle::current()))
.unwrap();
let stdout = raw_stdout()
.and_then(File::new_nb)
.and_then(|fd| fd.into_io(&Handle::current()))
.unwrap();
socket
.and_then(move |stream| {
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
config.connect(domain, stream)
})
.and_then(move |stream| io::write_all(stream, text))
.and_then(move |(stream, _)| {
let (r, w) = stream.split();
io::copy(r, stdout)
.map(drop)
.select2(io::copy(stdin, w).map(drop))
.map_err(|res| res.split().0)
})
.map(drop)
.map_err(|err| eprintln!("{:?}", err))
};
#[cfg(not(unix))]
let resp = {
use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout };
let (stdin, stdout) = (tokio_stdin(), tokio_stdout());
socket
.and_then(move |stream| {
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
config.connect(domain, stream)
})
.and_then(move |stream| io::write_all(stream, text))
.and_then(move |(stream, _)| {
let (r, w) = stream.split();
io::copy(r, stdout)
.map(drop)
.join(io::copy(stdin, w).map(drop))
})
.map(drop)
.map_err(|err| eprintln!("{:?}", err))
};
tokio::run(resp);
tokio::run(done);
}

View File

@ -4,9 +4,6 @@ version = "0.1.0"
authors = ["quininer <quininer@live.com>"]
[dependencies]
tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] }
tokio-rustls = { path = "../.." }
tokio = { version = "0.1.6" }
# futures = "0.2.0-beta"
clap = "2.26"
clap = "2"

153
src/common/mod.rs Normal file
View File

@ -0,0 +1,153 @@
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
mod vecbuf;
use std::io::{ self, Read, Write };
#[cfg(feature = "nightly")]
use std::io::Initializer;
use rustls::Session;
#[cfg(feature = "nightly")]
use rustls::WriteV;
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
use tokio::io::AsyncWrite;
pub struct Stream<'a, S: 'a, IO: 'a> {
session: &'a mut S,
io: &'a mut IO
}
pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write {
fn write_tls(&mut self) -> io::Result<usize>;
}
impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> {
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
Stream { session, io }
}
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161
let until_handshaked = self.session.is_handshaking();
let mut eof = false;
let mut wrlen = 0;
let mut rdlen = 0;
loop {
while self.session.wants_write() {
wrlen += self.write_tls()?;
}
if !until_handshaked && wrlen > 0 {
return Ok((rdlen, wrlen));
}
if !eof && self.session.wants_read() {
match self.session.read_tls(self.io)? {
0 => eof = true,
n => rdlen += n
}
}
match self.session.process_new_packets() {
Ok(_) => {},
Err(e) => {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ignored = self.write_tls();
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
},
};
match (eof, until_handshaked, self.session.is_handshaking()) {
(_, true, false) => return Ok((rdlen, wrlen)),
(_, false, _) => return Ok((rdlen, wrlen)),
(true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(..) => ()
}
}
}
}
#[cfg(not(feature = "nightly"))]
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
fn write_tls(&mut self) -> io::Result<usize> {
self.session.write_tls(self.io)
}
}
#[cfg(feature = "nightly")]
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
default fn write_tls(&mut self) -> io::Result<usize> {
self.session.write_tls(self.io)
}
}
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
fn write_tls(&mut self) -> io::Result<usize> {
use tokio::prelude::Async;
use self::vecbuf::VecBuf;
struct V<'a, IO: 'a>(&'a mut IO);
impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> {
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize> {
let mut vbytes = VecBuf::new(vbytes);
match self.0.write_buf(&mut vbytes) {
Ok(Async::Ready(n)) => Ok(n),
Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()),
Err(err) => Err(err)
}
}
}
let mut vecbuf = V(self.io);
self.session.writev_tls(&mut vecbuf)
}
}
impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> {
#[cfg(feature = "nightly")]
unsafe fn initializer(&self) -> Initializer {
Initializer::nop()
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.session.wants_read() {
if let (0, 0) = self.complete_io()? {
break
}
}
self.session.read(buf)
}
}
impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let len = self.session.write(buf)?;
while self.session.wants_write() {
match self.complete_io() {
Ok(_) => (),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
Err(err) => return Err(err)
}
}
Ok(len)
}
fn flush(&mut self) -> io::Result<()> {
self.session.flush()?;
if self.session.wants_write() {
self.complete_io()?;
}
Ok(())
}
}
#[cfg(test)]
mod test_stream;

161
src/common/test_stream.rs Normal file
View File

@ -0,0 +1,161 @@
use std::sync::Arc;
use std::io::{ self, Read, Write, BufReader, Cursor };
use webpki::DNSNameRef;
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use rustls::{
ServerConfig, ClientConfig,
ServerSession, ClientSession,
Session, NoClientAuth
};
use super::Stream;
struct Good<'a>(&'a mut Session);
impl<'a> Read for Good<'a> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
self.0.write_tls(buf.by_ref())
}
}
impl<'a> Write for Good<'a> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
let len = self.0.read_tls(buf.by_ref())?;
self.0.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Ok(len)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
struct Bad(bool);
impl Read for Bad {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Ok(0)
}
}
impl Write for Bad {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0 {
Err(io::ErrorKind::WouldBlock.into())
} else {
Ok(buf.len())
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[test]
fn stream_good() -> io::Result<()> {
const FILE: &'static [u8] = include_bytes!("../../README.md");
let (mut server, mut client) = make_pair();
do_handshake(&mut client, &mut server);
io::copy(&mut Cursor::new(FILE), &mut server)?;
{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut client, &mut good);
let mut buf = Vec::new();
stream.read_to_end(&mut buf)?;
assert_eq!(buf, FILE);
stream.write_all(b"Hello World!")?
}
let mut buf = String::new();
server.read_to_string(&mut buf)?;
assert_eq!(buf, "Hello World!");
Ok(())
}
#[test]
fn stream_bad() -> io::Result<()> {
let (mut server, mut client) = make_pair();
do_handshake(&mut client, &mut server);
client.set_buffer_limit(1024);
let mut bad = Bad(true);
let mut stream = Stream::new(&mut client, &mut bad);
assert_eq!(stream.write(&[0x42; 8])?, 8);
assert_eq!(stream.write(&[0x42; 8])?, 8);
let r = stream.write(&[0x00; 1024])?; // fill buffer
assert!(r < 1024);
assert_eq!(
stream.write(&[0x01]).unwrap_err().kind(),
io::ErrorKind::WouldBlock
);
Ok(())
}
#[test]
fn stream_handshake() -> io::Result<()> {
let (mut server, mut client) = make_pair();
{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut client, &mut good);
let (r, w) = stream.complete_io()?;
assert!(r > 0);
assert!(w > 0);
stream.complete_io()?; // finish server handshake
}
assert!(!server.is_handshaking());
assert!(!client.is_handshaking());
Ok(())
}
#[test]
fn stream_handshake_eof() -> io::Result<()> {
let (_, mut client) = make_pair();
let mut bad = Bad(false);
let mut stream = Stream::new(&mut client, &mut bad);
let r = stream.complete_io();
assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
Ok(())
}
fn make_pair() -> (ServerSession, ClientSession) {
const CERT: &str = include_str!("../../tests/end.cert");
const CHAIN: &str = include_str!("../../tests/end.chain");
const RSA: &str = include_str!("../../tests/end.rsa");
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut sconfig = ServerConfig::new(NoClientAuth::new());
sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap();
let server = ServerSession::new(&Arc::new(sconfig));
let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap();
let mut cconfig = ClientConfig::new();
let mut chain = BufReader::new(Cursor::new(CHAIN));
cconfig.root_store.add_pem_file(&mut chain).unwrap();
let client = ClientSession::new(&Arc::new(cconfig), domain);
(server, client)
}
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) {
let mut good = Good(server);
let mut stream = Stream::new(client, &mut good);
stream.complete_io().unwrap();
stream.complete_io().unwrap();
}

122
src/common/vecbuf.rs Normal file
View File

@ -0,0 +1,122 @@
use std::cmp::{ self, Ordering };
use bytes::Buf;
use iovec::IoVec;
pub struct VecBuf<'a, 'b: 'a> {
pos: usize,
cur: usize,
inner: &'a [&'b [u8]]
}
impl<'a, 'b> VecBuf<'a, 'b> {
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
VecBuf { pos: 0, cur: 0, inner: vbytes }
}
}
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn remaining(&self) -> usize {
let sum = self.inner
.iter()
.skip(self.pos)
.map(|bytes| bytes.len())
.sum::<usize>();
sum - self.cur
}
fn bytes(&self) -> &[u8] {
&self.inner[self.pos][self.cur..]
}
fn advance(&mut self, cnt: usize) {
let current = self.inner[self.pos].len();
match (self.cur + cnt).cmp(&current) {
Ordering::Equal => if self.pos + 1 < self.inner.len() {
self.pos += 1;
self.cur = 0;
} else {
self.cur += cnt;
},
Ordering::Greater => {
if self.pos + 1 < self.inner.len() {
self.pos += 1;
}
let remaining = self.cur + cnt - current;
self.advance(remaining);
},
Ordering::Less => self.cur += cnt,
}
}
#[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))]
fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize {
let len = cmp::min(self.inner.len() - self.pos, dst.len());
if len > 0 {
dst[0] = self.bytes().into();
}
for i in 1..len {
dst[i] = self.inner[self.pos + i].into();
}
len
}
}
#[cfg(test)]
mod test_vecbuf {
use super::*;
#[test]
fn test_fresh_cursor_vec() {
let mut buf = VecBuf::new(&[b"he", b"llo"]);
assert_eq!(buf.remaining(), 5);
assert_eq!(buf.bytes(), b"he");
buf.advance(2);
assert_eq!(buf.remaining(), 3);
assert_eq!(buf.bytes(), b"llo");
buf.advance(3);
assert_eq!(buf.remaining(), 0);
assert_eq!(buf.bytes(), b"");
}
#[test]
fn test_get_u8() {
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
assert_eq!(0x21, buf.get_u8());
}
#[test]
fn test_get_u16() {
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
assert_eq!(0x2154, buf.get_u16_be());
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
assert_eq!(0x5421, buf.get_u16_le());
}
#[test]
#[should_panic]
fn test_get_u16_buffer_underflow() {
let mut buf = VecBuf::new(&[b"\x21"]);
buf.get_u16_be();
}
#[test]
fn test_bufs_vec() {
let buf = VecBuf::new(&[b"he", b"llo"]);
let b1: &[u8] = &mut [0];
let b2: &[u8] = &mut [0];
let mut dst: [&IoVec; 2] =
[b1.into(), b2.into()];
assert_eq!(2, buf.bytes_vec(&mut dst[..]));
}
}

View File

@ -1,170 +0,0 @@
extern crate futures_core;
extern crate futures_io;
use super::*;
use self::futures_core::{ Future, Poll, Async };
use self::futures_core::task::Context;
use self::futures_io::{ Error, AsyncRead, AsyncWrite };
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
type Item = TlsStream<S, ClientSession>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll(ctx)
}
}
impl<S: AsyncRead + AsyncWrite> Future for AcceptAsync<S> {
type Item = TlsStream<S, ServerSession>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
self.0.poll(ctx)
}
}
macro_rules! async {
( to $r:expr ) => {
match $r {
Ok(Async::Ready(n)) => Ok(n),
Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()),
Err(e) => Err(e)
}
};
( from $r:expr ) => {
match $r {
Ok(n) => Ok(Async::Ready(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
Err(e) => Err(e)
}
};
}
struct TaskStream<'a, 'b: 'a, S: 'a> {
io: &'a mut S,
task: &'a mut Context<'b>
}
impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S>
where S: AsyncRead + AsyncWrite
{
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
async!(to self.io.poll_read(self.task, buf))
}
}
impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S>
where S: AsyncRead + AsyncWrite
{
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
async!(to self.io.poll_write(self.task, buf))
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
async!(to self.io.poll_flush(self.task))
}
}
impl<S, C> Future for MidHandshake<S, C>
where S: AsyncRead + AsyncWrite, C: Session
{
type Item = TlsStream<S, C>;
type Error = io::Error;
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
loop {
let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break };
let (io, session) = stream.get_mut();
let mut taskio = TaskStream { io, task: ctx };
match session.complete_io(&mut taskio) {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending),
Err(e) => return Err(e)
}
}
Ok(Async::Ready(self.inner.take().unwrap()))
}
}
impl<S, C> AsyncRead for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
if self.eof {
return Ok(Async::Ready(0));
}
// TODO nll
let result = {
let (io, session) = self.get_mut();
let mut taskio = TaskStream { io, task: ctx };
let mut stream = Stream::new(session, &mut taskio);
io::Read::read(&mut stream, buf)
};
match result {
Ok(0) => { self.eof = true; Ok(Async::Ready(0)) },
Ok(n) => Ok(Async::Ready(n)),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.eof = true;
self.is_shutdown = true;
self.session.send_close_notify();
Ok(Async::Ready(0))
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
Err(e) => Err(e)
}
}
}
impl<S, C> AsyncWrite for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{
fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll<usize, Error> {
let (io, session) = self.get_mut();
let mut taskio = TaskStream { io, task: ctx };
let mut stream = Stream::new(session, &mut taskio);
async!(from io::Write::write(&mut stream, buf))
}
fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> {
let (io, session) = self.get_mut();
let mut taskio = TaskStream { io, task: ctx };
{
let mut stream = Stream::new(session, &mut taskio);
async!(from io::Write::flush(&mut stream))?;
}
async!(from io::Write::flush(&mut taskio))
}
fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> {
if !self.is_shutdown {
self.session.send_close_notify();
self.is_shutdown = true;
}
{
let (io, session) = self.get_mut();
let mut taskio = TaskStream { io, task: ctx };
async!(from session.complete_io(&mut taskio))?;
}
self.io.poll_close(ctx)
}
}

View File

@ -1,10 +1,22 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))]
pub extern crate rustls;
pub extern crate webpki;
#[cfg(feature = "tokio")] mod tokio_impl;
#[cfg(feature = "unstable-futures")] mod futures_impl;
#[cfg(feature = "tokio-support")]
extern crate tokio;
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
extern crate bytes;
#[cfg(feature = "nightly")]
#[cfg(feature = "tokio-support")]
extern crate iovec;
mod common;
#[cfg(feature = "tokio-support")] mod tokio_impl;
use std::io;
use std::sync::Arc;
@ -12,8 +24,8 @@ use webpki::DNSNameRef;
use rustls::{
Session, ClientSession, ServerSession,
ClientConfig, ServerConfig,
Stream
};
use common::Stream;
pub struct TlsConnector {

View File

@ -1,9 +1,8 @@
extern crate tokio;
use super::*;
use self::tokio::prelude::*;
use self::tokio::io::{ AsyncRead, AsyncWrite };
use self::tokio::prelude::Poll;
use tokio::prelude::*;
use tokio::io::{ AsyncRead, AsyncWrite };
use tokio::prelude::Poll;
use common::Stream;
impl<S: AsyncRead + AsyncWrite> Future for Connect<S> {
@ -31,16 +30,17 @@ impl<S, C> Future for MidHandshake<S, C>
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
{
let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break };
if stream.session.is_handshaking() {
let (io, session) = stream.get_mut();
let mut stream = Stream::new(session, io);
let (io, session) = stream.get_mut();
match session.complete_io(io) {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
Err(e) => return Err(e)
match stream.complete_io() {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
Err(e) => return Err(e)
}
}
}
@ -52,7 +52,11 @@ impl<S, C> AsyncRead for TlsStream<S, C>
where
S: AsyncRead + AsyncWrite,
C: Session
{}
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<S, C> AsyncWrite for TlsStream<S, C>
where

View File

@ -1,81 +1,89 @@
#[macro_use] extern crate lazy_static;
extern crate rustls;
extern crate tokio;
extern crate tokio_rustls;
extern crate webpki;
#[cfg(feature = "unstable-futures")] extern crate futures;
use std::{ io, thread };
use std::io::{ BufReader, Cursor };
use std::sync::Arc;
use std::sync::mpsc::channel;
use std::net::{ SocketAddr, IpAddr, Ipv4Addr };
use std::net::SocketAddr;
use tokio::net::{ TcpListener, TcpStream };
use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig };
use rustls::{ ServerConfig, ClientConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::{ TlsConnector, TlsAcceptor };
const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain");
const RSA: &str = include_str!("end.rsa");
const HELLO_WORLD: &[u8] = b"Hello world!";
lazy_static!{
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
use tokio::prelude::*;
use tokio::io as aio;
fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
use tokio::prelude::*;
use tokio::io as aio;
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(cert, rsa)
.expect("invalid key or certificate");
let config = TlsAcceptor::from(Arc::new(config));
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(cert, keys.pop().unwrap())
.expect("invalid key or certificate");
let config = TlsAcceptor::from(Arc::new(config));
let (send, recv) = channel();
let (send, recv) = channel();
thread::spawn(move || {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
let listener = TcpListener::bind(&addr).unwrap();
thread::spawn(move || {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(&addr).unwrap();
send.send(listener.local_addr().unwrap()).unwrap();
send.send(listener.local_addr().unwrap()).unwrap();
let done = listener.incoming()
.for_each(move |stream| {
let done = config.accept(stream)
.and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
.and_then(|(stream, buf)| {
assert_eq!(buf, HELLO_WORLD);
aio::write_all(stream, HELLO_WORLD)
})
.then(|_| Ok(()));
let done = listener.incoming()
.for_each(move |stream| {
let done = config.accept(stream)
.and_then(|stream| {
let (reader, writer) = stream.split();
aio::copy(reader, writer)
})
.then(|_| Ok(()));
tokio::spawn(done);
Ok(())
})
.map_err(|err| panic!("{:?}", err));
tokio::spawn(done);
Ok(())
})
.map_err(|err| panic!("{:?}", err));
tokio::run(done);
});
tokio::run(done);
});
recv.recv().unwrap()
let addr = recv.recv().unwrap();
(addr, "localhost", CHAIN)
};
}
fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<&str>>>) -> io::Result<()> {
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
&*TEST_SERVER
}
fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> {
use tokio::prelude::*;
use tokio::io as aio;
const FILE: &'static [u8] = include_bytes!("../README.md");
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
let mut config = ClientConfig::new();
if let Some(mut chain) = chain {
config.root_store.add_pem_file(&mut chain).unwrap();
}
let mut chain = BufReader::new(Cursor::new(chain));
config.root_store.add_pem_file(&mut chain).unwrap();
let config = TlsConnector::from(Arc::new(config));
let done = TcpStream::connect(addr)
.and_then(|stream| config.connect(domain, stream))
.and_then(|stream| aio::write_all(stream, HELLO_WORLD))
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
.and_then(|stream| aio::write_all(stream, FILE))
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()]))
.and_then(|(stream, buf)| {
assert_eq!(buf, HELLO_WORLD);
assert_eq!(buf, FILE);
aio::shutdown(stream)
})
.map(drop);
@ -83,62 +91,17 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<
done.wait()
}
#[cfg(feature = "unstable-futures")]
fn start_client2(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<&str>>>) -> io::Result<()> {
use futures::FutureExt;
use futures::io::{ AsyncReadExt, AsyncWriteExt };
use futures::executor::block_on;
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
let mut config = ClientConfig::new();
if let Some(mut chain) = chain {
config.root_store.add_pem_file(&mut chain).unwrap();
}
let config = TlsConnector::from(Arc::new(config));
let done = TcpStream::connect(addr)
.and_then(|stream| config.connect(domain, stream))
.and_then(|stream| stream.write_all(HELLO_WORLD))
.and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()]))
.and_then(|(stream, buf)| {
assert_eq!(buf, HELLO_WORLD);
stream.close()
})
.map(drop);
block_on(done)
}
#[test]
fn pass() {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let chain = BufReader::new(Cursor::new(CHAIN));
let (addr, domain, chain) = start_server();
let addr = start_server(cert, keys.pop().unwrap());
start_client(&addr, "localhost", Some(chain)).unwrap();
start_client(addr, domain, chain).unwrap();
}
#[cfg(feature = "unstable-futures")]
#[test]
fn pass2() {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let chain = BufReader::new(Cursor::new(CHAIN));
let addr = start_server(cert, keys.pop().unwrap());
start_client2(&addr, "localhost", Some(chain)).unwrap();
}
#[should_panic]
#[test]
fn fail() {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let chain = BufReader::new(Cursor::new(CHAIN));
let (addr, domain, chain) = start_server();
let addr = start_server(cert, keys.pop().unwrap());
start_client(&addr, "google.com", Some(chain)).unwrap();
assert_ne!(domain, &"google.com");
assert!(start_client(addr, "google.com", chain).is_err());
}