diff --git a/.travis.yml b/.travis.yml index 3d5e1db..043e804 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,21 @@ language: rust -rust: - - stable cache: cargo -os: - - linux - - osx + +matrix: + include: + - rust: stable + os: linux + - rust: nightly + env: FEATURE=nightly + os: linux + - rust: stable + os: osx + - rust: nightly + env: FEATURE=nightly + os: osx script: - - cargo test --all-features + - cargo test --features "$FEATURE" - cd examples/server - cargo check - cd ../../examples/client diff --git a/Cargo.toml b/Cargo.toml index e289e0f..6106795 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.7.1" +version = "0.7.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -15,20 +15,17 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures-core = { version = "0.2.0", optional = true } -futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } +bytes = { version = "0.4", optional = true } +iovec = { version = "0.1", optional = true } rustls = "0.13" webpki = "0.18.1" [dev-dependencies] -# futures = "0.2.0" tokio = "0.1.6" +lazy_static = "1" [features] -default = [ "tokio" ] -# unstable-futures = [ -# "futures-core", -# "futures-io", -# "tokio/unstable-futures" -# ] +default = ["tokio-support"] +nightly = ["bytes", "iovec"] +tokio-support = ["tokio"] diff --git a/appveyor.yml b/appveyor.yml index 7ede91c..038274b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -13,7 +13,7 @@ install: build: false test_script: - - 'cargo test --all-features' + - 'cargo test' - 'cd examples/server' - 'cargo check' - 'cd ../../examples/client' diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index acf13bd..780ea88 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -5,15 +5,8 @@ authors = ["quininer "] [dependencies] webpki = "0.18.1" -tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } - +tokio-rustls = { path = "../.." } tokio = "0.1" - -clap = "2.26" +clap = "2" webpki-roots = "0.15" - -[target.'cfg(unix)'.dependencies] -tokio-file-unix = "0.5" - -[target.'cfg(not(unix))'.dependencies] -tokio-fs = "0.1" +tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 0d34f64..ff6b315 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -4,8 +4,7 @@ extern crate webpki; extern crate webpki_roots; extern crate tokio_rustls; -#[cfg(unix)] extern crate tokio_file_unix; -#[cfg(not(unix))] extern crate tokio_fs; +extern crate tokio_stdin_stdout; use std::sync::Arc; use std::net::ToSocketAddrs; @@ -16,6 +15,7 @@ use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; +use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; fn app() -> App<'static, 'static> { App::new("client") @@ -52,59 +52,23 @@ fn main() { let config = TlsConnector::from(Arc::new(config)); let socket = TcpStream::connect(&addr); + let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0)); - #[cfg(unix)] - let resp = { - use tokio::reactor::Handle; - use tokio_file_unix::{ raw_stdin, raw_stdout, File }; + let done = socket + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + config.connect(domain, stream) + }) + .and_then(move |stream| io::write_all(stream, text)) + .and_then(move |(stream, _)| { + let (r, w) = stream.split(); + io::copy(r, stdout) + .map(drop) + .select2(io::copy(stdin, w).map(drop)) + .map_err(|res| res.split().0) + }) + .map(drop) + .map_err(|err| eprintln!("{:?}", err)); - let stdin = raw_stdin() - .and_then(File::new_nb) - .and_then(|fd| fd.into_reader(&Handle::current())) - .unwrap(); - let stdout = raw_stdout() - .and_then(File::new_nb) - .and_then(|fd| fd.into_io(&Handle::current())) - .unwrap(); - - socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - config.connect(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .select2(io::copy(stdin, w).map(drop)) - .map_err(|res| res.split().0) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)) - }; - - #[cfg(not(unix))] - let resp = { - use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout }; - - let (stdin, stdout) = (tokio_stdin(), tokio_stdout()); - - socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - config.connect(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .join(io::copy(stdin, w).map(drop)) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)) - }; - - tokio::run(resp); + tokio::run(done); } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 98329ce..170693f 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -4,9 +4,6 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } - +tokio-rustls = { path = "../.." } tokio = { version = "0.1.6" } -# futures = "0.2.0-beta" - -clap = "2.26" +clap = "2" diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..7db198e --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,153 @@ +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +mod vecbuf; + +use std::io::{ self, Read, Write }; +#[cfg(feature = "nightly")] +use std::io::Initializer; +use rustls::Session; +#[cfg(feature = "nightly")] +use rustls::WriteV; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +use tokio::io::AsyncWrite; + +pub struct Stream<'a, S: 'a, IO: 'a> { + session: &'a mut S, + io: &'a mut IO +} + +pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { + fn write_tls(&mut self) -> io::Result; +} + +impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { + pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { + Stream { session, io } + } + + pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { + // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 + + let until_handshaked = self.session.is_handshaking(); + let mut eof = false; + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + while self.session.wants_write() { + wrlen += self.write_tls()?; + } + + if !until_handshaked && wrlen > 0 { + return Ok((rdlen, wrlen)); + } + + if !eof && self.session.wants_read() { + match self.session.read_tls(self.io)? { + 0 => eof = true, + n => rdlen += n + } + } + + match self.session.process_new_packets() { + Ok(_) => {}, + Err(e) => { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ignored = self.write_tls(); + + return Err(io::Error::new(io::ErrorKind::InvalidData, e)); + }, + }; + + match (eof, until_handshaked, self.session.is_handshaking()) { + (_, true, false) => return Ok((rdlen, wrlen)), + (_, false, _) => return Ok((rdlen, wrlen)), + (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (..) => () + } + } + } +} + +#[cfg(not(feature = "nightly"))] +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +#[cfg(feature = "nightly")] +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + use tokio::prelude::Async; + use self::vecbuf::VecBuf; + + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + +impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { + #[cfg(feature = "nightly")] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } + + fn read(&mut self, buf: &mut [u8]) -> io::Result { + while self.session.wants_read() { + if let (0, 0) = self.complete_io()? { + break + } + } + self.session.read(buf) + } +} + +impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = self.session.write(buf)?; + while self.session.wants_write() { + match self.complete_io() { + Ok(_) => (), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, + Err(err) => return Err(err) + } + } + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + self.session.flush()?; + if self.session.wants_write() { + self.complete_io()?; + } + Ok(()) + } +} + +#[cfg(test)] +mod test_stream; diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs new file mode 100644 index 0000000..2cd9e87 --- /dev/null +++ b/src/common/test_stream.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; +use std::io::{ self, Read, Write, BufReader, Cursor }; +use webpki::DNSNameRef; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use rustls::{ + ServerConfig, ClientConfig, + ServerSession, ClientSession, + Session, NoClientAuth +}; +use super::Stream; + + +struct Good<'a>(&'a mut Session); + +impl<'a> Read for Good<'a> { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + self.0.write_tls(buf.by_ref()) + } +} + +impl<'a> Write for Good<'a> { + fn write(&mut self, mut buf: &[u8]) -> io::Result { + let len = self.0.read_tls(buf.by_ref())?; + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +struct Bad(bool); + +impl Read for Bad { + fn read(&mut self, _: &mut [u8]) -> io::Result { + Ok(0) + } +} + +impl Write for Bad { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.0 { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Ok(buf.len()) + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + + +#[test] +fn stream_good() -> io::Result<()> { + const FILE: &'static [u8] = include_bytes!("../../README.md"); + + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + io::copy(&mut Cursor::new(FILE), &mut server)?; + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut client, &mut good); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf)?; + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!")? + } + + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); + + Ok(()) +} + +#[test] +fn stream_bad() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + client.set_buffer_limit(1024); + + let mut bad = Bad(true); + let mut stream = Stream::new(&mut client, &mut bad); + assert_eq!(stream.write(&[0x42; 8])?, 8); + assert_eq!(stream.write(&[0x42; 8])?, 8); + let r = stream.write(&[0x00; 1024])?; // fill buffer + assert!(r < 1024); + assert_eq!( + stream.write(&[0x01]).unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + + Ok(()) +} + +#[test] +fn stream_handshake() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut client, &mut good); + let (r, w) = stream.complete_io()?; + + assert!(r > 0); + assert!(w > 0); + + stream.complete_io()?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) +} + +#[test] +fn stream_handshake_eof() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut bad = Bad(false); + let mut stream = Stream::new(&mut client, &mut bad); + let r = stream.complete_io(); + + assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) +} + +fn make_pair() -> (ServerSession, ClientSession) { + const CERT: &str = include_str!("../../tests/end.cert"); + const CHAIN: &str = include_str!("../../tests/end.chain"); + const RSA: &str = include_str!("../../tests/end.rsa"); + + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let mut sconfig = ServerConfig::new(NoClientAuth::new()); + sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap(); + let server = ServerSession::new(&Arc::new(sconfig)); + + let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap(); + let mut cconfig = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(CHAIN)); + cconfig.root_store.add_pem_file(&mut chain).unwrap(); + let client = ClientSession::new(&Arc::new(cconfig), domain); + + (server, client) +} + +fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { + let mut good = Good(server); + let mut stream = Stream::new(client, &mut good); + stream.complete_io().unwrap(); + stream.complete_io().unwrap(); +} diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs new file mode 100644 index 0000000..81bec86 --- /dev/null +++ b/src/common/vecbuf.rs @@ -0,0 +1,122 @@ +use std::cmp::{ self, Ordering }; +use bytes::Buf; +use iovec::IoVec; + +pub struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + pub fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + }, + Ordering::Greater => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + #[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))] + fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + if len > 0 { + dst[0] = self.bytes().into(); + } + + for i in 1..len { + dst[i] = self.inner[self.pos + i].into(); + } + + len + } +} + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(2); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16_be()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16_be(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [&IoVec; 2] = + [b1.into(), b2.into()]; + + assert_eq!(2, buf.bytes_vec(&mut dst[..])); + } +} diff --git a/src/futures_impl.rs b/src/futures_impl.rs deleted file mode 100644 index 6771316..0000000 --- a/src/futures_impl.rs +++ /dev/null @@ -1,170 +0,0 @@ -extern crate futures_core; -extern crate futures_io; - -use super::*; -use self::futures_core::{ Future, Poll, Async }; -use self::futures_core::task::Context; -use self::futures_io::{ Error, AsyncRead, AsyncWrite }; - - -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -impl Future for AcceptAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -macro_rules! async { - ( to $r:expr ) => { - match $r { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()), - Err(e) => Err(e) - } - }; - ( from $r:expr ) => { - match $r { - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - }; -} - -struct TaskStream<'a, 'b: 'a, S: 'a> { - io: &'a mut S, - task: &'a mut Context<'b> -} - -impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - async!(to self.io.poll_read(self.task, buf)) - } -} - -impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - async!(to self.io.poll_write(self.task, buf)) - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - async!(to self.io.poll_flush(self.task)) - } -} - -impl Future for MidHandshake - where S: AsyncRead + AsyncWrite, C: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - loop { - let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; - - let (io, session) = stream.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - match session.complete_io(&mut taskio) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending), - Err(e) => return Err(e) - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - -impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - if self.eof { - return Ok(Async::Ready(0)); - } - - // TODO nll - let result = { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - io::Read::read(&mut stream, buf) - }; - - match result { - Ok(0) => { self.eof = true; Ok(Async::Ready(0)) }, - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.eof = true; - self.is_shutdown = true; - self.session.send_close_notify(); - Ok(Async::Ready(0)) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - } -} - -impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - - async!(from io::Write::write(&mut stream, buf)) - } - - fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - { - let mut stream = Stream::new(session, &mut taskio); - async!(from io::Write::flush(&mut stream))?; - } - - async!(from io::Write::flush(&mut taskio)) - } - - fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - - { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - async!(from session.complete_io(&mut taskio))?; - } - - self.io.poll_close(ctx) - } -} diff --git a/src/lib.rs b/src/lib.rs index 8d43c22..d61caf0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,22 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))] + pub extern crate rustls; pub extern crate webpki; -#[cfg(feature = "tokio")] mod tokio_impl; -#[cfg(feature = "unstable-futures")] mod futures_impl; +#[cfg(feature = "tokio-support")] +extern crate tokio; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +extern crate bytes; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +extern crate iovec; + + +mod common; +#[cfg(feature = "tokio-support")] mod tokio_impl; use std::io; use std::sync::Arc; @@ -12,8 +24,8 @@ use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, - Stream }; +use common::Stream; pub struct TlsConnector { diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index d9598bf..11179dc 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,9 +1,8 @@ -extern crate tokio; - use super::*; -use self::tokio::prelude::*; -use self::tokio::io::{ AsyncRead, AsyncWrite }; -use self::tokio::prelude::Poll; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; +use common::Stream; impl Future for Connect { @@ -31,16 +30,17 @@ impl Future for MidHandshake type Error = io::Error; fn poll(&mut self) -> Poll { - loop { + { let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; + if stream.session.is_handshaking() { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(session, io); - let (io, session) = stream.get_mut(); - - match session.complete_io(io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) + match stream.complete_io() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) + } } } @@ -52,7 +52,11 @@ impl AsyncRead for TlsStream where S: AsyncRead + AsyncWrite, C: Session -{} +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} impl AsyncWrite for TlsStream where diff --git a/tests/test.rs b/tests/test.rs index 7eae2af..8833253 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,81 +1,89 @@ +#[macro_use] extern crate lazy_static; extern crate rustls; extern crate tokio; extern crate tokio_rustls; extern crate webpki; -#[cfg(feature = "unstable-futures")] extern crate futures; - use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; -use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; +use std::net::SocketAddr; use tokio::net::{ TcpListener, TcpStream }; -use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; +use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ TlsConnector, TlsAcceptor }; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -const HELLO_WORLD: &[u8] = b"Hello world!"; +lazy_static!{ + static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { + use tokio::prelude::*; + use tokio::io as aio; -fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { - use tokio::prelude::*; - use tokio::io as aio; + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa) - .expect("invalid key or certificate"); - let config = TlsAcceptor::from(Arc::new(config)); + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, keys.pop().unwrap()) + .expect("invalid key or certificate"); + let config = TlsAcceptor::from(Arc::new(config)); - let (send, recv) = channel(); + let (send, recv) = channel(); - thread::spawn(move || { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let listener = TcpListener::bind(&addr).unwrap(); + thread::spawn(move || { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(&addr).unwrap(); - send.send(listener.local_addr().unwrap()).unwrap(); + send.send(listener.local_addr().unwrap()).unwrap(); - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept(stream) - .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - aio::write_all(stream, HELLO_WORLD) - }) - .then(|_| Ok(())); + let done = listener.incoming() + .for_each(move |stream| { + let done = config.accept(stream) + .and_then(|stream| { + let (reader, writer) = stream.split(); + aio::copy(reader, writer) + }) + .then(|_| Ok(())); - tokio::spawn(done); - Ok(()) - }) - .map_err(|err| panic!("{:?}", err)); + tokio::spawn(done); + Ok(()) + }) + .map_err(|err| panic!("{:?}", err)); - tokio::run(done); - }); + tokio::run(done); + }); - recv.recv().unwrap() + let addr = recv.recv().unwrap(); + (addr, "localhost", CHAIN) + }; } -fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + +fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { + &*TEST_SERVER +} + +fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; + const FILE: &'static [u8] = include_bytes!("../README.md"); + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); let config = TlsConnector::from(Arc::new(config)); let done = TcpStream::connect(addr) .and_then(|stream| config.connect(domain, stream)) - .and_then(|stream| aio::write_all(stream, HELLO_WORLD)) - .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) + .and_then(|stream| aio::write_all(stream, FILE)) + .and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()])) .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); + assert_eq!(buf, FILE); aio::shutdown(stream) }) .map(drop); @@ -83,62 +91,17 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { - use futures::FutureExt; - use futures::io::{ AsyncReadExt, AsyncWriteExt }; - use futures::executor::block_on; - - let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } - let config = TlsConnector::from(Arc::new(config)); - - let done = TcpStream::connect(addr) - .and_then(|stream| config.connect(domain, stream)) - .and_then(|stream| stream.write_all(HELLO_WORLD)) - .and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - stream.close() - }) - .map(drop); - - block_on(done) -} - - #[test] fn pass() { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let chain = BufReader::new(Cursor::new(CHAIN)); + let (addr, domain, chain) = start_server(); - let addr = start_server(cert, keys.pop().unwrap()); - start_client(&addr, "localhost", Some(chain)).unwrap(); + start_client(addr, domain, chain).unwrap(); } -#[cfg(feature = "unstable-futures")] -#[test] -fn pass2() { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let chain = BufReader::new(Cursor::new(CHAIN)); - - let addr = start_server(cert, keys.pop().unwrap()); - start_client2(&addr, "localhost", Some(chain)).unwrap(); -} - -#[should_panic] #[test] fn fail() { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let chain = BufReader::new(Cursor::new(CHAIN)); + let (addr, domain, chain) = start_server(); - let addr = start_server(cert, keys.pop().unwrap()); - - start_client(&addr, "google.com", Some(chain)).unwrap(); + assert_ne!(domain, &"google.com"); + assert!(start_client(addr, "google.com", chain).is_err()); }