From 3ffb736d5e5247799b15964cfd377cd4c324fa09 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 22 May 2019 00:54:10 +0800 Subject: [PATCH] update server example --- Cargo.toml | 3 +- examples/server/Cargo.toml | 6 +- examples/server/src/main.rs | 150 +++++++++++++++++++----------------- src/common/mod.rs | 17 ++-- src/lib.rs | 3 + tests/test.rs | 7 +- 6 files changed, 98 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cb60566..84d6d63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -smallvec = "*" +smallvec = "0.6" futures = { package = "futures-preview", version = "0.3.0-alpha.16" } rustls = "0.15" webpki = "0.19" @@ -26,6 +26,5 @@ early-data = [] [dev-dependencies] romio = "0.3.0-alpha.8" -tokio = "0.1.6" lazy_static = "1" webpki-roots = "0.16" diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 170693f..0392ffb 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -2,8 +2,10 @@ name = "server" version = "0.1.0" authors = ["quininer "] +edition = "2018" [dependencies] +futures = { package = "futures-preview", version = "0.3.0-alpha.16" } +romio = "0.3.0-alpha.8" +structopt = "*" tokio-rustls = { path = "../.." } -tokio = { version = "0.1.6" } -clap = "2" diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 2a94c58..a2a3b13 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -1,88 +1,100 @@ -extern crate clap; -extern crate tokio; -extern crate tokio_rustls; +#![feature(async_await)] +use std::fs::File; use std::sync::Arc; use std::net::ToSocketAddrs; -use std::io::BufReader; -use std::fs::File; -use tokio_rustls::{ - TlsAcceptor, - rustls::{ - Certificate, NoClientAuth, PrivateKey, ServerConfig, - internal::pemfile::{ certs, rsa_private_keys } - }, -}; -use tokio::prelude::{ Future, Stream }; -use tokio::io::{ self, AsyncRead }; -use tokio::net::TcpListener; -use clap::{ App, Arg }; +use std::path::{ PathBuf, Path }; +use std::io::{ self, BufReader }; +use structopt::StructOpt; +use futures::task::SpawnExt; +use futures::prelude::*; +use futures::executor; +use romio::TcpListener; +use tokio_rustls::rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; +use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::TlsAcceptor; -fn app() -> App<'static, 'static> { - App::new("server") - .about("tokio-rustls server example") - .arg(Arg::with_name("addr").value_name("ADDR").required(true)) - .arg(Arg::with_name("cert").short("c").long("cert").value_name("FILE").help("cert file.").required(true)) - .arg(Arg::with_name("key").short("k").long("key").value_name("FILE").help("key file, rsa only.").required(true)) - .arg(Arg::with_name("echo").short("e").long("echo-mode").help("echo mode.")) + +#[derive(StructOpt)] +struct Options { + addr: String, + + /// cert file + #[structopt(short="c", long="cert", parse(from_os_str))] + cert: PathBuf, + + /// key file + #[structopt(short="k", long="key", parse(from_os_str))] + key: PathBuf, + + /// echo mode + #[structopt(short="e", long="echo-mode")] + echo: bool } -fn load_certs(path: &str) -> Vec { - certs(&mut BufReader::new(File::open(path).unwrap())).unwrap() +fn load_certs(path: &Path) -> io::Result> { + certs(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) } -fn load_keys(path: &str) -> Vec { - rsa_private_keys(&mut BufReader::new(File::open(path).unwrap())).unwrap() +fn load_keys(path: &Path) -> io::Result> { + rsa_private_keys(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) } -fn main() { - let matches = app().get_matches(); +fn main() -> io::Result<()> { + let options = Options::from_args(); - let addr = matches.value_of("addr").unwrap() - .to_socket_addrs().unwrap() - .next().unwrap(); - let cert_file = matches.value_of("cert").unwrap(); - let key_file = matches.value_of("key").unwrap(); - let flag_echo = matches.occurrences_of("echo") > 0; + let addr = options.addr.to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?; + let certs = load_certs(&options.cert)?; + let mut keys = load_keys(&options.key)?; + let flag_echo = options.echo; + let mut pool = executor::ThreadPool::new()?; let mut config = ServerConfig::new(NoClientAuth::new()); - config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)) - .expect("invalid key or certificate"); - let config = TlsAcceptor::from(Arc::new(config)); + config.set_single_cert(certs, keys.remove(0)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + let acceptor = TlsAcceptor::from(Arc::new(config)); - let socket = TcpListener::bind(&addr).unwrap(); - let done = socket.incoming() - .for_each(move |stream| if flag_echo { - let addr = stream.peer_addr().ok(); - let done = config.accept(stream) - .and_then(|stream| { - let (reader, writer) = stream.split(); - io::copy(reader, writer) - }) - .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) - .map_err(move |err| println!("Error: {:?} - {:?}", err, addr)); - tokio::spawn(done); + let fut = async { + let mut listener = TcpListener::bind(&addr)?; + let mut incoming = listener.incoming(); - Ok(()) - } else { - let addr = stream.peer_addr().ok(); - let done = config.accept(stream) - .and_then(|stream| io::write_all( - stream, - &b"HTTP/1.0 200 ok\r\n\ - Connection: close\r\n\ - Content-length: 12\r\n\ - \r\n\ - Hello world!"[..] - )) - .and_then(|(stream, _)| io::flush(stream)) - .map(move |_| println!("Accept: {:?}", addr)) - .map_err(move |err| println!("Error: {:?} - {:?}", err, addr)); - tokio::spawn(done); + while let Some(stream) = incoming.next().await { + let acceptor = acceptor.clone(); - Ok(()) - }); + let fut = async move { + let stream = stream?; + let peer_addr = stream.peer_addr()?; + let mut stream = acceptor.accept(stream).await?; - tokio::run(done.map_err(drop)); + if flag_echo { + let (mut reader, mut writer) = stream.split(); + let n = reader.copy_into(&mut writer).await?; + println!("Echo: {} - {}", peer_addr, n); + } else { + stream.write_all( + &b"HTTP/1.0 200 ok\r\n\ + Connection: close\r\n\ + Content-length: 12\r\n\ + \r\n\ + Hello world!"[..] + ).await?; + stream.flush().await?; + println!("Hello: {}", peer_addr); + } + + Ok(()) as io::Result<()> + }; + + pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + } + + Ok(()) + }; + + executor::block_on(fut) } diff --git a/src/common/mod.rs b/src/common/mod.rs index 585e6c9..4d65be7 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -127,7 +127,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { }; match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + return Poll::Ready(Err(err)); + }, (_, false, true) => { let would_block = match focus { Focus::Empty => rdlen == 0 && wrlen == 0, @@ -224,11 +227,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { - Poll::Ready(Ok(_)) => (), - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } + try_ready!(this.complete_inner_io(cx, Focus::Writable)); } Pin::new(&mut this.io).poll_flush(cx) } @@ -237,11 +236,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { - Poll::Ready(Ok(_)) => (), - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } + try_ready!(this.complete_inner_io(cx, Focus::Writable)); } Pin::new(&mut this.io).poll_close(cx) } diff --git a/src/lib.rs b/src/lib.rs index df3c259..cca1c85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,9 @@ use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; use webpki::DNSNameRef; use common::Stream; +pub use rustls; +pub use webpki; + #[derive(Debug, Copy, Clone)] enum TlsState { #[cfg(feature = "early-data")] diff --git a/tests/test.rs b/tests/test.rs index a7fd2f2..acc67e3 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -26,7 +26,7 @@ lazy_static!{ let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, keys.pop().unwrap()) .expect("invalid key or certificate"); - let config = TlsAcceptor::from(Arc::new(config)); + let acceptor = TlsAcceptor::from(Arc::new(config)); let (send, recv) = channel(); @@ -40,11 +40,10 @@ lazy_static!{ let mut incoming = listener.incoming(); while let Some(stream) = incoming.next().await { - let config = config.clone(); + let acceptor = acceptor.clone(); pool.spawn( async move { - let stream = stream?; - let stream = config.accept(stream).await?; + let stream = acceptor.accept(stream?).await?; let (mut reader, mut write) = stream.split(); reader.copy_into(&mut write).await?; Ok(()) as io::Result<()>