From b8e3fcb79e053ff2db492d3f92779d16c314f090 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 22 May 2019 23:57:14 +0800 Subject: [PATCH] update client example --- examples/client/Cargo.toml | 7 ++- examples/client/src/main.rs | 115 +++++++++++++++++++----------------- examples/server/Cargo.toml | 2 +- 3 files changed, 65 insertions(+), 59 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 3765efc..feec249 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -2,11 +2,12 @@ name = "client" version = "0.1.0" authors = ["quininer "] +edition = "2018" [dependencies] -webpki = "0.19" +futures = { package = "futures-preview", version = "0.3.0-alpha.16", features = ["io-compat"] } +romio = "0.3.0-alpha.8" +structopt = "0.2" tokio-rustls = { path = "../.." } -tokio = "0.1" -clap = "2" webpki-roots = "0.16" tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index ff6b315..6416db2 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -1,74 +1,79 @@ -extern crate clap; -extern crate tokio; -extern crate webpki; -extern crate webpki_roots; -extern crate tokio_rustls; - -extern crate tokio_stdin_stdout; +#![feature(async_await)] +use std::io; +use std::fs::File; +use std::path::PathBuf; use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; -use std::fs; -use tokio::io; -use tokio::net::TcpStream; -use tokio::prelude::*; -use clap::{ App, Arg }; -use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; +use structopt::StructOpt; +use romio::TcpStream; +use futures::prelude::*; +use futures::executor; +use futures::compat::{ AsyncRead01CompatExt, AsyncWrite01CompatExt }; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef }; use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; -fn app() -> App<'static, 'static> { - App::new("client") - .about("tokio-rustls client example") - .arg(Arg::with_name("host").value_name("HOST").required(true)) - .arg(Arg::with_name("port").short("p").long("port").value_name("PORT").help("port, default `443`")) - .arg(Arg::with_name("domain").short("d").long("domain").value_name("DOMAIN").help("domain")) - .arg(Arg::with_name("cafile").short("c").long("cafile").value_name("FILE").help("CA certificate chain")) + +#[derive(StructOpt)] +struct Options { + host: String, + + /// port + #[structopt(short="p", long="port", default_value="443")] + port: u16, + + /// domain + #[structopt(short="d", long="domain")] + domain: Option, + + /// cafile + #[structopt(short="c", long="cafile", parse(from_os_str))] + cafile: Option } -fn main() { - let matches = app().get_matches(); +fn main() -> io::Result<()> { + let options = Options::from_args(); - let host = matches.value_of("host").unwrap(); - let port = matches.value_of("port") - .map(|port| port.parse().unwrap()) - .unwrap_or(443); - let domain = matches.value_of("domain").unwrap_or(host).to_owned(); - let cafile = matches.value_of("cafile"); - let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - - let addr = (host, port) - .to_socket_addrs().unwrap() - .next().unwrap(); + let addr = (options.host.as_str(), options.port) + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + let domain = options.domain.unwrap_or(options.host); + let content = format!( + "GET / HTTP/1.0\r\nHost: {}\r\n\r\n", + domain + ); let mut config = ClientConfig::new(); - if let Some(cafile) = cafile { - let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); - config.root_store.add_pem_file(&mut pem).unwrap(); + if let Some(cafile) = &options.cafile { + let mut pem = BufReader::new(File::open(cafile)?); + config.root_store.add_pem_file(&mut pem) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?; } else { config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } - let config = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config)); - let socket = TcpStream::connect(&addr); - let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0)); + let fut = async { + let stream = TcpStream::connect(&addr).await?; + let (mut stdin, mut stdout) = (tokio_stdin(0).compat(), tokio_stdout(0).compat()); - let done = socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - config.connect(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .select2(io::copy(stdin, w).map(drop)) - .map_err(|res| res.split().0) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)); + let domain = DNSNameRef::try_from_ascii_str(&domain) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; - tokio::run(done); + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(content.as_bytes()).await?; + + let (mut reader, mut writer) = stream.split(); + future::try_join( + reader.copy_into(&mut stdout), + stdin.copy_into(&mut writer) + ).await?; + + Ok(()) + }; + + executor::block_on(fut) } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 0392ffb..9da4423 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,5 +7,5 @@ edition = "2018" [dependencies] futures = { package = "futures-preview", version = "0.3.0-alpha.16" } romio = "0.3.0-alpha.8" -structopt = "*" +structopt = "0.2" tokio-rustls = { path = "../.." }