diff --git a/Cargo.toml b/Cargo.toml index 5d06420..e8fcb70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ webpki = "0.18.1" [dev-dependencies] tokio = "0.1.6" +lazy_static = "1" [features] default = [ "tokio" ] diff --git a/src/futures_impl.rs b/src/futures_impl.rs deleted file mode 100644 index 6771316..0000000 --- a/src/futures_impl.rs +++ /dev/null @@ -1,170 +0,0 @@ -extern crate futures_core; -extern crate futures_io; - -use super::*; -use self::futures_core::{ Future, Poll, Async }; -use self::futures_core::task::Context; -use self::futures_io::{ Error, AsyncRead, AsyncWrite }; - - -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -impl Future for AcceptAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -macro_rules! async { - ( to $r:expr ) => { - match $r { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()), - Err(e) => Err(e) - } - }; - ( from $r:expr ) => { - match $r { - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - }; -} - -struct TaskStream<'a, 'b: 'a, S: 'a> { - io: &'a mut S, - task: &'a mut Context<'b> -} - -impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - async!(to self.io.poll_read(self.task, buf)) - } -} - -impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - async!(to self.io.poll_write(self.task, buf)) - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - async!(to self.io.poll_flush(self.task)) - } -} - -impl Future for MidHandshake - where S: AsyncRead + AsyncWrite, C: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - loop { - let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; - - let (io, session) = stream.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - match session.complete_io(&mut taskio) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending), - Err(e) => return Err(e) - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - -impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - if self.eof { - return Ok(Async::Ready(0)); - } - - // TODO nll - let result = { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - io::Read::read(&mut stream, buf) - }; - - match result { - Ok(0) => { self.eof = true; Ok(Async::Ready(0)) }, - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.eof = true; - self.is_shutdown = true; - self.session.send_close_notify(); - Ok(Async::Ready(0)) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - } -} - -impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - - async!(from io::Write::write(&mut stream, buf)) - } - - fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - { - let mut stream = Stream::new(session, &mut taskio); - async!(from io::Write::flush(&mut stream))?; - } - - async!(from io::Write::flush(&mut taskio)) - } - - fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - - { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - async!(from session.complete_io(&mut taskio))?; - } - - self.io.poll_close(ctx) - } -} diff --git a/tests/test.rs b/tests/test.rs index c6262e9..fa46f5a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,81 +1,89 @@ +#[macro_use] extern crate lazy_static; extern crate rustls; extern crate tokio; extern crate tokio_rustls; extern crate webpki; -#[cfg(feature = "unstable-futures")] extern crate futures; - use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; -use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; +use std::net::SocketAddr; use tokio::net::{ TcpListener, TcpStream }; -use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; +use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -const HELLO_WORLD: &[u8] = b"Hello world!"; +lazy_static!{ + static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { + use tokio::prelude::*; + use tokio::io as aio; -fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { - use tokio::prelude::*; - use tokio::io as aio; + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa) - .expect("invalid key or certificate"); - let config = Arc::new(config); + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, keys.pop().unwrap()) + .expect("invalid key or certificate"); + let config = Arc::new(config); - let (send, recv) = channel(); + let (send, recv) = channel(); - thread::spawn(move || { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let listener = TcpListener::bind(&addr).unwrap(); + thread::spawn(move || { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(&addr).unwrap(); - send.send(listener.local_addr().unwrap()).unwrap(); + send.send(listener.local_addr().unwrap()).unwrap(); - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept_async(stream) - .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - aio::write_all(stream, HELLO_WORLD) - }) - .then(|_| Ok(())); + let done = listener.incoming() + .for_each(move |stream| { + let done = config.accept_async(stream) + .and_then(|stream| { + let (reader, writer) = stream.split(); + aio::copy(reader, writer) + }) + .then(|_| Ok(())); - tokio::spawn(done); - Ok(()) - }) - .map_err(|err| panic!("{:?}", err)); + tokio::spawn(done); + Ok(()) + }) + .map_err(|err| panic!("{:?}", err)); - tokio::run(done); - }); + tokio::run(done); + }); - recv.recv().unwrap() + let addr = recv.recv().unwrap(); + (addr, "localhost", CHAIN) + }; } -fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + +fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { + &*TEST_SERVER +} + +fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; + const FILE: &'static [u8] = include_bytes!("../README.md"); + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); let config = Arc::new(config); let done = TcpStream::connect(addr) .and_then(|stream| config.connect_async(domain, stream)) - .and_then(|stream| aio::write_all(stream, HELLO_WORLD)) - .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) + .and_then(|stream| aio::write_all(stream, FILE)) + .and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()])) .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); + assert_eq!(buf, FILE); aio::shutdown(stream) }) .map(drop); @@ -86,22 +94,15 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option