Merge branch 'master' into new-api
This commit is contained in:
commit
f633a72f02
20
.travis.yml
20
.travis.yml
@ -1,13 +1,21 @@
|
||||
language: rust
|
||||
rust:
|
||||
- stable
|
||||
cache: cargo
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- rust: stable
|
||||
os: linux
|
||||
- rust: nightly
|
||||
env: FEATURE=nightly
|
||||
os: linux
|
||||
- rust: stable
|
||||
os: osx
|
||||
- rust: nightly
|
||||
env: FEATURE=nightly
|
||||
os: osx
|
||||
|
||||
script:
|
||||
- cargo test --all-features
|
||||
- cargo test --features "$FEATURE"
|
||||
- cd examples/server
|
||||
- cargo check
|
||||
- cd ../../examples/client
|
||||
|
17
Cargo.toml
17
Cargo.toml
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tokio-rustls"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
authors = ["quininer kel <quininer@live.com>"]
|
||||
license = "MIT/Apache-2.0"
|
||||
repository = "https://github.com/quininer/tokio-rustls"
|
||||
@ -15,20 +15,17 @@ travis-ci = { repository = "quininer/tokio-rustls" }
|
||||
appveyor = { repository = "quininer/tokio-rustls" }
|
||||
|
||||
[dependencies]
|
||||
futures-core = { version = "0.2.0", optional = true }
|
||||
futures-io = { version = "0.2.0", optional = true }
|
||||
tokio = { version = "0.1.6", optional = true }
|
||||
bytes = { version = "0.4", optional = true }
|
||||
iovec = { version = "0.1", optional = true }
|
||||
rustls = "0.13"
|
||||
webpki = "0.18.1"
|
||||
|
||||
[dev-dependencies]
|
||||
# futures = "0.2.0"
|
||||
tokio = "0.1.6"
|
||||
lazy_static = "1"
|
||||
|
||||
[features]
|
||||
default = [ "tokio" ]
|
||||
# unstable-futures = [
|
||||
# "futures-core",
|
||||
# "futures-io",
|
||||
# "tokio/unstable-futures"
|
||||
# ]
|
||||
default = ["tokio-support"]
|
||||
nightly = ["bytes", "iovec"]
|
||||
tokio-support = ["tokio"]
|
||||
|
@ -13,7 +13,7 @@ install:
|
||||
build: false
|
||||
|
||||
test_script:
|
||||
- 'cargo test --all-features'
|
||||
- 'cargo test'
|
||||
- 'cd examples/server'
|
||||
- 'cargo check'
|
||||
- 'cd ../../examples/client'
|
||||
|
@ -5,15 +5,8 @@ authors = ["quininer <quininer@live.com>"]
|
||||
|
||||
[dependencies]
|
||||
webpki = "0.18.1"
|
||||
tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] }
|
||||
|
||||
tokio-rustls = { path = "../.." }
|
||||
tokio = "0.1"
|
||||
|
||||
clap = "2.26"
|
||||
clap = "2"
|
||||
webpki-roots = "0.15"
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
tokio-file-unix = "0.5"
|
||||
|
||||
[target.'cfg(not(unix))'.dependencies]
|
||||
tokio-fs = "0.1"
|
||||
tokio-stdin-stdout = "0.1"
|
||||
|
@ -4,8 +4,7 @@ extern crate webpki;
|
||||
extern crate webpki_roots;
|
||||
extern crate tokio_rustls;
|
||||
|
||||
#[cfg(unix)] extern crate tokio_file_unix;
|
||||
#[cfg(not(unix))] extern crate tokio_fs;
|
||||
extern crate tokio_stdin_stdout;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::net::ToSocketAddrs;
|
||||
@ -16,6 +15,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::prelude::*;
|
||||
use clap::{ App, Arg };
|
||||
use tokio_rustls::{ TlsConnector, rustls::ClientConfig };
|
||||
use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout };
|
||||
|
||||
fn app() -> App<'static, 'static> {
|
||||
App::new("client")
|
||||
@ -52,59 +52,23 @@ fn main() {
|
||||
let config = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let socket = TcpStream::connect(&addr);
|
||||
let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0));
|
||||
|
||||
#[cfg(unix)]
|
||||
let resp = {
|
||||
use tokio::reactor::Handle;
|
||||
use tokio_file_unix::{ raw_stdin, raw_stdout, File };
|
||||
let done = socket
|
||||
.and_then(move |stream| {
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
||||
config.connect(domain, stream)
|
||||
})
|
||||
.and_then(move |stream| io::write_all(stream, text))
|
||||
.and_then(move |(stream, _)| {
|
||||
let (r, w) = stream.split();
|
||||
io::copy(r, stdout)
|
||||
.map(drop)
|
||||
.select2(io::copy(stdin, w).map(drop))
|
||||
.map_err(|res| res.split().0)
|
||||
})
|
||||
.map(drop)
|
||||
.map_err(|err| eprintln!("{:?}", err));
|
||||
|
||||
let stdin = raw_stdin()
|
||||
.and_then(File::new_nb)
|
||||
.and_then(|fd| fd.into_reader(&Handle::current()))
|
||||
.unwrap();
|
||||
let stdout = raw_stdout()
|
||||
.and_then(File::new_nb)
|
||||
.and_then(|fd| fd.into_io(&Handle::current()))
|
||||
.unwrap();
|
||||
|
||||
socket
|
||||
.and_then(move |stream| {
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
||||
config.connect(domain, stream)
|
||||
})
|
||||
.and_then(move |stream| io::write_all(stream, text))
|
||||
.and_then(move |(stream, _)| {
|
||||
let (r, w) = stream.split();
|
||||
io::copy(r, stdout)
|
||||
.map(drop)
|
||||
.select2(io::copy(stdin, w).map(drop))
|
||||
.map_err(|res| res.split().0)
|
||||
})
|
||||
.map(drop)
|
||||
.map_err(|err| eprintln!("{:?}", err))
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let resp = {
|
||||
use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout };
|
||||
|
||||
let (stdin, stdout) = (tokio_stdin(), tokio_stdout());
|
||||
|
||||
socket
|
||||
.and_then(move |stream| {
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
||||
config.connect(domain, stream)
|
||||
})
|
||||
.and_then(move |stream| io::write_all(stream, text))
|
||||
.and_then(move |(stream, _)| {
|
||||
let (r, w) = stream.split();
|
||||
io::copy(r, stdout)
|
||||
.map(drop)
|
||||
.join(io::copy(stdin, w).map(drop))
|
||||
})
|
||||
.map(drop)
|
||||
.map_err(|err| eprintln!("{:?}", err))
|
||||
};
|
||||
|
||||
tokio::run(resp);
|
||||
tokio::run(done);
|
||||
}
|
||||
|
@ -4,9 +4,6 @@ version = "0.1.0"
|
||||
authors = ["quininer <quininer@live.com>"]
|
||||
|
||||
[dependencies]
|
||||
tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] }
|
||||
|
||||
tokio-rustls = { path = "../.." }
|
||||
tokio = { version = "0.1.6" }
|
||||
# futures = "0.2.0-beta"
|
||||
|
||||
clap = "2.26"
|
||||
clap = "2"
|
||||
|
153
src/common/mod.rs
Normal file
153
src/common/mod.rs
Normal file
@ -0,0 +1,153 @@
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "tokio-support")]
|
||||
mod vecbuf;
|
||||
|
||||
use std::io::{ self, Read, Write };
|
||||
#[cfg(feature = "nightly")]
|
||||
use std::io::Initializer;
|
||||
use rustls::Session;
|
||||
#[cfg(feature = "nightly")]
|
||||
use rustls::WriteV;
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "tokio-support")]
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
pub struct Stream<'a, S: 'a, IO: 'a> {
|
||||
session: &'a mut S,
|
||||
io: &'a mut IO
|
||||
}
|
||||
|
||||
pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write {
|
||||
fn write_tls(&mut self) -> io::Result<usize>;
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> {
|
||||
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
|
||||
Stream { session, io }
|
||||
}
|
||||
|
||||
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
|
||||
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161
|
||||
|
||||
let until_handshaked = self.session.is_handshaking();
|
||||
let mut eof = false;
|
||||
let mut wrlen = 0;
|
||||
let mut rdlen = 0;
|
||||
|
||||
loop {
|
||||
while self.session.wants_write() {
|
||||
wrlen += self.write_tls()?;
|
||||
}
|
||||
|
||||
if !until_handshaked && wrlen > 0 {
|
||||
return Ok((rdlen, wrlen));
|
||||
}
|
||||
|
||||
if !eof && self.session.wants_read() {
|
||||
match self.session.read_tls(self.io)? {
|
||||
0 => eof = true,
|
||||
n => rdlen += n
|
||||
}
|
||||
}
|
||||
|
||||
match self.session.process_new_packets() {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
// In case we have an alert to send describing this error,
|
||||
// try a last-gasp write -- but don't predate the primary
|
||||
// error.
|
||||
let _ignored = self.write_tls();
|
||||
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
|
||||
},
|
||||
};
|
||||
|
||||
match (eof, until_handshaked, self.session.is_handshaking()) {
|
||||
(_, true, false) => return Ok((rdlen, wrlen)),
|
||||
(_, false, _) => return Ok((rdlen, wrlen)),
|
||||
(true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
|
||||
(..) => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "nightly"))]
|
||||
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||
fn write_tls(&mut self) -> io::Result<usize> {
|
||||
self.session.write_tls(self.io)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||
default fn write_tls(&mut self) -> io::Result<usize> {
|
||||
self.session.write_tls(self.io)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "tokio-support")]
|
||||
impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||
fn write_tls(&mut self) -> io::Result<usize> {
|
||||
use tokio::prelude::Async;
|
||||
use self::vecbuf::VecBuf;
|
||||
|
||||
struct V<'a, IO: 'a>(&'a mut IO);
|
||||
|
||||
impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> {
|
||||
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize> {
|
||||
let mut vbytes = VecBuf::new(vbytes);
|
||||
match self.0.write_buf(&mut vbytes) {
|
||||
Ok(Async::Ready(n)) => Ok(n),
|
||||
Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()),
|
||||
Err(err) => Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut vecbuf = V(self.io);
|
||||
self.session.writev_tls(&mut vecbuf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> {
|
||||
#[cfg(feature = "nightly")]
|
||||
unsafe fn initializer(&self) -> Initializer {
|
||||
Initializer::nop()
|
||||
}
|
||||
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
while self.session.wants_read() {
|
||||
if let (0, 0) = self.complete_io()? {
|
||||
break
|
||||
}
|
||||
}
|
||||
self.session.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let len = self.session.write(buf)?;
|
||||
while self.session.wants_write() {
|
||||
match self.complete_io() {
|
||||
Ok(_) => (),
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
|
||||
Err(err) => return Err(err)
|
||||
}
|
||||
}
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.session.flush()?;
|
||||
if self.session.wants_write() {
|
||||
self.complete_io()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_stream;
|
161
src/common/test_stream.rs
Normal file
161
src/common/test_stream.rs
Normal file
@ -0,0 +1,161 @@
|
||||
use std::sync::Arc;
|
||||
use std::io::{ self, Read, Write, BufReader, Cursor };
|
||||
use webpki::DNSNameRef;
|
||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
||||
use rustls::{
|
||||
ServerConfig, ClientConfig,
|
||||
ServerSession, ClientSession,
|
||||
Session, NoClientAuth
|
||||
};
|
||||
use super::Stream;
|
||||
|
||||
|
||||
struct Good<'a>(&'a mut Session);
|
||||
|
||||
impl<'a> Read for Good<'a> {
|
||||
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.write_tls(buf.by_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Write for Good<'a> {
|
||||
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
|
||||
let len = self.0.read_tls(buf.by_ref())?;
|
||||
self.0.process_new_packets()
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct Bad(bool);
|
||||
|
||||
impl Read for Bad {
|
||||
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Bad {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
if self.0 {
|
||||
Err(io::ErrorKind::WouldBlock.into())
|
||||
} else {
|
||||
Ok(buf.len())
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn stream_good() -> io::Result<()> {
|
||||
const FILE: &'static [u8] = include_bytes!("../../README.md");
|
||||
|
||||
let (mut server, mut client) = make_pair();
|
||||
do_handshake(&mut client, &mut server);
|
||||
io::copy(&mut Cursor::new(FILE), &mut server)?;
|
||||
|
||||
{
|
||||
let mut good = Good(&mut server);
|
||||
let mut stream = Stream::new(&mut client, &mut good);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
stream.read_to_end(&mut buf)?;
|
||||
assert_eq!(buf, FILE);
|
||||
stream.write_all(b"Hello World!")?
|
||||
}
|
||||
|
||||
let mut buf = String::new();
|
||||
server.read_to_string(&mut buf)?;
|
||||
assert_eq!(buf, "Hello World!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_bad() -> io::Result<()> {
|
||||
let (mut server, mut client) = make_pair();
|
||||
do_handshake(&mut client, &mut server);
|
||||
client.set_buffer_limit(1024);
|
||||
|
||||
let mut bad = Bad(true);
|
||||
let mut stream = Stream::new(&mut client, &mut bad);
|
||||
assert_eq!(stream.write(&[0x42; 8])?, 8);
|
||||
assert_eq!(stream.write(&[0x42; 8])?, 8);
|
||||
let r = stream.write(&[0x00; 1024])?; // fill buffer
|
||||
assert!(r < 1024);
|
||||
assert_eq!(
|
||||
stream.write(&[0x01]).unwrap_err().kind(),
|
||||
io::ErrorKind::WouldBlock
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_handshake() -> io::Result<()> {
|
||||
let (mut server, mut client) = make_pair();
|
||||
|
||||
{
|
||||
let mut good = Good(&mut server);
|
||||
let mut stream = Stream::new(&mut client, &mut good);
|
||||
let (r, w) = stream.complete_io()?;
|
||||
|
||||
assert!(r > 0);
|
||||
assert!(w > 0);
|
||||
|
||||
stream.complete_io()?; // finish server handshake
|
||||
}
|
||||
|
||||
assert!(!server.is_handshaking());
|
||||
assert!(!client.is_handshaking());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_handshake_eof() -> io::Result<()> {
|
||||
let (_, mut client) = make_pair();
|
||||
|
||||
let mut bad = Bad(false);
|
||||
let mut stream = Stream::new(&mut client, &mut bad);
|
||||
let r = stream.complete_io();
|
||||
|
||||
assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_pair() -> (ServerSession, ClientSession) {
|
||||
const CERT: &str = include_str!("../../tests/end.cert");
|
||||
const CHAIN: &str = include_str!("../../tests/end.chain");
|
||||
const RSA: &str = include_str!("../../tests/end.rsa");
|
||||
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
let mut sconfig = ServerConfig::new(NoClientAuth::new());
|
||||
sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap();
|
||||
let server = ServerSession::new(&Arc::new(sconfig));
|
||||
|
||||
let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap();
|
||||
let mut cconfig = ClientConfig::new();
|
||||
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
||||
cconfig.root_store.add_pem_file(&mut chain).unwrap();
|
||||
let client = ClientSession::new(&Arc::new(cconfig), domain);
|
||||
|
||||
(server, client)
|
||||
}
|
||||
|
||||
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) {
|
||||
let mut good = Good(server);
|
||||
let mut stream = Stream::new(client, &mut good);
|
||||
stream.complete_io().unwrap();
|
||||
stream.complete_io().unwrap();
|
||||
}
|
122
src/common/vecbuf.rs
Normal file
122
src/common/vecbuf.rs
Normal file
@ -0,0 +1,122 @@
|
||||
use std::cmp::{ self, Ordering };
|
||||
use bytes::Buf;
|
||||
use iovec::IoVec;
|
||||
|
||||
pub struct VecBuf<'a, 'b: 'a> {
|
||||
pos: usize,
|
||||
cur: usize,
|
||||
inner: &'a [&'b [u8]]
|
||||
}
|
||||
|
||||
impl<'a, 'b> VecBuf<'a, 'b> {
|
||||
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
|
||||
VecBuf { pos: 0, cur: 0, inner: vbytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
||||
fn remaining(&self) -> usize {
|
||||
let sum = self.inner
|
||||
.iter()
|
||||
.skip(self.pos)
|
||||
.map(|bytes| bytes.len())
|
||||
.sum::<usize>();
|
||||
sum - self.cur
|
||||
}
|
||||
|
||||
fn bytes(&self) -> &[u8] {
|
||||
&self.inner[self.pos][self.cur..]
|
||||
}
|
||||
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
let current = self.inner[self.pos].len();
|
||||
match (self.cur + cnt).cmp(¤t) {
|
||||
Ordering::Equal => if self.pos + 1 < self.inner.len() {
|
||||
self.pos += 1;
|
||||
self.cur = 0;
|
||||
} else {
|
||||
self.cur += cnt;
|
||||
},
|
||||
Ordering::Greater => {
|
||||
if self.pos + 1 < self.inner.len() {
|
||||
self.pos += 1;
|
||||
}
|
||||
let remaining = self.cur + cnt - current;
|
||||
self.advance(remaining);
|
||||
},
|
||||
Ordering::Less => self.cur += cnt,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))]
|
||||
fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize {
|
||||
let len = cmp::min(self.inner.len() - self.pos, dst.len());
|
||||
|
||||
if len > 0 {
|
||||
dst[0] = self.bytes().into();
|
||||
}
|
||||
|
||||
for i in 1..len {
|
||||
dst[i] = self.inner[self.pos + i].into();
|
||||
}
|
||||
|
||||
len
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_vecbuf {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fresh_cursor_vec() {
|
||||
let mut buf = VecBuf::new(&[b"he", b"llo"]);
|
||||
|
||||
assert_eq!(buf.remaining(), 5);
|
||||
assert_eq!(buf.bytes(), b"he");
|
||||
|
||||
buf.advance(2);
|
||||
|
||||
assert_eq!(buf.remaining(), 3);
|
||||
assert_eq!(buf.bytes(), b"llo");
|
||||
|
||||
buf.advance(3);
|
||||
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
assert_eq!(buf.bytes(), b"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_u8() {
|
||||
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
|
||||
assert_eq!(0x21, buf.get_u8());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_u16() {
|
||||
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||
assert_eq!(0x2154, buf.get_u16_be());
|
||||
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||
assert_eq!(0x5421, buf.get_u16_le());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_get_u16_buffer_underflow() {
|
||||
let mut buf = VecBuf::new(&[b"\x21"]);
|
||||
buf.get_u16_be();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bufs_vec() {
|
||||
let buf = VecBuf::new(&[b"he", b"llo"]);
|
||||
|
||||
let b1: &[u8] = &mut [0];
|
||||
let b2: &[u8] = &mut [0];
|
||||
|
||||
let mut dst: [&IoVec; 2] =
|
||||
[b1.into(), b2.into()];
|
||||
|
||||
assert_eq!(2, buf.bytes_vec(&mut dst[..]));
|
||||
}
|
||||
}
|
@ -1,170 +0,0 @@
|
||||
extern crate futures_core;
|
||||
extern crate futures_io;
|
||||
|
||||
use super::*;
|
||||
use self::futures_core::{ Future, Poll, Async };
|
||||
use self::futures_core::task::Context;
|
||||
use self::futures_io::{ Error, AsyncRead, AsyncWrite };
|
||||
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
|
||||
type Item = TlsStream<S, ClientSession>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
|
||||
self.0.poll(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite> Future for AcceptAsync<S> {
|
||||
type Item = TlsStream<S, ServerSession>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
|
||||
self.0.poll(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! async {
|
||||
( to $r:expr ) => {
|
||||
match $r {
|
||||
Ok(Async::Ready(n)) => Ok(n),
|
||||
Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()),
|
||||
Err(e) => Err(e)
|
||||
}
|
||||
};
|
||||
( from $r:expr ) => {
|
||||
match $r {
|
||||
Ok(n) => Ok(Async::Ready(n)),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
|
||||
Err(e) => Err(e)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
struct TaskStream<'a, 'b: 'a, S: 'a> {
|
||||
io: &'a mut S,
|
||||
task: &'a mut Context<'b>
|
||||
}
|
||||
|
||||
impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S>
|
||||
where S: AsyncRead + AsyncWrite
|
||||
{
|
||||
#[inline]
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
async!(to self.io.poll_read(self.task, buf))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S>
|
||||
where S: AsyncRead + AsyncWrite
|
||||
{
|
||||
#[inline]
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
async!(to self.io.poll_write(self.task, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
async!(to self.io.poll_flush(self.task))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Future for MidHandshake<S, C>
|
||||
where S: AsyncRead + AsyncWrite, C: Session
|
||||
{
|
||||
type Item = TlsStream<S, C>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
|
||||
loop {
|
||||
let stream = self.inner.as_mut().unwrap();
|
||||
if !stream.session.is_handshaking() { break };
|
||||
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut taskio = TaskStream { io, task: ctx };
|
||||
|
||||
match session.complete_io(&mut taskio) {
|
||||
Ok(_) => (),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending),
|
||||
Err(e) => return Err(e)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Async::Ready(self.inner.take().unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> AsyncRead for TlsStream<S, C>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
C: Session
|
||||
{
|
||||
fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
|
||||
if self.eof {
|
||||
return Ok(Async::Ready(0));
|
||||
}
|
||||
|
||||
// TODO nll
|
||||
let result = {
|
||||
let (io, session) = self.get_mut();
|
||||
let mut taskio = TaskStream { io, task: ctx };
|
||||
let mut stream = Stream::new(session, &mut taskio);
|
||||
io::Read::read(&mut stream, buf)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(0) => { self.eof = true; Ok(Async::Ready(0)) },
|
||||
Ok(n) => Ok(Async::Ready(n)),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.eof = true;
|
||||
self.is_shutdown = true;
|
||||
self.session.send_close_notify();
|
||||
Ok(Async::Ready(0))
|
||||
},
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
|
||||
Err(e) => Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> AsyncWrite for TlsStream<S, C>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
C: Session
|
||||
{
|
||||
fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll<usize, Error> {
|
||||
let (io, session) = self.get_mut();
|
||||
let mut taskio = TaskStream { io, task: ctx };
|
||||
let mut stream = Stream::new(session, &mut taskio);
|
||||
|
||||
async!(from io::Write::write(&mut stream, buf))
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> {
|
||||
let (io, session) = self.get_mut();
|
||||
let mut taskio = TaskStream { io, task: ctx };
|
||||
|
||||
{
|
||||
let mut stream = Stream::new(session, &mut taskio);
|
||||
async!(from io::Write::flush(&mut stream))?;
|
||||
}
|
||||
|
||||
async!(from io::Write::flush(&mut taskio))
|
||||
}
|
||||
|
||||
fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> {
|
||||
if !self.is_shutdown {
|
||||
self.session.send_close_notify();
|
||||
self.is_shutdown = true;
|
||||
}
|
||||
|
||||
{
|
||||
let (io, session) = self.get_mut();
|
||||
let mut taskio = TaskStream { io, task: ctx };
|
||||
async!(from session.complete_io(&mut taskio))?;
|
||||
}
|
||||
|
||||
self.io.poll_close(ctx)
|
||||
}
|
||||
}
|
18
src/lib.rs
18
src/lib.rs
@ -1,10 +1,22 @@
|
||||
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
|
||||
|
||||
#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))]
|
||||
|
||||
pub extern crate rustls;
|
||||
pub extern crate webpki;
|
||||
|
||||
#[cfg(feature = "tokio")] mod tokio_impl;
|
||||
#[cfg(feature = "unstable-futures")] mod futures_impl;
|
||||
#[cfg(feature = "tokio-support")]
|
||||
extern crate tokio;
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "tokio-support")]
|
||||
extern crate bytes;
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "tokio-support")]
|
||||
extern crate iovec;
|
||||
|
||||
|
||||
mod common;
|
||||
#[cfg(feature = "tokio-support")] mod tokio_impl;
|
||||
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
@ -12,8 +24,8 @@ use webpki::DNSNameRef;
|
||||
use rustls::{
|
||||
Session, ClientSession, ServerSession,
|
||||
ClientConfig, ServerConfig,
|
||||
Stream
|
||||
};
|
||||
use common::Stream;
|
||||
|
||||
|
||||
pub struct TlsConnector {
|
||||
|
@ -1,9 +1,8 @@
|
||||
extern crate tokio;
|
||||
|
||||
use super::*;
|
||||
use self::tokio::prelude::*;
|
||||
use self::tokio::io::{ AsyncRead, AsyncWrite };
|
||||
use self::tokio::prelude::Poll;
|
||||
use tokio::prelude::*;
|
||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
||||
use tokio::prelude::Poll;
|
||||
use common::Stream;
|
||||
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite> Future for Connect<S> {
|
||||
@ -31,16 +30,17 @@ impl<S, C> Future for MidHandshake<S, C>
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
loop {
|
||||
{
|
||||
let stream = self.inner.as_mut().unwrap();
|
||||
if !stream.session.is_handshaking() { break };
|
||||
if stream.session.is_handshaking() {
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(session, io);
|
||||
|
||||
let (io, session) = stream.get_mut();
|
||||
|
||||
match session.complete_io(io) {
|
||||
Ok(_) => (),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
|
||||
Err(e) => return Err(e)
|
||||
match stream.complete_io() {
|
||||
Ok(_) => (),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
|
||||
Err(e) => return Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +52,11 @@ impl<S, C> AsyncRead for TlsStream<S, C>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
C: Session
|
||||
{}
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> AsyncWrite for TlsStream<S, C>
|
||||
where
|
||||
|
143
tests/test.rs
143
tests/test.rs
@ -1,81 +1,89 @@
|
||||
#[macro_use] extern crate lazy_static;
|
||||
extern crate rustls;
|
||||
extern crate tokio;
|
||||
extern crate tokio_rustls;
|
||||
extern crate webpki;
|
||||
|
||||
#[cfg(feature = "unstable-futures")] extern crate futures;
|
||||
|
||||
use std::{ io, thread };
|
||||
use std::io::{ BufReader, Cursor };
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::net::{ SocketAddr, IpAddr, Ipv4Addr };
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::{ TcpListener, TcpStream };
|
||||
use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig };
|
||||
use rustls::{ ServerConfig, ClientConfig };
|
||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
||||
use tokio_rustls::{ TlsConnector, TlsAcceptor };
|
||||
|
||||
const CERT: &str = include_str!("end.cert");
|
||||
const CHAIN: &str = include_str!("end.chain");
|
||||
const RSA: &str = include_str!("end.rsa");
|
||||
const HELLO_WORLD: &[u8] = b"Hello world!";
|
||||
|
||||
lazy_static!{
|
||||
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
|
||||
use tokio::prelude::*;
|
||||
use tokio::io as aio;
|
||||
|
||||
fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
|
||||
use tokio::prelude::*;
|
||||
use tokio::io as aio;
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
|
||||
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(cert, rsa)
|
||||
.expect("invalid key or certificate");
|
||||
let config = TlsAcceptor::from(Arc::new(config));
|
||||
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(cert, keys.pop().unwrap())
|
||||
.expect("invalid key or certificate");
|
||||
let config = TlsAcceptor::from(Arc::new(config));
|
||||
|
||||
let (send, recv) = channel();
|
||||
let (send, recv) = channel();
|
||||
|
||||
thread::spawn(move || {
|
||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
|
||||
let listener = TcpListener::bind(&addr).unwrap();
|
||||
thread::spawn(move || {
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
|
||||
let listener = TcpListener::bind(&addr).unwrap();
|
||||
|
||||
send.send(listener.local_addr().unwrap()).unwrap();
|
||||
send.send(listener.local_addr().unwrap()).unwrap();
|
||||
|
||||
let done = listener.incoming()
|
||||
.for_each(move |stream| {
|
||||
let done = config.accept(stream)
|
||||
.and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
|
||||
.and_then(|(stream, buf)| {
|
||||
assert_eq!(buf, HELLO_WORLD);
|
||||
aio::write_all(stream, HELLO_WORLD)
|
||||
})
|
||||
.then(|_| Ok(()));
|
||||
let done = listener.incoming()
|
||||
.for_each(move |stream| {
|
||||
let done = config.accept(stream)
|
||||
.and_then(|stream| {
|
||||
let (reader, writer) = stream.split();
|
||||
aio::copy(reader, writer)
|
||||
})
|
||||
.then(|_| Ok(()));
|
||||
|
||||
tokio::spawn(done);
|
||||
Ok(())
|
||||
})
|
||||
.map_err(|err| panic!("{:?}", err));
|
||||
tokio::spawn(done);
|
||||
Ok(())
|
||||
})
|
||||
.map_err(|err| panic!("{:?}", err));
|
||||
|
||||
tokio::run(done);
|
||||
});
|
||||
tokio::run(done);
|
||||
});
|
||||
|
||||
recv.recv().unwrap()
|
||||
let addr = recv.recv().unwrap();
|
||||
(addr, "localhost", CHAIN)
|
||||
};
|
||||
}
|
||||
|
||||
fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<&str>>>) -> io::Result<()> {
|
||||
|
||||
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
|
||||
&*TEST_SERVER
|
||||
}
|
||||
|
||||
fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> {
|
||||
use tokio::prelude::*;
|
||||
use tokio::io as aio;
|
||||
|
||||
const FILE: &'static [u8] = include_bytes!("../README.md");
|
||||
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
||||
let mut config = ClientConfig::new();
|
||||
if let Some(mut chain) = chain {
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
}
|
||||
let mut chain = BufReader::new(Cursor::new(chain));
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
let config = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let done = TcpStream::connect(addr)
|
||||
.and_then(|stream| config.connect(domain, stream))
|
||||
.and_then(|stream| aio::write_all(stream, HELLO_WORLD))
|
||||
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
|
||||
.and_then(|stream| aio::write_all(stream, FILE))
|
||||
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()]))
|
||||
.and_then(|(stream, buf)| {
|
||||
assert_eq!(buf, HELLO_WORLD);
|
||||
assert_eq!(buf, FILE);
|
||||
aio::shutdown(stream)
|
||||
})
|
||||
.map(drop);
|
||||
@ -83,62 +91,17 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<
|
||||
done.wait()
|
||||
}
|
||||
|
||||
#[cfg(feature = "unstable-futures")]
|
||||
fn start_client2(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<&str>>>) -> io::Result<()> {
|
||||
use futures::FutureExt;
|
||||
use futures::io::{ AsyncReadExt, AsyncWriteExt };
|
||||
use futures::executor::block_on;
|
||||
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
||||
let mut config = ClientConfig::new();
|
||||
if let Some(mut chain) = chain {
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
}
|
||||
let config = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let done = TcpStream::connect(addr)
|
||||
.and_then(|stream| config.connect(domain, stream))
|
||||
.and_then(|stream| stream.write_all(HELLO_WORLD))
|
||||
.and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()]))
|
||||
.and_then(|(stream, buf)| {
|
||||
assert_eq!(buf, HELLO_WORLD);
|
||||
stream.close()
|
||||
})
|
||||
.map(drop);
|
||||
|
||||
block_on(done)
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn pass() {
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
let chain = BufReader::new(Cursor::new(CHAIN));
|
||||
let (addr, domain, chain) = start_server();
|
||||
|
||||
let addr = start_server(cert, keys.pop().unwrap());
|
||||
start_client(&addr, "localhost", Some(chain)).unwrap();
|
||||
start_client(addr, domain, chain).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(feature = "unstable-futures")]
|
||||
#[test]
|
||||
fn pass2() {
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
let chain = BufReader::new(Cursor::new(CHAIN));
|
||||
|
||||
let addr = start_server(cert, keys.pop().unwrap());
|
||||
start_client2(&addr, "localhost", Some(chain)).unwrap();
|
||||
}
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn fail() {
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
let chain = BufReader::new(Cursor::new(CHAIN));
|
||||
let (addr, domain, chain) = start_server();
|
||||
|
||||
let addr = start_server(cert, keys.pop().unwrap());
|
||||
|
||||
start_client(&addr, "google.com", Some(chain)).unwrap();
|
||||
assert_ne!(domain, &"google.com");
|
||||
assert!(start_client(addr, "google.com", chain).is_err());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user