From 41a6a3b501299809a88fe22ca6dae7179c7fa1eb Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 12:20:27 +0800 Subject: [PATCH 01/15] impl vecbuf for tokio --- Cargo.toml | 2 + src/common.rs | 99 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 +++ src/tokio_impl.rs | 8 ++-- 4 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 src/common.rs diff --git a/Cargo.toml b/Cargo.toml index e289e0f..94edce4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ appveyor = { repository = "quininer/tokio-rustls" } futures-core = { version = "0.2.0", optional = true } futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } +bytes = { version = "*" } +iovec = { version = "*" } rustls = "0.13" webpki = "0.18.1" diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..799ee7e --- /dev/null +++ b/src/common.rs @@ -0,0 +1,99 @@ +use std::cmp::{ self, Ordering }; +use std::io::{ self, Read, Write }; +use rustls::{ Session, WriteV }; +use tokio::prelude::Async; +use tokio::io::AsyncWrite; +use bytes::Buf; +use iovec::IoVec; + + +pub struct Stream<'a, S: 'a, IO: 'a> { + session: &'a mut S, + io: &'a mut IO +} + +/* +impl<'a, S: Session, IO: Write> Stream<'a, S, IO> { + pub default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} +*/ + +impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { + pub fn write_tls(&mut self) -> io::Result { + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + + +struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => { + if self.pos < self.inner.len() { + self.pos += 1; + } + self.cur = 0; + }, + Ordering::Greater => { + if self.pos < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + for i in 0..len { + dst[i] = self.inner[self.pos + i].into(); + } + + len + } +} diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..81da5fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,12 @@ pub extern crate rustls; pub extern crate webpki; +extern crate tokio; +extern crate bytes; +extern crate iovec; + + +mod common; #[cfg(feature = "tokio")] mod tokio_impl; #[cfg(feature = "unstable-futures")] mod futures_impl; diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..e9a00a9 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,9 +1,7 @@ -extern crate tokio; - use super::*; -use self::tokio::prelude::*; -use self::tokio::io::{ AsyncRead, AsyncWrite }; -use self::tokio::prelude::Poll; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; impl Future for ConnectAsync { From 518ad51376ace135487d29dd1e41e0f1392b9c40 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 15:29:16 +0800 Subject: [PATCH 02/15] impl complete_io --- src/common.rs | 97 +++++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 4 +- src/tokio_impl.rs | 18 +++++---- 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/src/common.rs b/src/common.rs index 799ee7e..df83537 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,16 +12,98 @@ pub struct Stream<'a, S: 'a, IO: 'a> { io: &'a mut IO } -/* -impl<'a, S: Session, IO: Write> Stream<'a, S, IO> { - pub default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) +pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { + fn write_tls(&mut self) -> io::Result; + fn complete_io(&mut self) -> io::Result<(usize, usize)>; +} + +impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { + pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { + Stream { session, io } } } -*/ -impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { - pub fn write_tls(&mut self) -> io::Result { +impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } + + fn complete_io(&mut self) -> io::Result<(usize, usize)> { + // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 + + let until_handshaked = self.session.is_handshaking(); + let mut eof = false; + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + while self.session.wants_write() { + wrlen += self.write_tls()?; + } + + if !until_handshaked && wrlen > 0 { + return Ok((rdlen, wrlen)); + } + + if !eof && self.session.wants_read() { + match self.session.read_tls(self.io)? { + 0 => eof = true, + n => rdlen += n + } + } + + match self.session.process_new_packets() { + Ok(_) => {}, + Err(e) => { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ignored = self.write_tls(); + + return Err(io::Error::new(io::ErrorKind::InvalidData, e)); + }, + }; + + match (eof, until_handshaked, self.session.is_handshaking()) { + (_, true, false) => return Ok((rdlen, wrlen)), + (_, false, _) => return Ok((rdlen, wrlen)), + (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (..) => () + } + } + } +} + +impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + while self.session.wants_read() { + if let (0, 0) = self.complete_io()? { + break + } + } + + self.session.read(buf) + } +} + +impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = self.session.write(buf)?; + self.complete_io()?; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + self.session.flush()?; + if self.session.wants_write() { + self.complete_io()?; + } + Ok(()) + } +} + +impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { struct V<'a, IO: 'a>(&'a mut IO); impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { @@ -41,6 +123,7 @@ impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { } +// TODO test struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, diff --git a/src/lib.rs b/src/lib.rs index 81da5fe..f8432d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +#![feature(specialization)] + pub extern crate rustls; pub extern crate webpki; @@ -18,8 +20,8 @@ use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, - Stream }; +use common::Stream; /// Extension trait for the `Arc` type in the `rustls` crate. diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index e9a00a9..663d6ca 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -2,6 +2,7 @@ use super::*; use tokio::prelude::*; use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::prelude::Poll; +use common::{ Stream, CompleteIo }; impl Future for ConnectAsync { @@ -29,16 +30,17 @@ impl Future for MidHandshake type Error = io::Error; fn poll(&mut self) -> Poll { - loop { + { let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; + if stream.session.is_handshaking() { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(session, io); - let (io, session) = stream.get_mut(); - - match session.complete_io(io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) + match stream.complete_io() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) + } } } From 32f328fc142e6e6379f1ae6445fc58e1cc883ec7 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 17:59:59 +0800 Subject: [PATCH 03/15] remove futures 0.2 code --- Cargo.toml | 8 -------- src/lib.rs | 3 ++- src/tokio_impl.rs | 8 +++----- tests/test.rs | 37 ------------------------------------- 4 files changed, 5 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e289e0f..5d06420 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,20 +15,12 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures-core = { version = "0.2.0", optional = true } -futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } rustls = "0.13" webpki = "0.18.1" [dev-dependencies] -# futures = "0.2.0" tokio = "0.1.6" [features] default = [ "tokio" ] -# unstable-futures = [ -# "futures-core", -# "futures-io", -# "tokio/unstable-futures" -# ] diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..69db77d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,8 +3,9 @@ pub extern crate rustls; pub extern crate webpki; +extern crate tokio; + #[cfg(feature = "tokio")] mod tokio_impl; -#[cfg(feature = "unstable-futures")] mod futures_impl; use std::io; use std::sync::Arc; diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..e9a00a9 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,9 +1,7 @@ -extern crate tokio; - use super::*; -use self::tokio::prelude::*; -use self::tokio::io::{ AsyncRead, AsyncWrite }; -use self::tokio::prelude::Poll; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; impl Future for ConnectAsync { diff --git a/tests/test.rs b/tests/test.rs index e64dd82..c6262e9 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -83,32 +83,6 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { - use futures::FutureExt; - use futures::io::{ AsyncReadExt, AsyncWriteExt }; - use futures::executor::block_on; - - let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } - let config = Arc::new(config); - - let done = TcpStream::connect(addr) - .and_then(|stream| config.connect_async(domain, stream)) - .and_then(|stream| stream.write_all(HELLO_WORLD)) - .and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - stream.close() - }) - .map(drop); - - block_on(done) -} - #[test] fn pass() { @@ -120,17 +94,6 @@ fn pass() { start_client(&addr, "localhost", Some(chain)).unwrap(); } -#[cfg(feature = "unstable-futures")] -#[test] -fn pass2() { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let chain = BufReader::new(Cursor::new(CHAIN)); - - let addr = start_server(cert, keys.pop().unwrap()); - start_client2(&addr, "localhost", Some(chain)).unwrap(); -} - #[should_panic] #[test] fn fail() { From 26046efc3cef6ada79a18a6e208600d87cfe34ee Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 18:43:35 +0800 Subject: [PATCH 04/15] fix vecbuf --- src/common.rs | 125 +++++++++++++++++++++++++++++++++------------- src/tokio_impl.rs | 2 +- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/src/common.rs b/src/common.rs index df83537..070caeb 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,23 +12,16 @@ pub struct Stream<'a, S: 'a, IO: 'a> { io: &'a mut IO } -pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { +pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { fn write_tls(&mut self) -> io::Result; - fn complete_io(&mut self) -> io::Result<(usize, usize)>; } impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { Stream { session, io } } -} -impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { - default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) - } - - fn complete_io(&mut self) -> io::Result<(usize, usize)> { + pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 let until_handshaked = self.session.is_handshaking(); @@ -74,6 +67,32 @@ impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, I } } +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { @@ -102,28 +121,7 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { - fn write_tls(&mut self) -> io::Result { - struct V<'a, IO: 'a>(&'a mut IO); - impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { - fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { - let mut vbytes = VecBuf::new(vbytes); - match self.0.write_buf(&mut vbytes) { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), - Err(err) => Err(err) - } - } - } - - let mut vecbuf = V(self.io); - self.session.writev_tls(&mut vecbuf) - } -} - - -// TODO test struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, @@ -153,14 +151,14 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn advance(&mut self, cnt: usize) { let current = self.inner[self.pos].len(); match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => { - if self.pos < self.inner.len() { - self.pos += 1; - } + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; self.cur = 0; + } else { + self.cur += cnt; }, Ordering::Greater => { - if self.pos < self.inner.len() { + if self.pos + 1 < self.inner.len() { self.pos += 1; } let remaining = self.cur + cnt - current; @@ -180,3 +178,60 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { len } } + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(2); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16_be()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16_be(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [&IoVec; 2] = + [b1.into(), b2.into()]; + + assert_eq!(2, buf.bytes_vec(&mut dst[..])); + } +} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 663d6ca..9f09705 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -2,7 +2,7 @@ use super::*; use tokio::prelude::*; use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::prelude::Poll; -use common::{ Stream, CompleteIo }; +use common::Stream; impl Future for ConnectAsync { From 4a2354c1cc19aaf937da1b1d8f70e5c6705771db Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 19:13:18 +0800 Subject: [PATCH 05/15] rename tokio feature --- Cargo.toml | 4 ++-- examples/client/Cargo.toml | 6 ++---- examples/server/Cargo.toml | 7 ++----- src/lib.rs | 2 +- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 367bcca..41c0671 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,5 +25,5 @@ webpki = "0.18.1" tokio = "0.1.6" [features] -default = [ "tokio_impl" ] -tokio_impl = [ "tokio", "bytes", "iovec" ] +default = [ "tokio-support" ] +tokio-support = [ "tokio", "bytes", "iovec" ] diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index acf13bd..2253096 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -5,11 +5,9 @@ authors = ["quininer "] [dependencies] webpki = "0.18.1" -tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } - +tokio-rustls = { path = "../.." } tokio = "0.1" - -clap = "2.26" +clap = "2" webpki-roots = "0.15" [target.'cfg(unix)'.dependencies] diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 98329ce..170693f 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -4,9 +4,6 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } - +tokio-rustls = { path = "../.." } tokio = { version = "0.1.6" } -# futures = "0.2.0-beta" - -clap = "2.26" +clap = "2" diff --git a/src/lib.rs b/src/lib.rs index afd3cf0..9d910e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ extern crate iovec; mod common; -#[cfg(feature = "tokio_impl")] mod tokio_impl; +#[cfg(feature = "tokio-support")] mod tokio_impl; use std::io; use std::sync::Arc; From b040a9a65f1b16569652dd2b7564363055d7682e Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 20:44:37 +0800 Subject: [PATCH 06/15] Test use lazy_static! --- Cargo.toml | 1 + src/futures_impl.rs | 170 -------------------------------------------- tests/test.rs | 105 +++++++++++++-------------- 3 files changed, 54 insertions(+), 222 deletions(-) delete mode 100644 src/futures_impl.rs diff --git a/Cargo.toml b/Cargo.toml index 5d06420..e8fcb70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ webpki = "0.18.1" [dev-dependencies] tokio = "0.1.6" +lazy_static = "1" [features] default = [ "tokio" ] diff --git a/src/futures_impl.rs b/src/futures_impl.rs deleted file mode 100644 index 6771316..0000000 --- a/src/futures_impl.rs +++ /dev/null @@ -1,170 +0,0 @@ -extern crate futures_core; -extern crate futures_io; - -use super::*; -use self::futures_core::{ Future, Poll, Async }; -use self::futures_core::task::Context; -use self::futures_io::{ Error, AsyncRead, AsyncWrite }; - - -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -impl Future for AcceptAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -macro_rules! async { - ( to $r:expr ) => { - match $r { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()), - Err(e) => Err(e) - } - }; - ( from $r:expr ) => { - match $r { - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - }; -} - -struct TaskStream<'a, 'b: 'a, S: 'a> { - io: &'a mut S, - task: &'a mut Context<'b> -} - -impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - async!(to self.io.poll_read(self.task, buf)) - } -} - -impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - async!(to self.io.poll_write(self.task, buf)) - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - async!(to self.io.poll_flush(self.task)) - } -} - -impl Future for MidHandshake - where S: AsyncRead + AsyncWrite, C: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - loop { - let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; - - let (io, session) = stream.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - match session.complete_io(&mut taskio) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending), - Err(e) => return Err(e) - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - -impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - if self.eof { - return Ok(Async::Ready(0)); - } - - // TODO nll - let result = { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - io::Read::read(&mut stream, buf) - }; - - match result { - Ok(0) => { self.eof = true; Ok(Async::Ready(0)) }, - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.eof = true; - self.is_shutdown = true; - self.session.send_close_notify(); - Ok(Async::Ready(0)) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - } -} - -impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - - async!(from io::Write::write(&mut stream, buf)) - } - - fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - { - let mut stream = Stream::new(session, &mut taskio); - async!(from io::Write::flush(&mut stream))?; - } - - async!(from io::Write::flush(&mut taskio)) - } - - fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - - { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - async!(from session.complete_io(&mut taskio))?; - } - - self.io.poll_close(ctx) - } -} diff --git a/tests/test.rs b/tests/test.rs index c6262e9..fa46f5a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,81 +1,89 @@ +#[macro_use] extern crate lazy_static; extern crate rustls; extern crate tokio; extern crate tokio_rustls; extern crate webpki; -#[cfg(feature = "unstable-futures")] extern crate futures; - use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; -use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; +use std::net::SocketAddr; use tokio::net::{ TcpListener, TcpStream }; -use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; +use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -const HELLO_WORLD: &[u8] = b"Hello world!"; +lazy_static!{ + static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { + use tokio::prelude::*; + use tokio::io as aio; -fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { - use tokio::prelude::*; - use tokio::io as aio; + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa) - .expect("invalid key or certificate"); - let config = Arc::new(config); + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, keys.pop().unwrap()) + .expect("invalid key or certificate"); + let config = Arc::new(config); - let (send, recv) = channel(); + let (send, recv) = channel(); - thread::spawn(move || { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let listener = TcpListener::bind(&addr).unwrap(); + thread::spawn(move || { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(&addr).unwrap(); - send.send(listener.local_addr().unwrap()).unwrap(); + send.send(listener.local_addr().unwrap()).unwrap(); - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept_async(stream) - .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - aio::write_all(stream, HELLO_WORLD) - }) - .then(|_| Ok(())); + let done = listener.incoming() + .for_each(move |stream| { + let done = config.accept_async(stream) + .and_then(|stream| { + let (reader, writer) = stream.split(); + aio::copy(reader, writer) + }) + .then(|_| Ok(())); - tokio::spawn(done); - Ok(()) - }) - .map_err(|err| panic!("{:?}", err)); + tokio::spawn(done); + Ok(()) + }) + .map_err(|err| panic!("{:?}", err)); - tokio::run(done); - }); + tokio::run(done); + }); - recv.recv().unwrap() + let addr = recv.recv().unwrap(); + (addr, "localhost", CHAIN) + }; } -fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + +fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { + &*TEST_SERVER +} + +fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; + const FILE: &'static [u8] = include_bytes!("../README.md"); + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); let config = Arc::new(config); let done = TcpStream::connect(addr) .and_then(|stream| config.connect_async(domain, stream)) - .and_then(|stream| aio::write_all(stream, HELLO_WORLD)) - .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) + .and_then(|stream| aio::write_all(stream, FILE)) + .and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()])) .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); + assert_eq!(buf, FILE); aio::shutdown(stream) }) .map(drop); @@ -86,22 +94,15 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option Date: Thu, 16 Aug 2018 22:35:50 +0800 Subject: [PATCH 07/15] fix vecbuf bytes_vec --- src/common.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/common.rs b/src/common.rs index 070caeb..f9900ff 100644 --- a/src/common.rs +++ b/src/common.rs @@ -171,7 +171,11 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { let len = cmp::min(self.inner.len() - self.pos, dst.len()); - for i in 0..len { + if len > 0 { + dst[0] = self.bytes().into(); + } + + for i in 1..len { dst[i] = self.inner[self.pos + i].into(); } From 5cbd5b8aa0c4e03b3006c4231b576e0618c9effa Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 09:18:53 +0800 Subject: [PATCH 08/15] fix: handle Stream non-blocking write --- src/common.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/common.rs b/src/common.rs index f9900ff..9e3cb5c 100644 --- a/src/common.rs +++ b/src/common.rs @@ -100,7 +100,6 @@ impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { break } } - self.session.read(buf) } } @@ -108,7 +107,13 @@ impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; - self.complete_io()?; + while self.session.wants_write() { + match self.complete_io() { + Ok(_) => (), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, + Err(err) => return Err(err) + } + } Ok(len) } From 762d7f952582b8430f79e658572578aa12533c4b Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 10:04:49 +0800 Subject: [PATCH 09/15] Add nightly feature --- Cargo.toml | 5 +- src/{common.rs => common/mod.rs} | 149 +++++-------------------------- src/common/vecbuf.rs | 122 +++++++++++++++++++++++++ src/lib.rs | 7 +- 4 files changed, 154 insertions(+), 129 deletions(-) rename src/{common.rs => common/mod.rs} (57%) create mode 100644 src/common/vecbuf.rs diff --git a/Cargo.toml b/Cargo.toml index 41c0671..58ece24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,5 +25,6 @@ webpki = "0.18.1" tokio = "0.1.6" [features] -default = [ "tokio-support" ] -tokio-support = [ "tokio", "bytes", "iovec" ] +default = ["tokio-support"] +nightly = ["bytes", "iovec"] +tokio-support = ["tokio"] diff --git a/src/common.rs b/src/common/mod.rs similarity index 57% rename from src/common.rs rename to src/common/mod.rs index 9e3cb5c..1f5d1d2 100644 --- a/src/common.rs +++ b/src/common/mod.rs @@ -1,11 +1,14 @@ -use std::cmp::{ self, Ordering }; -use std::io::{ self, Read, Write }; -use rustls::{ Session, WriteV }; -use tokio::prelude::Async; -use tokio::io::AsyncWrite; -use bytes::Buf; -use iovec::IoVec; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +mod vecbuf; +use std::io::{ self, Read, Write }; +use rustls::Session; +#[cfg(feature = "nightly")] +use rustls::WriteV; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +use tokio::io::AsyncWrite; pub struct Stream<'a, S: 'a, IO: 'a> { session: &'a mut S, @@ -67,14 +70,27 @@ impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { } } +#[cfg(not(feature = "nightly"))] +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +#[cfg(feature = "nightly")] impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { default fn write_tls(&mut self) -> io::Result { self.session.write_tls(self.io) } } +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { fn write_tls(&mut self) -> io::Result { + use tokio::prelude::Async; + use self::vecbuf::VecBuf; + struct V<'a, IO: 'a>(&'a mut IO); impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { @@ -125,122 +141,3 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { Ok(()) } } - - -struct VecBuf<'a, 'b: 'a> { - pos: usize, - cur: usize, - inner: &'a [&'b [u8]] -} - -impl<'a, 'b> VecBuf<'a, 'b> { - fn new(vbytes: &'a [&'b [u8]]) -> Self { - VecBuf { pos: 0, cur: 0, inner: vbytes } - } -} - -impl<'a, 'b> Buf for VecBuf<'a, 'b> { - fn remaining(&self) -> usize { - let sum = self.inner - .iter() - .skip(self.pos) - .map(|bytes| bytes.len()) - .sum::(); - sum - self.cur - } - - fn bytes(&self) -> &[u8] { - &self.inner[self.pos][self.cur..] - } - - fn advance(&mut self, cnt: usize) { - let current = self.inner[self.pos].len(); - match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => if self.pos + 1 < self.inner.len() { - self.pos += 1; - self.cur = 0; - } else { - self.cur += cnt; - }, - Ordering::Greater => { - if self.pos + 1 < self.inner.len() { - self.pos += 1; - } - let remaining = self.cur + cnt - current; - self.advance(remaining); - }, - Ordering::Less => self.cur += cnt, - } - } - - fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { - let len = cmp::min(self.inner.len() - self.pos, dst.len()); - - if len > 0 { - dst[0] = self.bytes().into(); - } - - for i in 1..len { - dst[i] = self.inner[self.pos + i].into(); - } - - len - } -} - -#[cfg(test)] -mod test_vecbuf { - use super::*; - - #[test] - fn test_fresh_cursor_vec() { - let mut buf = VecBuf::new(&[b"he", b"llo"]); - - assert_eq!(buf.remaining(), 5); - assert_eq!(buf.bytes(), b"he"); - - buf.advance(2); - - assert_eq!(buf.remaining(), 3); - assert_eq!(buf.bytes(), b"llo"); - - buf.advance(3); - - assert_eq!(buf.remaining(), 0); - assert_eq!(buf.bytes(), b""); - } - - #[test] - fn test_get_u8() { - let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); - assert_eq!(0x21, buf.get_u8()); - } - - #[test] - fn test_get_u16() { - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x2154, buf.get_u16_be()); - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x5421, buf.get_u16_le()); - } - - #[test] - #[should_panic] - fn test_get_u16_buffer_underflow() { - let mut buf = VecBuf::new(&[b"\x21"]); - buf.get_u16_be(); - } - - #[test] - fn test_bufs_vec() { - let buf = VecBuf::new(&[b"he", b"llo"]); - - let b1: &[u8] = &mut [0]; - let b2: &[u8] = &mut [0]; - - let mut dst: [&IoVec; 2] = - [b1.into(), b2.into()]; - - assert_eq!(2, buf.bytes_vec(&mut dst[..])); - } -} diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs new file mode 100644 index 0000000..dd40163 --- /dev/null +++ b/src/common/vecbuf.rs @@ -0,0 +1,122 @@ +use std::cmp::{ self, Ordering }; +use bytes::Buf; +use iovec::IoVec; + +pub struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + pub fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + }, + Ordering::Greater => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + #[allow(needless_range_loop)] + fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + if len > 0 { + dst[0] = self.bytes().into(); + } + + for i in 1..len { + dst[i] = self.inner[self.pos + i].into(); + } + + len + } +} + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(2); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16_be()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16_be(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [&IoVec; 2] = + [b1.into(), b2.into()]; + + assert_eq!(2, buf.bytes_vec(&mut dst[..])); + } +} diff --git a/src/lib.rs b/src/lib.rs index 9d910e1..b06c227 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![feature(specialization)] +#![cfg_attr(feature = "nightly", feature(specialization))] pub extern crate rustls; pub extern crate webpki; +#[cfg(feature = "tokio-support")] extern crate tokio; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] extern crate bytes; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] extern crate iovec; From cf00bbb2f7d464c64ff2fc9ba2825b9a69fcfae2 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 10:14:17 +0800 Subject: [PATCH 10/15] fix ci --- .travis.yml | 20 ++++++++++++++------ appveyor.yml | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3d5e1db..043e804 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,21 @@ language: rust -rust: - - stable cache: cargo -os: - - linux - - osx + +matrix: + include: + - rust: stable + os: linux + - rust: nightly + env: FEATURE=nightly + os: linux + - rust: stable + os: osx + - rust: nightly + env: FEATURE=nightly + os: osx script: - - cargo test --all-features + - cargo test --features "$FEATURE" - cd examples/server - cargo check - cd ../../examples/client diff --git a/appveyor.yml b/appveyor.yml index 7ede91c..038274b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -13,7 +13,7 @@ install: build: false test_script: - - 'cargo test --all-features' + - 'cargo test' - 'cd examples/server' - 'cargo check' - 'cd ../../examples/client' From 482f3c3aa6f7c981a6c76f6903c0d7ca3cd416e3 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 13:07:26 +0800 Subject: [PATCH 11/15] impl prepare_uninitialized_buffer --- src/common/vecbuf.rs | 2 +- src/tokio_impl.rs | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs index dd40163..81bec86 100644 --- a/src/common/vecbuf.rs +++ b/src/common/vecbuf.rs @@ -48,7 +48,7 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { } } - #[allow(needless_range_loop)] + #[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))] fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { let len = cmp::min(self.inner.len() - self.pos, dst.len()); diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 9f09705..e3fd4d5 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -52,7 +52,11 @@ impl AsyncRead for TlsStream where S: AsyncRead + AsyncWrite, C: Session -{} +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} impl AsyncWrite for TlsStream where From 808df2f226bfd8c939209eef504357ee3a15676b Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 13:09:24 +0800 Subject: [PATCH 12/15] publish 0.7.2 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 73ca707..6106795 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.7.1" +version = "0.7.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From f698c44e1a352629a786c490a622f6ed5431b1da Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 18 Aug 2018 13:52:00 +0800 Subject: [PATCH 13/15] Add stream test --- src/common/mod.rs | 3 + src/common/test_stream.rs | 161 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 src/common/test_stream.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 1f5d1d2..8580cc8 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -141,3 +141,6 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { Ok(()) } } + +#[cfg(test)] +mod test_stream; diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs new file mode 100644 index 0000000..2cd9e87 --- /dev/null +++ b/src/common/test_stream.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; +use std::io::{ self, Read, Write, BufReader, Cursor }; +use webpki::DNSNameRef; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use rustls::{ + ServerConfig, ClientConfig, + ServerSession, ClientSession, + Session, NoClientAuth +}; +use super::Stream; + + +struct Good<'a>(&'a mut Session); + +impl<'a> Read for Good<'a> { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + self.0.write_tls(buf.by_ref()) + } +} + +impl<'a> Write for Good<'a> { + fn write(&mut self, mut buf: &[u8]) -> io::Result { + let len = self.0.read_tls(buf.by_ref())?; + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +struct Bad(bool); + +impl Read for Bad { + fn read(&mut self, _: &mut [u8]) -> io::Result { + Ok(0) + } +} + +impl Write for Bad { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.0 { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Ok(buf.len()) + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + + +#[test] +fn stream_good() -> io::Result<()> { + const FILE: &'static [u8] = include_bytes!("../../README.md"); + + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + io::copy(&mut Cursor::new(FILE), &mut server)?; + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut client, &mut good); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf)?; + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!")? + } + + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); + + Ok(()) +} + +#[test] +fn stream_bad() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + client.set_buffer_limit(1024); + + let mut bad = Bad(true); + let mut stream = Stream::new(&mut client, &mut bad); + assert_eq!(stream.write(&[0x42; 8])?, 8); + assert_eq!(stream.write(&[0x42; 8])?, 8); + let r = stream.write(&[0x00; 1024])?; // fill buffer + assert!(r < 1024); + assert_eq!( + stream.write(&[0x01]).unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + + Ok(()) +} + +#[test] +fn stream_handshake() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut client, &mut good); + let (r, w) = stream.complete_io()?; + + assert!(r > 0); + assert!(w > 0); + + stream.complete_io()?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) +} + +#[test] +fn stream_handshake_eof() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut bad = Bad(false); + let mut stream = Stream::new(&mut client, &mut bad); + let r = stream.complete_io(); + + assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) +} + +fn make_pair() -> (ServerSession, ClientSession) { + const CERT: &str = include_str!("../../tests/end.cert"); + const CHAIN: &str = include_str!("../../tests/end.chain"); + const RSA: &str = include_str!("../../tests/end.rsa"); + + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let mut sconfig = ServerConfig::new(NoClientAuth::new()); + sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap(); + let server = ServerSession::new(&Arc::new(sconfig)); + + let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap(); + let mut cconfig = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(CHAIN)); + cconfig.root_store.add_pem_file(&mut chain).unwrap(); + let client = ClientSession::new(&Arc::new(cconfig), domain); + + (server, client) +} + +fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { + let mut good = Good(server); + let mut stream = Stream::new(client, &mut good); + stream.complete_io().unwrap(); + stream.complete_io().unwrap(); +} From 686b75bd4623f9311cc302308a059c87fe475a61 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 21 Aug 2018 09:57:27 +0800 Subject: [PATCH 14/15] fix #5 --- examples/client/Cargo.toml | 7 +--- examples/client/src/main.rs | 74 ++++++++++--------------------------- 2 files changed, 20 insertions(+), 61 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 2253096..780ea88 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -9,9 +9,4 @@ tokio-rustls = { path = "../.." } tokio = "0.1" clap = "2" webpki-roots = "0.15" - -[target.'cfg(unix)'.dependencies] -tokio-file-unix = "0.5" - -[target.'cfg(not(unix))'.dependencies] -tokio-fs = "0.1" +tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 8499993..e58a633 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -4,8 +4,7 @@ extern crate webpki; extern crate webpki_roots; extern crate tokio_rustls; -#[cfg(unix)] extern crate tokio_file_unix; -#[cfg(not(unix))] extern crate tokio_fs; +extern crate tokio_stdin_stdout; use std::sync::Arc; use std::net::ToSocketAddrs; @@ -16,6 +15,7 @@ use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; +use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; fn app() -> App<'static, 'static> { App::new("client") @@ -52,59 +52,23 @@ fn main() { let arc_config = Arc::new(config); let socket = TcpStream::connect(&addr); + let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0)); - #[cfg(unix)] - let resp = { - use tokio::reactor::Handle; - use tokio_file_unix::{ raw_stdin, raw_stdout, File }; + let done = socket + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + arc_config.connect_async(domain, stream) + }) + .and_then(move |stream| io::write_all(stream, text)) + .and_then(move |(stream, _)| { + let (r, w) = stream.split(); + io::copy(r, stdout) + .map(drop) + .select2(io::copy(stdin, w).map(drop)) + .map_err(|res| res.split().0) + }) + .map(drop) + .map_err(|err| eprintln!("{:?}", err)); - let stdin = raw_stdin() - .and_then(File::new_nb) - .and_then(|fd| fd.into_reader(&Handle::current())) - .unwrap(); - let stdout = raw_stdout() - .and_then(File::new_nb) - .and_then(|fd| fd.into_io(&Handle::current())) - .unwrap(); - - socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .select2(io::copy(stdin, w).map(drop)) - .map_err(|res| res.split().0) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)) - }; - - #[cfg(not(unix))] - let resp = { - use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout }; - - let (stdin, stdout) = (tokio_stdin(), tokio_stdout()); - - socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .join(io::copy(stdin, w).map(drop)) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)) - }; - - tokio::run(resp); + tokio::run(done); } From 9378e415ce407bc8e0b26f6e70517b943070e190 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 6 Sep 2018 13:56:00 +0800 Subject: [PATCH 15/15] impl read_initializer --- src/common/mod.rs | 7 +++++++ src/lib.rs | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 8580cc8..7db198e 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,6 +3,8 @@ mod vecbuf; use std::io::{ self, Read, Write }; +#[cfg(feature = "nightly")] +use std::io::Initializer; use rustls::Session; #[cfg(feature = "nightly")] use rustls::WriteV; @@ -110,6 +112,11 @@ impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S } impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { + #[cfg(feature = "nightly")] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } + fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { if let (0, 0) = self.complete_io()? { diff --git a/src/lib.rs b/src/lib.rs index b06c227..15bce44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![cfg_attr(feature = "nightly", feature(specialization))] +#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))] pub extern crate rustls; pub extern crate webpki;