From 66835b50403031bd28eb0ee976b40e69e49579dd Mon Sep 17 00:00:00 2001 From: quininer kel Date: Tue, 21 Feb 2017 11:52:43 +0800 Subject: [PATCH 001/171] [Added] init --- .gitignore | 2 + .gitjournal.toml | 10 +++ Cargo.toml | 9 +++ src/lib.rs | 192 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 .gitignore create mode 100644 .gitjournal.toml create mode 100644 Cargo.toml create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9d37c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock diff --git a/.gitjournal.toml b/.gitjournal.toml new file mode 100644 index 0000000..508a97e --- /dev/null +++ b/.gitjournal.toml @@ -0,0 +1,10 @@ +categories = ["Added", "Changed", "Fixed", "Improved", "Removed"] +category_delimiters = ["[", "]"] +colored_output = true +enable_debug = true +enable_footers = false +excluded_commit_tags = [] +show_commit_hash = false +show_prefix = false +sort_by = "date" +template_prefix = "" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e8ea275 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tokio-rustls" +version = "0.1.0" +authors = ["quininer kel "] + +[dependencies] +futures = "*" +tokio-core = "*" +rustls = "*" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b51ae20 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,192 @@ +extern crate futures; +extern crate tokio_core; +extern crate rustls; + +use std::io; +use std::sync::Arc; +use futures::{ Future, Poll, Async }; +use tokio_core::io::Io; +use rustls::{ Session, ClientSession, ServerSession }; +pub use rustls::{ ClientConfig, ServerConfig }; + + +pub trait TlsConnectorExt { + fn connect_async(&self, domain: &str, stream: S) + -> ConnectAsync + where S: Io; +} + +pub trait TlsAcceptorExt { + fn accept_async(&self, stream: S) + -> AcceptAsync + where S: Io; +} + + +pub struct ConnectAsync(MidHandshake); + +pub struct AcceptAsync(MidHandshake); + + +impl TlsConnectorExt for Arc { + fn connect_async(&self, domain: &str, stream: S) + -> ConnectAsync + where S: Io + { + ConnectAsync(MidHandshake { + inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) + }) + } +} + +impl TlsAcceptorExt for Arc { + fn accept_async(&self, stream: S) + -> AcceptAsync + where S: Io + { + AcceptAsync(MidHandshake { + inner: Some(TlsStream::new(stream, ServerSession::new(self))) + }) + } +} + +impl Future for ConnectAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +impl Future for AcceptAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + + +struct MidHandshake { + inner: Option> +} + +impl Future for MidHandshake + where S: Io, C: Session +{ + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let stream = self.inner.as_mut().unwrap_or_else(|| unreachable!()); + if !stream.session.is_handshaking() { break }; + + match stream.do_io() { + Ok(()) => continue, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + } + if !stream.session.is_handshaking() { break }; + + if stream.eof { + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); + } else { + return Ok(Async::NotReady); + } + } + + Ok(Async::Ready(self.inner.take().unwrap_or_else(|| unreachable!()))) + } +} + + +pub struct TlsStream { + eof: bool, + io: S, + session: C +} + +impl TlsStream + where S: Io, C: Session +{ + #[inline] + pub fn new(io: S, session: C) -> TlsStream { + TlsStream { + eof: false, + io: io, + session: session + } + } + + pub fn do_io(&mut self) -> io::Result<()> { + loop { + let read_would_block = match (!self.eof && self.session.wants_read(), self.io.poll_read()) { + (true, Async::Ready(())) => { + match self.session.read_tls(&mut self.io) { + Ok(0) => self.eof = true, + Ok(_) => self.session.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + }; + continue + }, + (true, Async::NotReady) => true, + (false, _) => false, + }; + + let write_would_block = match (self.session.wants_write(), self.io.poll_write()) { + (true, Async::Ready(())) => match self.session.write_tls(&mut self.io) { + Ok(_) => continue, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => return Err(e) + }, + (true, Async::NotReady) => true, + (false, _) => false + }; + + if read_would_block || write_would_block { + return Err(io::Error::from(io::ErrorKind::WouldBlock)); + } else { + return Ok(()); + } + } + } +} + +impl io::Read for TlsStream + where S: Io, C: Session +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + loop { + match self.session.read(buf) { + Ok(0) if !self.eof => self.do_io()?, + output => return output + } + } + } +} + +impl io::Write for TlsStream + where S: Io, C: Session +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + while self.session.wants_write() && self.io.poll_write().is_ready() { + self.session.write_tls(&mut self.io)?; + } + self.session.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.session.flush()?; + while self.session.wants_write() && self.io.poll_write().is_ready() { + self.session.write_tls(&mut self.io)?; + } + Ok(()) + } +} + +impl Io for TlsStream where S: Io, C: Session {} From 6e7d67cccbde6a2b1a54bda840c496244d66f806 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Tue, 21 Feb 2017 19:08:23 +0800 Subject: [PATCH 002/171] [Fixed] TlsStream::read ConnectionAborted --- Cargo.toml | 4 ++++ src/lib.rs | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e8ea275..39a0675 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,7 @@ authors = ["quininer kel "] futures = "*" tokio-core = "*" rustls = "*" + +[dev-dependencies] +clap = "*" +webpki-roots = "*" diff --git a/src/lib.rs b/src/lib.rs index b51ae20..197e5cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -161,11 +161,11 @@ impl io::Read for TlsStream where S: Io, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - loop { - match self.session.read(buf) { - Ok(0) if !self.eof => self.do_io()?, - output => return output - } + self.do_io()?; + if self.eof { + Ok(0) + } else { + self.session.read(buf) } } } From 0941db6792216888537d4b14ae1a880fee2427ce Mon Sep 17 00:00:00 2001 From: quininer kel Date: Wed, 22 Feb 2017 11:42:32 +0800 Subject: [PATCH 003/171] [Fixed] TlsStream::write then write_tls - [Added] add example --- Cargo.toml | 1 + examples/client.rs | 71 +++++++++++++++++++++++++++++++++++++ examples/server.rs | 88 ++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 +- 4 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 examples/client.rs create mode 100644 examples/server.rs diff --git a/Cargo.toml b/Cargo.toml index 39a0675..c2413bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,4 @@ rustls = "*" [dev-dependencies] clap = "*" webpki-roots = "*" +tokio-file-unix = "*" diff --git a/examples/client.rs b/examples/client.rs new file mode 100644 index 0000000..69da1e8 --- /dev/null +++ b/examples/client.rs @@ -0,0 +1,71 @@ +extern crate clap; +extern crate futures; +extern crate tokio_core; +extern crate webpki_roots; +extern crate tokio_file_unix; +extern crate tokio_rustls; + +use std::sync::Arc; +use std::net::ToSocketAddrs; +use std::io::{ BufReader, stdout }; +use std::fs; +use futures::Future; +use tokio_core::io; +use tokio_core::net::TcpStream; +use tokio_core::reactor::Core; +use clap::{ App, Arg }; +use tokio_file_unix::{ StdFile, File }; +use tokio_rustls::{ ClientConfig, TlsConnectorExt }; + + +fn app() -> App<'static, 'static> { + App::new("client") + .about("tokio-rustls client example") + .arg(Arg::with_name("host").value_name("HOST").required(true)) + .arg(Arg::with_name("port").short("p").long("port").value_name("PORT").help("port, default `443`")) + .arg(Arg::with_name("domain").short("d").long("domain").value_name("DOMAIN").help("domain")) + .arg(Arg::with_name("cafile").short("c").long("cafile").value_name("FILE").help("CA certificate chain")) +} + + +fn main() { + let matches = app().get_matches(); + + let host = matches.value_of("host").unwrap(); + let port = if let Some(port) = matches.value_of("port") { + port.parse().unwrap() + } else { + 443 + }; + let domain = matches.value_of("domain").unwrap_or(host); + let cafile = matches.value_of("cafile"); + let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); + + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let addr = (host, port) + .to_socket_addrs().unwrap() + .next().unwrap(); + + let stdout = stdout(); + let mut stdout = File::new_nb(StdFile(stdout.lock())).unwrap(); + stdout.set_nonblocking(true).unwrap(); + let stdout = stdout.into_io(&handle).unwrap(); + + let mut config = ClientConfig::new(); + if let Some(cafile) = cafile { + let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); + config.root_store.add_pem_file(&mut pem).unwrap(); + } else { + config.root_store.add_trust_anchors(&webpki_roots::ROOTS); + } + let arc_config = Arc::new(config); + + let socket = TcpStream::connect(&addr, &handle); + let resp = socket + .and_then(|stream| arc_config.connect_async(domain, stream)) + .and_then(|stream| io::write_all(stream, text.as_bytes())) + .and_then(|(stream, _)| io::copy(stream, stdout)); + + core.run(resp).unwrap(); +} diff --git a/examples/server.rs b/examples/server.rs new file mode 100644 index 0000000..eb197fd --- /dev/null +++ b/examples/server.rs @@ -0,0 +1,88 @@ +extern crate clap; +extern crate rustls; +extern crate futures; +extern crate tokio_core; +extern crate webpki_roots; +extern crate tokio_rustls; + +use std::sync::Arc; +use std::net::ToSocketAddrs; +use std::io::BufReader; +use std::fs::File; +use futures::{ Future, Stream }; +use rustls::{ Certificate, PrivateKey }; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_core::io::{ self, Io }; +use tokio_core::net::TcpListener; +use tokio_core::reactor::Core; +use clap::{ App, Arg }; +use tokio_rustls::{ ServerConfig, TlsAcceptorExt }; + + +fn app() -> App<'static, 'static> { + App::new("server") + .about("tokio-rustls server example") + .arg(Arg::with_name("addr").value_name("ADDR").required(true)) + .arg(Arg::with_name("cert").short("c").long("cert").value_name("FILE").help("cert file.").required(true)) + .arg(Arg::with_name("key").short("k").long("key").value_name("FILE").help("key file, rsa only.").required(true)) + .arg(Arg::with_name("echo").short("e").long("echo-mode").help("echo mode.")) +} + +fn load_certs(path: &str) -> Vec { + certs(&mut BufReader::new(File::open(path).unwrap())).unwrap() +} + +fn load_keys(path: &str) -> Vec { + rsa_private_keys(&mut BufReader::new(File::open(path).unwrap())).unwrap() +} + + +fn main() { + let matches = app().get_matches(); + + let addr = matches.value_of("addr").unwrap() + .to_socket_addrs().unwrap() + .next().unwrap(); + let cert_file = matches.value_of("cert").unwrap(); + let key_file = matches.value_of("key").unwrap(); + let flag_echo = matches.occurrences_of("echo") > 0; + + let mut core = Core::new().unwrap(); + let handle = core.handle(); + + let mut config = ServerConfig::new(); + config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)); + let arc_config = Arc::new(config); + + let socket = TcpListener::bind(&addr, &handle).unwrap(); + let done = socket.incoming() + .for_each(|(stream, addr)| if flag_echo { + let done = arc_config.accept_async(stream) + .and_then(|stream| { + let (reader, writer) = stream.split(); + io::copy(reader, writer) + }) + .map(move |n| println!("Echo: {} - {}", n, addr)) + .map_err(move |err| println!("Error: {:?} - {}", err, addr)); + handle.spawn(done); + + Ok(()) + } else { + let done = arc_config.accept_async(stream) + .and_then(|stream| io::write_all( + stream, + "HTTP/1.0 200 ok\r\n\ + Connection: close\r\n\ + Content-length: 12\r\n\ + \r\n\ + Hello world!".as_bytes() + )) + .map(move |_| println!("Accept: {}", addr)) + .map_err(move |err| println!("Error: {:?} - {}", err, addr)); + handle.spawn(done); + + Ok(()) + }); + + core.run(done).unwrap(); +} diff --git a/src/lib.rs b/src/lib.rs index 197e5cc..15d0ac6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -174,10 +174,11 @@ impl io::Write for TlsStream where S: Io, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { + let output = self.session.write(buf); while self.session.wants_write() && self.io.poll_write().is_ready() { self.session.write_tls(&mut self.io)?; } - self.session.write(buf) + output } fn flush(&mut self) -> io::Result<()> { From 15e40f3ebfdfb896a5d0b5addbfad97a48380918 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Wed, 22 Feb 2017 13:03:21 +0800 Subject: [PATCH 004/171] [Added] README and more --- Cargo.toml | 21 +++-- LICENSE-APACHE | 201 +++++++++++++++++++++++++++++++++++++++++++++ LICENSE-MIT | 25 ++++++ README.md | 3 + examples/client.rs | 2 +- examples/server.rs | 2 +- src/lib.rs | 31 ++++++- 7 files changed, 273 insertions(+), 12 deletions(-) create mode 100644 LICENSE-APACHE create mode 100644 LICENSE-MIT create mode 100644 README.md diff --git a/Cargo.toml b/Cargo.toml index c2413bb..20e5c2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,13 +2,22 @@ name = "tokio-rustls" version = "0.1.0" authors = ["quininer kel "] +license = "MIT/Apache-2.0" +repository = "https://github.com/quininer/tokio-rustls" +homepage = "https://github.com/quininer/tokio-rustls" +documentation = "https://docs.rs/tokio-rustls" +description = """ +An implementation of TLS/SSL streams for Tokio giving an implementation of TLS +for nonblocking I/O streams. +""" [dependencies] -futures = "*" -tokio-core = "*" -rustls = "*" +futures = "0.1" +tokio-core = "0.1" +rustls = "0.5" +tokio-proto = { version = "0.1", optional = true } [dev-dependencies] -clap = "*" -webpki-roots = "*" -tokio-file-unix = "*" +clap = "2.20" +webpki-roots = "0.7" +tokio-file-unix = "0.2" diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..4e411cf --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright 2016 quininer kel + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..d0dfcc7 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2016 quininer kel + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5cdcb70 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# tokio-rustls + +[tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). diff --git a/examples/client.rs b/examples/client.rs index 69da1e8..b78d57d 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -15,7 +15,7 @@ use tokio_core::net::TcpStream; use tokio_core::reactor::Core; use clap::{ App, Arg }; use tokio_file_unix::{ StdFile, File }; -use tokio_rustls::{ ClientConfig, TlsConnectorExt }; +use tokio_rustls::{ ClientConfig, ClientConfigExt }; fn app() -> App<'static, 'static> { diff --git a/examples/server.rs b/examples/server.rs index eb197fd..c5cf2fa 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -16,7 +16,7 @@ use tokio_core::io::{ self, Io }; use tokio_core::net::TcpListener; use tokio_core::reactor::Core; use clap::{ App, Arg }; -use tokio_rustls::{ ServerConfig, TlsAcceptorExt }; +use tokio_rustls::{ ServerConfig, ServerConfigExt }; fn app() -> App<'static, 'static> { diff --git a/src/lib.rs b/src/lib.rs index 15d0ac6..a7274da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,7 @@ +//! Async TLS streams +//! +//! [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). + extern crate futures; extern crate tokio_core; extern crate rustls; @@ -10,25 +14,31 @@ use rustls::{ Session, ClientSession, ServerSession }; pub use rustls::{ ClientConfig, ServerConfig }; -pub trait TlsConnectorExt { +/// Extension trait for the `Arc` type in the `rustls` crate. +pub trait ClientConfigExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: Io; } -pub trait TlsAcceptorExt { +/// Extension trait for the `Arc` type in the `rustls` crate. +pub trait ServerConfigExt { fn accept_async(&self, stream: S) -> AcceptAsync where S: Io; } +/// Future returned from `ClientConfigExt::connect_async` which will resolve +/// once the connection handshake has finished. pub struct ConnectAsync(MidHandshake); +/// Future returned from `ServerConfigExt::accept_async` which will resolve +/// once the accept handshake has finished. pub struct AcceptAsync(MidHandshake); -impl TlsConnectorExt for Arc { +impl ClientConfigExt for Arc { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: Io @@ -39,7 +49,7 @@ impl TlsConnectorExt for Arc { } } -impl TlsAcceptorExt for Arc { +impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync where S: Io @@ -103,12 +113,25 @@ impl Future for MidHandshake } +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] pub struct TlsStream { eof: bool, io: S, session: C } +impl TlsStream { + pub fn get_ref(&self) -> (&S, &C) { + (&self.io, &self.session) + } + + pub fn get_mut(&mut self) -> (&mut S, &mut C) { + (&mut self.io, &mut self.session) + } +} + impl TlsStream where S: Io, C: Session { From ced35f66882cad876636e2e960aee0b70715cabd Mon Sep 17 00:00:00 2001 From: quininer kel Date: Wed, 22 Feb 2017 13:30:01 +0800 Subject: [PATCH 005/171] [Added] fork proto.rs --- examples/client.rs | 4 +- examples/server.rs | 4 +- src/lib.rs | 5 +- src/proto.rs | 551 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 560 insertions(+), 4 deletions(-) create mode 100644 src/proto.rs diff --git a/examples/client.rs b/examples/client.rs index b78d57d..6fb5f5b 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,4 +1,5 @@ extern crate clap; +extern crate rustls; extern crate futures; extern crate tokio_core; extern crate webpki_roots; @@ -14,8 +15,9 @@ use tokio_core::io; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; use clap::{ App, Arg }; +use rustls::ClientConfig; use tokio_file_unix::{ StdFile, File }; -use tokio_rustls::{ ClientConfig, ClientConfigExt }; +use tokio_rustls::ClientConfigExt; fn app() -> App<'static, 'static> { diff --git a/examples/server.rs b/examples/server.rs index c5cf2fa..ca67f0e 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -10,13 +10,13 @@ use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; use futures::{ Future, Stream }; -use rustls::{ Certificate, PrivateKey }; +use rustls::{ Certificate, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_core::io::{ self, Io }; use tokio_core::net::TcpListener; use tokio_core::reactor::Core; use clap::{ App, Arg }; -use tokio_rustls::{ ServerConfig, ServerConfigExt }; +use tokio_rustls::ServerConfigExt; fn app() -> App<'static, 'static> { diff --git a/src/lib.rs b/src/lib.rs index a7274da..899efd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,16 +2,19 @@ //! //! [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). +#[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; extern crate tokio_core; extern crate rustls; +pub mod proto; + use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; use tokio_core::io::Io; use rustls::{ Session, ClientSession, ServerSession }; -pub use rustls::{ ClientConfig, ServerConfig }; +use rustls::{ ClientConfig, ServerConfig }; /// Extension trait for the `Arc` type in the `rustls` crate. diff --git a/src/proto.rs b/src/proto.rs new file mode 100644 index 0000000..24a8541 --- /dev/null +++ b/src/proto.rs @@ -0,0 +1,551 @@ +//! Wrappers for `tokio-proto` +//! +//! This module contains wrappers for protocols defined by the `tokio-proto` +//! crate. These wrappers will all attempt to negotiate a TLS connection first +//! and then delegate all further protocol information to the protocol +//! specified. +//! +//! This module requires the `tokio-proto` feature to be enabled. + +#![cfg(feature = "tokio-proto")] + +extern crate tokio_proto; + +use std::io; +use std::sync::Arc; +use futures::{ Future, IntoFuture, Poll }; +use rustls::{ ServerConfig, ClientConfig, ServerSession, ClientSession }; +use self::tokio_proto::multiplex; +use self::tokio_proto::pipeline; +use self::tokio_proto::streaming; +use tokio_core::io::Io; + +use { TlsStream, ServerConfigExt, ClientConfigExt, AcceptAsync, ConnectAsync }; + +/// TLS server protocol wrapper. +/// +/// This structure is a wrapper for other implementations of `ServerProto` in +/// the `tokio-proto` crate. This structure will negotiate a TLS connection +/// first and then delegate all further operations to the `ServerProto` +/// implementation for the underlying type. +pub struct Server { + inner: Arc, + acceptor: Arc, +} + +impl Server { + /// Constructs a new TLS protocol which will delegate to the underlying + /// `protocol` specified. + /// + /// The `acceptor` provided will be used to accept TLS connections. All new + /// connections will go through the TLS acceptor first and then further I/O + /// will go through the negotiated TLS stream through the `protocol` + /// specified. + pub fn new(protocol: T, acceptor: ServerConfig) -> Server { + Server { + inner: Arc::new(protocol), + acceptor: Arc::new(acceptor), + } + } +} + +/// Future returned from `bind_transport` in the `ServerProto` implementation. +pub struct ServerPipelineBind + where T: pipeline::ServerProto>, + I: Io + 'static, +{ + state: PipelineState, +} + +enum PipelineState + where T: pipeline::ServerProto>, + I: Io + 'static, +{ + First(AcceptAsync, Arc), + Next(::Future), +} + +impl pipeline::ServerProto for Server + where T: pipeline::ServerProto>, + I: Io + 'static, +{ + type Request = T::Request; + type Response = T::Response; + type Transport = T::Transport; + type BindTransport = ServerPipelineBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + + ServerPipelineBind { + state: PipelineState::First(self.acceptor.accept_async(io), proto), + } + } +} + +impl Future for ServerPipelineBind + where T: pipeline::ServerProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + PipelineState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + PipelineState::Next(ref mut b) => return b.poll(), + }; + self.state = PipelineState::Next(next.into_future()); + } + } +} + +/// Future returned from `bind_transport` in the `ServerProto` implementation. +pub struct ServerMultiplexBind + where T: multiplex::ServerProto>, + I: Io + 'static, +{ + state: MultiplexState, +} + +enum MultiplexState + where T: multiplex::ServerProto>, + I: Io + 'static, +{ + First(AcceptAsync, Arc), + Next(::Future), +} + +impl multiplex::ServerProto for Server + where T: multiplex::ServerProto>, + I: Io + 'static, +{ + type Request = T::Request; + type Response = T::Response; + type Transport = T::Transport; + type BindTransport = ServerMultiplexBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + + ServerMultiplexBind { + state: MultiplexState::First(self.acceptor.accept_async(io), proto), + } + } +} + +impl Future for ServerMultiplexBind + where T: multiplex::ServerProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + MultiplexState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + MultiplexState::Next(ref mut b) => return b.poll(), + }; + self.state = MultiplexState::Next(next.into_future()); + } + } +} + +/// Future returned from `bind_transport` in the `ServerProto` implementation. +pub struct ServerStreamingPipelineBind + where T: streaming::pipeline::ServerProto>, + I: Io + 'static, +{ + state: StreamingPipelineState, +} + +enum StreamingPipelineState + where T: streaming::pipeline::ServerProto>, + I: Io + 'static, +{ + First(AcceptAsync, Arc), + Next(::Future), +} + +impl streaming::pipeline::ServerProto for Server + where T: streaming::pipeline::ServerProto>, + I: Io + 'static, +{ + type Request = T::Request; + type RequestBody = T::RequestBody; + type Response = T::Response; + type ResponseBody = T::ResponseBody; + type Error = T::Error; + type Transport = T::Transport; + type BindTransport = ServerStreamingPipelineBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + + ServerStreamingPipelineBind { + state: StreamingPipelineState::First(self.acceptor.accept_async(io), proto), + } + } +} + +impl Future for ServerStreamingPipelineBind + where T: streaming::pipeline::ServerProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + StreamingPipelineState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + StreamingPipelineState::Next(ref mut b) => return b.poll(), + }; + self.state = StreamingPipelineState::Next(next.into_future()); + } + } +} + +/// Future returned from `bind_transport` in the `ServerProto` implementation. +pub struct ServerStreamingMultiplexBind + where T: streaming::multiplex::ServerProto>, + I: Io + 'static, +{ + state: StreamingMultiplexState, +} + +enum StreamingMultiplexState + where T: streaming::multiplex::ServerProto>, + I: Io + 'static, +{ + First(AcceptAsync, Arc), + Next(::Future), +} + +impl streaming::multiplex::ServerProto for Server + where T: streaming::multiplex::ServerProto>, + I: Io + 'static, +{ + type Request = T::Request; + type RequestBody = T::RequestBody; + type Response = T::Response; + type ResponseBody = T::ResponseBody; + type Error = T::Error; + type Transport = T::Transport; + type BindTransport = ServerStreamingMultiplexBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + + ServerStreamingMultiplexBind { + state: StreamingMultiplexState::First(self.acceptor.accept_async(io), proto), + } + } +} + +impl Future for ServerStreamingMultiplexBind + where T: streaming::multiplex::ServerProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + StreamingMultiplexState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + StreamingMultiplexState::Next(ref mut b) => return b.poll(), + }; + self.state = StreamingMultiplexState::Next(next.into_future()); + } + } +} + +/// TLS client protocol wrapper. +/// +/// This structure is a wrapper for other implementations of `ClientProto` in +/// the `tokio-proto` crate. This structure will negotiate a TLS connection +/// first and then delegate all further operations to the `ClientProto` +/// implementation for the underlying type. +pub struct Client { + inner: Arc, + connector: Arc, + hostname: String, +} + +impl Client { + /// Constructs a new TLS protocol which will delegate to the underlying + /// `protocol` specified. + /// + /// The `connector` provided will be used to configure the TLS connection. Further I/O + /// will go through the negotiated TLS stream through the `protocol` specified. + pub fn new(protocol: T, + connector: ClientConfig, + hostname: &str) -> Client { + Client { + inner: Arc::new(protocol), + connector: Arc::new(connector), + hostname: hostname.to_string(), + } + } +} + +/// Future returned from `bind_transport` in the `ClientProto` implementation. +pub struct ClientPipelineBind + where T: pipeline::ClientProto>, + I: Io + 'static, +{ + state: ClientPipelineState, +} + +enum ClientPipelineState + where T: pipeline::ClientProto>, + I: Io + 'static, +{ + First(ConnectAsync, Arc), + Next(::Future), +} + +impl pipeline::ClientProto for Client + where T: pipeline::ClientProto>, + I: Io + 'static, +{ + type Request = T::Request; + type Response = T::Response; + type Transport = T::Transport; + type BindTransport = ClientPipelineBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + let io = self.connector.connect_async(&self.hostname, io); + + ClientPipelineBind { + state: ClientPipelineState::First(io, proto), + } + } +} + +impl Future for ClientPipelineBind + where T: pipeline::ClientProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + ClientPipelineState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + ClientPipelineState::Next(ref mut b) => return b.poll(), + }; + self.state = ClientPipelineState::Next(next.into_future()); + } + } +} + +/// Future returned from `bind_transport` in the `ClientProto` implementation. +pub struct ClientMultiplexBind + where T: multiplex::ClientProto>, + I: Io + 'static, +{ + state: ClientMultiplexState, +} + +enum ClientMultiplexState + where T: multiplex::ClientProto>, + I: Io + 'static, +{ + First(ConnectAsync, Arc), + Next(::Future), +} + +impl multiplex::ClientProto for Client + where T: multiplex::ClientProto>, + I: Io + 'static, +{ + type Request = T::Request; + type Response = T::Response; + type Transport = T::Transport; + type BindTransport = ClientMultiplexBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + let io = self.connector.connect_async(&self.hostname, io); + + ClientMultiplexBind { + state: ClientMultiplexState::First(io, proto), + } + } +} + +impl Future for ClientMultiplexBind + where T: multiplex::ClientProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + ClientMultiplexState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + ClientMultiplexState::Next(ref mut b) => return b.poll(), + }; + self.state = ClientMultiplexState::Next(next.into_future()); + } + } +} + +/// Future returned from `bind_transport` in the `ClientProto` implementation. +pub struct ClientStreamingPipelineBind + where T: streaming::pipeline::ClientProto>, + I: Io + 'static, +{ + state: ClientStreamingPipelineState, +} + +enum ClientStreamingPipelineState + where T: streaming::pipeline::ClientProto>, + I: Io + 'static, +{ + First(ConnectAsync, Arc), + Next(::Future), +} + +impl streaming::pipeline::ClientProto for Client + where T: streaming::pipeline::ClientProto>, + I: Io + 'static, +{ + type Request = T::Request; + type RequestBody = T::RequestBody; + type Response = T::Response; + type ResponseBody = T::ResponseBody; + type Error = T::Error; + type Transport = T::Transport; + type BindTransport = ClientStreamingPipelineBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + let io = self.connector.connect_async(&self.hostname, io); + + ClientStreamingPipelineBind { + state: ClientStreamingPipelineState::First(io, proto), + } + } +} + +impl Future for ClientStreamingPipelineBind + where T: streaming::pipeline::ClientProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + ClientStreamingPipelineState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + ClientStreamingPipelineState::Next(ref mut b) => return b.poll(), + }; + self.state = ClientStreamingPipelineState::Next(next.into_future()); + } + } +} + +/// Future returned from `bind_transport` in the `ClientProto` implementation. +pub struct ClientStreamingMultiplexBind + where T: streaming::multiplex::ClientProto>, + I: Io + 'static, +{ + state: ClientStreamingMultiplexState, +} + +enum ClientStreamingMultiplexState + where T: streaming::multiplex::ClientProto>, + I: Io + 'static, +{ + First(ConnectAsync, Arc), + Next(::Future), +} + +impl streaming::multiplex::ClientProto for Client + where T: streaming::multiplex::ClientProto>, + I: Io + 'static, +{ + type Request = T::Request; + type RequestBody = T::RequestBody; + type Response = T::Response; + type ResponseBody = T::ResponseBody; + type Error = T::Error; + type Transport = T::Transport; + type BindTransport = ClientStreamingMultiplexBind; + + fn bind_transport(&self, io: I) -> Self::BindTransport { + let proto = self.inner.clone(); + let io = self.connector.connect_async(&self.hostname, io); + + ClientStreamingMultiplexBind { + state: ClientStreamingMultiplexState::First(io, proto), + } + } +} + +impl Future for ClientStreamingMultiplexBind + where T: streaming::multiplex::ClientProto>, + I: Io + 'static, +{ + type Item = T::Transport; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let next = match self.state { + ClientStreamingMultiplexState::First(ref mut a, ref state) => { + let res = a.poll().map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + state.bind_transport(try_ready!(res)) + } + ClientStreamingMultiplexState::Next(ref mut b) => return b.poll(), + }; + self.state = ClientStreamingMultiplexState::Next(next.into_future()); + } + } +} From 91e21316537c8d4525bd3b61fa8431dba361a401 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Wed, 22 Feb 2017 20:09:10 +0800 Subject: [PATCH 006/171] [Improved] update README --- README.md | 3 +++ examples/client.rs | 5 ++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5cdcb70..efd4c48 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ # tokio-rustls +[![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) +[![license](https://img.shields.io/github/license/quininer/tokio-rustls.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE) +[![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). diff --git a/examples/client.rs b/examples/client.rs index 6fb5f5b..471c484 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -50,9 +50,8 @@ fn main() { .next().unwrap(); let stdout = stdout(); - let mut stdout = File::new_nb(StdFile(stdout.lock())).unwrap(); - stdout.set_nonblocking(true).unwrap(); - let stdout = stdout.into_io(&handle).unwrap(); + let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() + .into_io(&handle).unwrap(); let mut config = ClientConfig::new(); if let Some(cafile) = cafile { From ffdf1ebcb8aa4fbd76ae83932af13b0d0865a551 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Fri, 24 Feb 2017 19:28:19 +0800 Subject: [PATCH 007/171] [Fixed] empty handshake loop --- Cargo.toml | 2 +- examples/client.rs | 14 +++++++++++--- src/lib.rs | 22 ++++++++++++---------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 20e5c2d..33930da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.0" +version = "0.1.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/examples/client.rs b/examples/client.rs index 471c484..63db56d 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,10 +8,10 @@ extern crate tokio_rustls; use std::sync::Arc; use std::net::ToSocketAddrs; -use std::io::{ BufReader, stdout }; +use std::io::{ BufReader, stdout, stdin }; use std::fs; use futures::Future; -use tokio_core::io; +use tokio_core::io::{ self, Io }; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; use clap::{ App, Arg }; @@ -49,6 +49,9 @@ fn main() { .to_socket_addrs().unwrap() .next().unwrap(); + let stdin = stdin(); + let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() + .into_io(&handle).unwrap(); let stdout = stdout(); let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() .into_io(&handle).unwrap(); @@ -66,7 +69,12 @@ fn main() { let resp = socket .and_then(|stream| arc_config.connect_async(domain, stream)) .and_then(|stream| io::write_all(stream, text.as_bytes())) - .and_then(|(stream, _)| io::copy(stream, stdout)); + .and_then(|(stream, _)| { + let (r, w) = stream.split(); + io::copy(r, stdout).select(io::copy(stdin, w)) + .map(|_| ()) + .map_err(|(e, _)| e) + }); core.run(resp).unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index 899efd7..89b6873 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,16 +98,18 @@ impl Future for MidHandshake if !stream.session.is_handshaking() { break }; match stream.do_io() { - Ok(()) => continue, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - } - if !stream.session.is_handshaking() { break }; - - if stream.eof { - return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); - } else { - return Ok(Async::NotReady); + Ok(()) => if stream.eof { + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else if stream.session.is_handshaking() { + continue + } else { + break + }, + Err(e) => match (e.kind(), stream.session.is_handshaking()) { + (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), + (io::ErrorKind::WouldBlock, false) => break, + (..) => return Err(e) + } } } From 93944055129626dbbb445f28b6bb7791cecd2578 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Sat, 25 Feb 2017 12:25:09 +0800 Subject: [PATCH 008/171] [Improved] TlsStream impl poll_{read, write} --- Cargo.toml | 2 +- src/lib.rs | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 33930da..f540cea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.1" +version = "0.1.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/lib.rs b/src/lib.rs index 89b6873..ae56975 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -218,4 +218,20 @@ impl io::Write for TlsStream } } -impl Io for TlsStream where S: Io, C: Session {} +impl Io for TlsStream where S: Io, C: Session { + fn poll_read(&mut self) -> Async<()> { + if !self.eof && self.session.wants_read() && self.io.poll_read().is_not_ready() { + Async::NotReady + } else { + Async::Ready(()) + } + } + + fn poll_write(&mut self) -> Async<()> { + if self.session.wants_write() && self.io.poll_write().is_not_ready() { + Async::NotReady + } else { + Async::Ready(()) + } + } +} From 7e4fcca0321e94f23c9eba67809a5847b349078d Mon Sep 17 00:00:00 2001 From: quininer kel Date: Mon, 27 Feb 2017 20:59:35 +0800 Subject: [PATCH 009/171] [Improved] MidHandshake/TlsStream - [Improved] README.md - [Improved] MidHandshake poll - [Improved] TlsStream read - [Fixed] TlsStream write, possible of repeat write - [Removed] TlsStream poll_{read, write} --- README.md | 18 ++++++++++++++++++ src/lib.rs | 56 ++++++++++++++++++++++++------------------------------ 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index efd4c48..b521c92 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,21 @@ [![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). + +### exmaple + +```rust +// ... + +use rustls::ClientConfig; +use tokio_rustls::ClientConfigExt; + +let mut config = ClientConfig::new(); +config.root_store.add_trust_anchors(&webpki_roots::ROOTS); +let config = Arc::new(config); + +TcpStream::connect(&addr, &handle) + .and_then(|socket| config.connect_async("www.rust-lang.org", socket)) + +// ... +``` diff --git a/src/lib.rs b/src/lib.rs index ae56975..9bd1953 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,12 +98,10 @@ impl Future for MidHandshake if !stream.session.is_handshaking() { break }; match stream.do_io() { - Ok(()) => if stream.eof { - return Err(io::Error::from(io::ErrorKind::UnexpectedEof)) - } else if stream.session.is_handshaking() { - continue - } else { - break + Ok(()) => match (stream.eof, stream.session.is_handshaking()) { + (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (false, true) => continue, + (..) => break }, Err(e) => match (e.kind(), stream.session.is_handshaking()) { (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), @@ -189,11 +187,17 @@ impl io::Read for TlsStream where S: Io, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.do_io()?; - if self.eof { - Ok(0) - } else { - self.session.read(buf) + loop { + match self.session.read(buf) { + Ok(0) if !self.eof => self.do_io()?, + Ok(n) => return Ok(n), + Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { + self.do_io()?; + return if self.eof { Ok(0) } else { Err(e) } + } else { + return Err(e) + } + } } } } @@ -202,11 +206,17 @@ impl io::Write for TlsStream where S: Io, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - let output = self.session.write(buf); + let output = self.session.write(buf)?; + while self.session.wants_write() && self.io.poll_write().is_ready() { - self.session.write_tls(&mut self.io)?; + match self.session.write_tls(&mut self.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => return Err(e) + } } - output + + Ok(output) } fn flush(&mut self) -> io::Result<()> { @@ -218,20 +228,4 @@ impl io::Write for TlsStream } } -impl Io for TlsStream where S: Io, C: Session { - fn poll_read(&mut self) -> Async<()> { - if !self.eof && self.session.wants_read() && self.io.poll_read().is_not_ready() { - Async::NotReady - } else { - Async::Ready(()) - } - } - - fn poll_write(&mut self) -> Async<()> { - if self.session.wants_write() && self.io.poll_write().is_not_ready() { - Async::NotReady - } else { - Async::Ready(()) - } - } -} +impl Io for TlsStream where S: Io, C: Session {} From 1921f2bf4964a9090633420d38d2535e4cd1f647 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Tue, 28 Feb 2017 08:53:52 +0800 Subject: [PATCH 010/171] [Fixed] TlsStream should not check poll_write --- examples/server.rs | 1 + src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/server.rs b/examples/server.rs index ca67f0e..283d016 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -77,6 +77,7 @@ fn main() { \r\n\ Hello world!".as_bytes() )) + .and_then(|(stream, _)| io::flush(stream)) .map(move |_| println!("Accept: {}", addr)) .map_err(move |err| println!("Error: {:?} - {}", err, addr)); handle.spawn(done); diff --git a/src/lib.rs b/src/lib.rs index 9bd1953..e815aff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -221,7 +221,7 @@ impl io::Write for TlsStream fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - while self.session.wants_write() && self.io.poll_write().is_ready() { + while self.session.wants_write() { self.session.write_tls(&mut self.io)?; } Ok(()) From 0db05aa9bfa71e4d9127b36fab9c99040b44e784 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Wed, 1 Mar 2017 10:34:24 +0800 Subject: [PATCH 011/171] [Changed] proto {Server,Client}::new use Arc --- Cargo.toml | 2 +- src/lib.rs | 8 +++++--- src/proto.rs | 8 ++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f540cea..bf70636 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.2" +version = "0.1.3" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/lib.rs b/src/lib.rs index e815aff..e6f5856 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,7 +94,7 @@ impl Future for MidHandshake fn poll(&mut self) -> Poll { loop { - let stream = self.inner.as_mut().unwrap_or_else(|| unreachable!()); + let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; match stream.do_io() { @@ -111,7 +111,7 @@ impl Future for MidHandshake } } - Ok(Async::Ready(self.inner.take().unwrap_or_else(|| unreachable!()))) + Ok(Async::Ready(self.inner.take().unwrap())) } } @@ -228,4 +228,6 @@ impl io::Write for TlsStream } } -impl Io for TlsStream where S: Io, C: Session {} +impl Io for TlsStream where S: Io, C: Session { + // TODO impl poll_{read, write} +} diff --git a/src/proto.rs b/src/proto.rs index 24a8541..c1b60d4 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -41,10 +41,10 @@ impl Server { /// connections will go through the TLS acceptor first and then further I/O /// will go through the negotiated TLS stream through the `protocol` /// specified. - pub fn new(protocol: T, acceptor: ServerConfig) -> Server { + pub fn new(protocol: T, acceptor: Arc) -> Server { Server { inner: Arc::new(protocol), - acceptor: Arc::new(acceptor), + acceptor: acceptor, } } } @@ -302,11 +302,11 @@ impl Client { /// The `connector` provided will be used to configure the TLS connection. Further I/O /// will go through the negotiated TLS stream through the `protocol` specified. pub fn new(protocol: T, - connector: ClientConfig, + connector: Arc, hostname: &str) -> Client { Client { inner: Arc::new(protocol), - connector: Arc::new(connector), + connector: connector, hostname: hostname.to_string(), } } From c7041e211181a8eaa1573956242404b7df18adb6 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 16 Mar 2017 10:15:06 +0800 Subject: [PATCH 012/171] [Changed] update tokio-io --- Cargo.toml | 1 + examples/client.rs | 6 ++-- examples/server.rs | 5 +-- src/lib.rs | 86 +++++++++++++++++++++++++++------------------- 4 files changed, 59 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bf70636..c5f05f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ for nonblocking I/O streams. [dependencies] futures = "0.1" tokio-core = "0.1" +tokio-io = "0.1" rustls = "0.5" tokio-proto = { version = "0.1", optional = true } diff --git a/examples/client.rs b/examples/client.rs index 63db56d..2643aba 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,6 +1,7 @@ extern crate clap; extern crate rustls; extern crate futures; +extern crate tokio_io; extern crate tokio_core; extern crate webpki_roots; extern crate tokio_file_unix; @@ -11,7 +12,7 @@ use std::net::ToSocketAddrs; use std::io::{ BufReader, stdout, stdin }; use std::fs; use futures::Future; -use tokio_core::io::{ self, Io }; +use tokio_io::{ io, AsyncRead }; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; use clap::{ App, Arg }; @@ -71,8 +72,9 @@ fn main() { .and_then(|stream| io::write_all(stream, text.as_bytes())) .and_then(|(stream, _)| { let (r, w) = stream.split(); - io::copy(r, stdout).select(io::copy(stdin, w)) + io::copy(r, stdout) .map(|_| ()) + .select(io::copy(stdin, w).map(|_| ())) .map_err(|(e, _)| e) }); diff --git a/examples/server.rs b/examples/server.rs index 283d016..5d27474 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,6 +1,7 @@ extern crate clap; extern crate rustls; extern crate futures; +extern crate tokio_io; extern crate tokio_core; extern crate webpki_roots; extern crate tokio_rustls; @@ -12,7 +13,7 @@ use std::fs::File; use futures::{ Future, Stream }; use rustls::{ Certificate, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_core::io::{ self, Io }; +use tokio_io::{ io, AsyncRead }; use tokio_core::net::TcpListener; use tokio_core::reactor::Core; use clap::{ App, Arg }; @@ -62,7 +63,7 @@ fn main() { let (reader, writer) = stream.split(); io::copy(reader, writer) }) - .map(move |n| println!("Echo: {} - {}", n, addr)) + .map(move |(n, _, _)| println!("Echo: {} - {}", n, addr)) .map_err(move |err| println!("Error: {:?} - {}", err, addr)); handle.spawn(done); diff --git a/src/lib.rs b/src/lib.rs index e6f5856..36b4d09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,8 @@ //! //! [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). -#[cfg_attr(feature = "tokio-proto", macro_use)] -extern crate futures; +#[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; +extern crate tokio_io; extern crate tokio_core; extern crate rustls; @@ -12,7 +12,7 @@ pub mod proto; use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; -use tokio_core::io::Io; +use tokio_io::{ AsyncRead, AsyncWrite }; use rustls::{ Session, ClientSession, ServerSession }; use rustls::{ ClientConfig, ServerConfig }; @@ -21,14 +21,14 @@ use rustls::{ ClientConfig, ServerConfig }; pub trait ClientConfigExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where S: Io; + where S: AsyncRead + AsyncWrite; } /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ServerConfigExt { fn accept_async(&self, stream: S) -> AcceptAsync - where S: Io; + where S: AsyncRead + AsyncWrite; } @@ -44,7 +44,7 @@ pub struct AcceptAsync(MidHandshake); impl ClientConfigExt for Arc { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where S: Io + where S: AsyncRead + AsyncWrite { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) @@ -55,7 +55,7 @@ impl ClientConfigExt for Arc { impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync - where S: Io + where S: AsyncRead + AsyncWrite { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, ServerSession::new(self))) @@ -63,7 +63,7 @@ impl ServerConfigExt for Arc { } } -impl Future for ConnectAsync { +impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; @@ -72,7 +72,7 @@ impl Future for ConnectAsync { } } -impl Future for AcceptAsync { +impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; @@ -87,7 +87,7 @@ struct MidHandshake { } impl Future for MidHandshake - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { type Item = TlsStream; type Error = io::Error; @@ -136,7 +136,7 @@ impl TlsStream { } impl TlsStream - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { #[inline] pub fn new(io: S, session: C) -> TlsStream { @@ -149,29 +149,32 @@ impl TlsStream pub fn do_io(&mut self) -> io::Result<()> { loop { - let read_would_block = match (!self.eof && self.session.wants_read(), self.io.poll_read()) { - (true, Async::Ready(())) => { - match self.session.read_tls(&mut self.io) { - Ok(0) => self.eof = true, - Ok(_) => self.session.process_new_packets() - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - }; - continue - }, - (true, Async::NotReady) => true, - (false, _) => false, + let read_would_block = if !self.eof && self.session.wants_read() { + match self.session.read_tls(&mut self.io) { + Ok(0) => { + self.eof = true; + continue + }, + Ok(_) => { + self.session.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + continue + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(e) => return Err(e) + } + } else { + false }; - let write_would_block = match (self.session.wants_write(), self.io.poll_write()) { - (true, Async::Ready(())) => match self.session.write_tls(&mut self.io) { + let write_would_block = if self.session.wants_write() { + match self.session.write_tls(&mut self.io) { Ok(_) => continue, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, Err(e) => return Err(e) - }, - (true, Async::NotReady) => true, - (false, _) => false + } + } else { + false }; if read_would_block || write_would_block { @@ -184,7 +187,7 @@ impl TlsStream } impl io::Read for TlsStream - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { @@ -203,12 +206,12 @@ impl io::Read for TlsStream } impl io::Write for TlsStream - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { let output = self.session.write(buf)?; - while self.session.wants_write() && self.io.poll_write().is_ready() { + while self.session.wants_write() { match self.session.write_tls(&mut self.io) { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, @@ -228,6 +231,19 @@ impl io::Write for TlsStream } } -impl Io for TlsStream where S: Io, C: Session { - // TODO impl poll_{read, write} +impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{} + +impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.session.send_close_notify(); + self.io.shutdown() + } } From a3358230984c3c42093eda7b2b7baea3e01eac53 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 16 Mar 2017 10:17:30 +0800 Subject: [PATCH 013/171] [Fixed] feature tokio-proto --- src/proto.rs | 66 ++++++++++++++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/proto.rs b/src/proto.rs index c1b60d4..688bb19 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -14,11 +14,11 @@ extern crate tokio_proto; use std::io; use std::sync::Arc; use futures::{ Future, IntoFuture, Poll }; +use tokio_io::{ AsyncRead, AsyncWrite }; use rustls::{ ServerConfig, ClientConfig, ServerSession, ClientSession }; use self::tokio_proto::multiplex; use self::tokio_proto::pipeline; use self::tokio_proto::streaming; -use tokio_core::io::Io; use { TlsStream, ServerConfigExt, ClientConfigExt, AcceptAsync, ConnectAsync }; @@ -52,14 +52,14 @@ impl Server { /// Future returned from `bind_transport` in the `ServerProto` implementation. pub struct ServerPipelineBind where T: pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: PipelineState, } enum PipelineState where T: pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(AcceptAsync, Arc), Next(::Future), @@ -67,7 +67,7 @@ enum PipelineState impl pipeline::ServerProto for Server where T: pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type Response = T::Response; @@ -85,7 +85,7 @@ impl pipeline::ServerProto for Server impl Future for ServerPipelineBind where T: pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -109,14 +109,14 @@ impl Future for ServerPipelineBind /// Future returned from `bind_transport` in the `ServerProto` implementation. pub struct ServerMultiplexBind where T: multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: MultiplexState, } enum MultiplexState where T: multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(AcceptAsync, Arc), Next(::Future), @@ -124,7 +124,7 @@ enum MultiplexState impl multiplex::ServerProto for Server where T: multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type Response = T::Response; @@ -142,7 +142,7 @@ impl multiplex::ServerProto for Server impl Future for ServerMultiplexBind where T: multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -166,14 +166,14 @@ impl Future for ServerMultiplexBind /// Future returned from `bind_transport` in the `ServerProto` implementation. pub struct ServerStreamingPipelineBind where T: streaming::pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: StreamingPipelineState, } enum StreamingPipelineState where T: streaming::pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(AcceptAsync, Arc), Next(::Future), @@ -181,7 +181,7 @@ enum StreamingPipelineState impl streaming::pipeline::ServerProto for Server where T: streaming::pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type RequestBody = T::RequestBody; @@ -202,7 +202,7 @@ impl streaming::pipeline::ServerProto for Server impl Future for ServerStreamingPipelineBind where T: streaming::pipeline::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -226,14 +226,14 @@ impl Future for ServerStreamingPipelineBind /// Future returned from `bind_transport` in the `ServerProto` implementation. pub struct ServerStreamingMultiplexBind where T: streaming::multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: StreamingMultiplexState, } enum StreamingMultiplexState where T: streaming::multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(AcceptAsync, Arc), Next(::Future), @@ -241,7 +241,7 @@ enum StreamingMultiplexState impl streaming::multiplex::ServerProto for Server where T: streaming::multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type RequestBody = T::RequestBody; @@ -262,7 +262,7 @@ impl streaming::multiplex::ServerProto for Server impl Future for ServerStreamingMultiplexBind where T: streaming::multiplex::ServerProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -315,14 +315,14 @@ impl Client { /// Future returned from `bind_transport` in the `ClientProto` implementation. pub struct ClientPipelineBind where T: pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: ClientPipelineState, } enum ClientPipelineState where T: pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(ConnectAsync, Arc), Next(::Future), @@ -330,7 +330,7 @@ enum ClientPipelineState impl pipeline::ClientProto for Client where T: pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type Response = T::Response; @@ -349,7 +349,7 @@ impl pipeline::ClientProto for Client impl Future for ClientPipelineBind where T: pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -373,14 +373,14 @@ impl Future for ClientPipelineBind /// Future returned from `bind_transport` in the `ClientProto` implementation. pub struct ClientMultiplexBind where T: multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: ClientMultiplexState, } enum ClientMultiplexState where T: multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(ConnectAsync, Arc), Next(::Future), @@ -388,7 +388,7 @@ enum ClientMultiplexState impl multiplex::ClientProto for Client where T: multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type Response = T::Response; @@ -407,7 +407,7 @@ impl multiplex::ClientProto for Client impl Future for ClientMultiplexBind where T: multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -431,14 +431,14 @@ impl Future for ClientMultiplexBind /// Future returned from `bind_transport` in the `ClientProto` implementation. pub struct ClientStreamingPipelineBind where T: streaming::pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: ClientStreamingPipelineState, } enum ClientStreamingPipelineState where T: streaming::pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(ConnectAsync, Arc), Next(::Future), @@ -446,7 +446,7 @@ enum ClientStreamingPipelineState impl streaming::pipeline::ClientProto for Client where T: streaming::pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type RequestBody = T::RequestBody; @@ -468,7 +468,7 @@ impl streaming::pipeline::ClientProto for Client impl Future for ClientStreamingPipelineBind where T: streaming::pipeline::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; @@ -492,14 +492,14 @@ impl Future for ClientStreamingPipelineBind /// Future returned from `bind_transport` in the `ClientProto` implementation. pub struct ClientStreamingMultiplexBind where T: streaming::multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { state: ClientStreamingMultiplexState, } enum ClientStreamingMultiplexState where T: streaming::multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { First(ConnectAsync, Arc), Next(::Future), @@ -507,7 +507,7 @@ enum ClientStreamingMultiplexState impl streaming::multiplex::ClientProto for Client where T: streaming::multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Request = T::Request; type RequestBody = T::RequestBody; @@ -529,7 +529,7 @@ impl streaming::multiplex::ClientProto for Client impl Future for ClientStreamingMultiplexBind where T: streaming::multiplex::ClientProto>, - I: Io + 'static, + I: AsyncRead + AsyncWrite + 'static, { type Item = T::Transport; type Error = io::Error; From 9046dcb75eb6bf25818ab799587ba01d66de81f5 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 16 Mar 2017 10:18:19 +0800 Subject: [PATCH 014/171] [Changed] bump to 0.1.4 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c5f05f4..e4e6f86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.3" +version = "0.1.4" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 2a5640459b7274796529b8ed1e67ffdddcf6a658 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 16 Mar 2017 10:41:41 +0800 Subject: [PATCH 015/171] [Removed] dont need tokio-core --- Cargo.toml | 2 +- src/lib.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e4e6f86..b432583 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,12 +13,12 @@ for nonblocking I/O streams. [dependencies] futures = "0.1" -tokio-core = "0.1" tokio-io = "0.1" rustls = "0.5" tokio-proto = { version = "0.1", optional = true } [dev-dependencies] +tokio-core = "0.1" clap = "2.20" webpki-roots = "0.7" tokio-file-unix = "0.2" diff --git a/src/lib.rs b/src/lib.rs index 36b4d09..1ecbfeb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ #[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; extern crate tokio_io; -extern crate tokio_core; extern crate rustls; pub mod proto; From 2bf0ba169f96e150c450709dba72518791022e3b Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 29 Mar 2017 17:58:30 -1000 Subject: [PATCH 016/171] Clarify and expand documentation. Fix the license badges to point to the correct files. Fix some typos. Add more links. Explain how to run the examples. --- README.md | 33 ++++++++++++++++++++++++++++----- src/lib.rs | 5 ++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index b521c92..5cf9b39 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,10 @@ # tokio-rustls -[![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) -[![license](https://img.shields.io/github/license/quininer/tokio-rustls.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE) -[![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) +[![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) [![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-MIT) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-APACHE) [![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) -[tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). +Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using +[Rustls](https://github.com/ctz/rustls). -### exmaple +### Basic Structure of a Client ```rust // ... @@ -22,3 +21,27 @@ TcpStream::connect(&addr, &handle) // ... ``` + +### Client Example Program + +See [examples/client.rs](examples/client.rs). You can run it with: + +```sh +cargo run --example client google.com +``` + +### Server Example Program + +See [examples/server.rs](examples/server.rs). You can run it with: + +```sh +cargo run --example server -- 127.0.0.1 --cert mycert.der --key mykey.der +``` + +### License & Origin + +tokio-rustls is primarily distributed under the terms of both the [MIT license](LICENSE-MIT) and +the [Apache License (Version 2.0)](LICENSE-APACHE), with portions covered by various BSD-like +licenses. + +This started as a fork of [tokio-tls](https://github.com/tokio-rs/tokio-tls). diff --git a/src/lib.rs b/src/lib.rs index 1ecbfeb..e94534c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,5 @@ -//! Async TLS streams -//! -//! [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). +//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). + #[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; extern crate tokio_io; From e597250fb875df7ec1f36b72de66a6af5f29a66f Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 30 Mar 2017 14:49:24 +0800 Subject: [PATCH 017/171] [Added] example std-client --- Cargo.toml | 2 ++ examples/client.rs | 2 ++ examples/std-client.rs | 72 ++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 ++-- 4 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 examples/std-client.rs diff --git a/Cargo.toml b/Cargo.toml index b432583..4663061 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,4 +21,6 @@ tokio-proto = { version = "0.1", optional = true } tokio-core = "0.1" clap = "2.20" webpki-roots = "0.7" + +[target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.2" diff --git a/examples/client.rs b/examples/client.rs index 2643aba..052c8fe 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,3 +1,5 @@ +#![cfg(unix)] + extern crate clap; extern crate rustls; extern crate futures; diff --git a/examples/std-client.rs b/examples/std-client.rs new file mode 100644 index 0000000..95ed719 --- /dev/null +++ b/examples/std-client.rs @@ -0,0 +1,72 @@ +extern crate clap; +extern crate rustls; +extern crate futures; +extern crate tokio_io; +extern crate tokio_core; +extern crate webpki_roots; +extern crate tokio_rustls; + +use std::sync::Arc; +use std::net::ToSocketAddrs; +use std::io::{ Read, Write, BufReader, stdout, stdin }; +use std::fs; +use futures::Future; +use tokio_io::io; +use tokio_core::net::TcpStream; +use tokio_core::reactor::Core; +use clap::{ App, Arg }; +use rustls::ClientConfig; +use tokio_rustls::ClientConfigExt; + + +fn app() -> App<'static, 'static> { + App::new("client") + .about("tokio-rustls client example") + .arg(Arg::with_name("host").value_name("HOST").required(true)) + .arg(Arg::with_name("port").short("p").long("port").value_name("PORT").help("port, default `443`")) + .arg(Arg::with_name("domain").short("d").long("domain").value_name("DOMAIN").help("domain")) + .arg(Arg::with_name("cafile").short("c").long("cafile").value_name("FILE").help("CA certificate chain")) +} + + +fn main() { + let matches = app().get_matches(); + + let host = matches.value_of("host").unwrap(); + let port = if let Some(port) = matches.value_of("port") { + port.parse().unwrap() + } else { + 443 + }; + let domain = matches.value_of("domain").unwrap_or(host); + let cafile = matches.value_of("cafile"); + let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); + + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let addr = (host, port) + .to_socket_addrs().unwrap() + .next().unwrap(); + + let mut input = Vec::new(); + stdin().read_to_end(&mut input).unwrap(); + + let mut config = ClientConfig::new(); + if let Some(cafile) = cafile { + let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); + config.root_store.add_pem_file(&mut pem).unwrap(); + } else { + config.root_store.add_trust_anchors(&webpki_roots::ROOTS); + } + let arc_config = Arc::new(config); + + let socket = TcpStream::connect(&addr, &handle); + let resp = socket + .and_then(|stream| arc_config.connect_async(domain, stream)) + .and_then(|stream| io::write_all(stream, text.as_bytes())) + .and_then(|(stream, _)| io::write_all(stream, &input)) + .and_then(|(stream, _)| io::read_to_end(stream, Vec::new())) + .and_then(|(_, output)| stdout().write_all(&output)); + + core.run(resp).unwrap(); +} diff --git a/src/lib.rs b/src/lib.rs index e94534c..fa945c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,10 @@ use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; use tokio_io::{ AsyncRead, AsyncWrite }; -use rustls::{ Session, ClientSession, ServerSession }; -use rustls::{ ClientConfig, ServerConfig }; +use rustls::{ + Session, ClientSession, ServerSession, + ClientConfig, ServerConfig +}; /// Extension trait for the `Arc` type in the `rustls` crate. From 0913af1af98ea59f9620ecf4fad38bb68ca8cf9a Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 30 Mar 2017 15:39:11 +0800 Subject: [PATCH 018/171] [Fixed] should flush when shutdown --- Cargo.toml | 8 +++----- src/lib.rs | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4663061..0ea33f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,13 @@ [package] name = "tokio-rustls" -version = "0.1.4" +version = "0.1.5" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" homepage = "https://github.com/quininer/tokio-rustls" documentation = "https://docs.rs/tokio-rustls" -description = """ -An implementation of TLS/SSL streams for Tokio giving an implementation of TLS -for nonblocking I/O streams. -""" +description = "Asynchronous TLS/SSL streams for Tokio using Rustls." +categories = ["asynchronous", "network-programming"] [dependencies] futures = "0.1" diff --git a/src/lib.rs b/src/lib.rs index fa945c4..dc92868 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ #[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; -extern crate tokio_io; +#[macro_use] extern crate tokio_io; extern crate rustls; pub mod proto; @@ -244,6 +244,7 @@ impl AsyncWrite for TlsStream { fn shutdown(&mut self) -> Poll<(), io::Error> { self.session.send_close_notify(); + try_nb!(io::Write::flush(self)); self.io.shutdown() } } From 5b566cadc37b40024cd4f4999cddd887a1ec6a89 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 29 Mar 2017 21:22:06 -1000 Subject: [PATCH 019/171] =?UTF-8?q?Add=20the=20crate=20to=20the=20?= =?UTF-8?q?=E2=80=9Ccryptography=E2=80=9D=20category=20on=20crates.io.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0ea33f4..0fd9c9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ repository = "https://github.com/quininer/tokio-rustls" homepage = "https://github.com/quininer/tokio-rustls" documentation = "https://docs.rs/tokio-rustls" description = "Asynchronous TLS/SSL streams for Tokio using Rustls." -categories = ["asynchronous", "network-programming"] +categories = ["asynchronous", "cryptography", "network-programming"] [dependencies] futures = "0.1" From 67c7d8909b1d20925e3c4f33e20c4923d0689722 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Thu, 30 Mar 2017 10:17:03 -1000 Subject: [PATCH 020/171] Add readme to Cargo.toml. --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 0fd9c9f..06a7126 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" homepage = "https://github.com/quininer/tokio-rustls" documentation = "https://docs.rs/tokio-rustls" +readme = "README.md" description = "Asynchronous TLS/SSL streams for Tokio using Rustls." categories = ["asynchronous", "cryptography", "network-programming"] From 4f59ebf87d4286af2448fd2f5ecc1eba4208a49e Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 29 Mar 2017 21:22:45 -1000 Subject: [PATCH 021/171] Update version in preparation for publishing a new version. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 06a7126..ace27ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.5" +version = "0.1.6" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 22fbb7497fe11d84d7fff0a966c20ea4dbe839d1 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Thu, 30 Mar 2017 10:09:35 -1000 Subject: [PATCH 022/171] Merge std-client.rs into client.rs to unbreak Windows build. The Windows build is failing because client.rs doesn't define main on non-Unixy platforms because of its `#![cfg(unix)]`. Merge std-client.rs into client.rs to solve this and to reduce redundancy. Continue using blocking stdin/stdout I/O on non-Unixy platforms until we get nonblocking stdin/stdout working on those platforms. --- README.md | 9 +++++- examples/client.rs | 59 +++++++++++++++++++++++++++------- examples/std-client.rs | 72 ------------------------------------------ 3 files changed, 55 insertions(+), 85 deletions(-) delete mode 100644 examples/std-client.rs diff --git a/README.md b/README.md index 5cf9b39..d229186 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,14 @@ TcpStream::connect(&addr, &handle) See [examples/client.rs](examples/client.rs). You can run it with: ```sh -cargo run --example client google.com +cargo run --example client hsts.badssl.com +``` + +Currently on Windows the example client reads from stdin and writes to stdout using +blocking I/O. Until this is fixed, do something this on Windows: + +```sh +echo | cargo run --example client hsts.badssl.com ``` ### Server Example Program diff --git a/examples/client.rs b/examples/client.rs index 052c8fe..6ccafc9 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,27 +1,34 @@ -#![cfg(unix)] - extern crate clap; extern crate rustls; extern crate futures; extern crate tokio_io; extern crate tokio_core; extern crate webpki_roots; -extern crate tokio_file_unix; extern crate tokio_rustls; +#[cfg(unix)] +extern crate tokio_file_unix; + use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::{ BufReader, stdout, stdin }; use std::fs; use futures::Future; -use tokio_io::{ io, AsyncRead }; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; +use tokio_io::io; use clap::{ App, Arg }; use rustls::ClientConfig; -use tokio_file_unix::{ StdFile, File }; use tokio_rustls::ClientConfigExt; +#[cfg(unix)] +use tokio_io::AsyncRead; + +#[cfg(unix)] +use tokio_file_unix::{ StdFile, File }; + +#[cfg(not(unix))] +use std::io::{Read, Write}; fn app() -> App<'static, 'static> { App::new("client") @@ -52,13 +59,6 @@ fn main() { .to_socket_addrs().unwrap() .next().unwrap(); - let stdin = stdin(); - let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() - .into_io(&handle).unwrap(); - let stdout = stdout(); - let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() - .into_io(&handle).unwrap(); - let mut config = ClientConfig::new(); if let Some(cafile) = cafile { let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); @@ -69,6 +69,24 @@ fn main() { let arc_config = Arc::new(config); let socket = TcpStream::connect(&addr, &handle); + + // Use async non-blocking I/O for stdin/stdout on Unixy platforms. + + #[cfg(unix)] + let stdin = stdin(); + + #[cfg(unix)] + let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() + .into_io(&handle).unwrap(); + + #[cfg(unix)] + let stdout = stdout(); + + #[cfg(unix)] + let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() + .into_io(&handle).unwrap(); + + #[cfg(unix)] let resp = socket .and_then(|stream| arc_config.connect_async(domain, stream)) .and_then(|stream| io::write_all(stream, text.as_bytes())) @@ -80,5 +98,22 @@ fn main() { .map_err(|(e, _)| e) }); + // XXX: For now, just use blocking I/O for stdin/stdout on other platforms. + // The network I/O will still be asynchronous and non-blocking. + + #[cfg(not(unix))] + let mut input = Vec::new(); + + #[cfg(not(unix))] + stdin().read_to_end(&mut input).unwrap(); + + #[cfg(not(unix))] + let resp = socket + .and_then(|stream| arc_config.connect_async(domain, stream)) + .and_then(|stream| io::write_all(stream, text.as_bytes())) + .and_then(|(stream, _)| io::write_all(stream, &input)) + .and_then(|(stream, _)| io::read_to_end(stream, Vec::new())) + .and_then(|(_, output)| stdout().write_all(&output)); + core.run(resp).unwrap(); } diff --git a/examples/std-client.rs b/examples/std-client.rs deleted file mode 100644 index 95ed719..0000000 --- a/examples/std-client.rs +++ /dev/null @@ -1,72 +0,0 @@ -extern crate clap; -extern crate rustls; -extern crate futures; -extern crate tokio_io; -extern crate tokio_core; -extern crate webpki_roots; -extern crate tokio_rustls; - -use std::sync::Arc; -use std::net::ToSocketAddrs; -use std::io::{ Read, Write, BufReader, stdout, stdin }; -use std::fs; -use futures::Future; -use tokio_io::io; -use tokio_core::net::TcpStream; -use tokio_core::reactor::Core; -use clap::{ App, Arg }; -use rustls::ClientConfig; -use tokio_rustls::ClientConfigExt; - - -fn app() -> App<'static, 'static> { - App::new("client") - .about("tokio-rustls client example") - .arg(Arg::with_name("host").value_name("HOST").required(true)) - .arg(Arg::with_name("port").short("p").long("port").value_name("PORT").help("port, default `443`")) - .arg(Arg::with_name("domain").short("d").long("domain").value_name("DOMAIN").help("domain")) - .arg(Arg::with_name("cafile").short("c").long("cafile").value_name("FILE").help("CA certificate chain")) -} - - -fn main() { - let matches = app().get_matches(); - - let host = matches.value_of("host").unwrap(); - let port = if let Some(port) = matches.value_of("port") { - port.parse().unwrap() - } else { - 443 - }; - let domain = matches.value_of("domain").unwrap_or(host); - let cafile = matches.value_of("cafile"); - let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let addr = (host, port) - .to_socket_addrs().unwrap() - .next().unwrap(); - - let mut input = Vec::new(); - stdin().read_to_end(&mut input).unwrap(); - - let mut config = ClientConfig::new(); - if let Some(cafile) = cafile { - let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); - config.root_store.add_pem_file(&mut pem).unwrap(); - } else { - config.root_store.add_trust_anchors(&webpki_roots::ROOTS); - } - let arc_config = Arc::new(config); - - let socket = TcpStream::connect(&addr, &handle); - let resp = socket - .and_then(|stream| arc_config.connect_async(domain, stream)) - .and_then(|stream| io::write_all(stream, text.as_bytes())) - .and_then(|(stream, _)| io::write_all(stream, &input)) - .and_then(|(stream, _)| io::read_to_end(stream, Vec::new())) - .and_then(|(_, output)| stdout().write_all(&output)); - - core.run(resp).unwrap(); -} From 3d5a36590d21a1163558f81c17fde95fa77f53ae Mon Sep 17 00:00:00 2001 From: quininer kel Date: Fri, 14 Apr 2017 12:43:03 +0800 Subject: [PATCH 023/171] [Fixed] shutdown should only flush io --- Cargo.toml | 2 +- README.md | 4 ++-- examples/client.rs | 2 +- examples/server.rs | 6 +++--- src/lib.rs | 7 +++++-- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ace27ad..8343c6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.6" +version = "0.1.7" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/README.md b/README.md index d229186..4003d57 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,11 @@ Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using ### Basic Structure of a Client ```rust -// ... - use rustls::ClientConfig; use tokio_rustls::ClientConfigExt; +// ... + let mut config = ClientConfig::new(); config.root_store.add_trust_anchors(&webpki_roots::ROOTS); let config = Arc::new(config); diff --git a/examples/client.rs b/examples/client.rs index 6ccafc9..cfc5bca 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -78,7 +78,7 @@ fn main() { #[cfg(unix)] let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() .into_io(&handle).unwrap(); - + #[cfg(unix)] let stdout = stdout(); diff --git a/examples/server.rs b/examples/server.rs index 5d27474..89a049f 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -63,7 +63,7 @@ fn main() { let (reader, writer) = stream.split(); io::copy(reader, writer) }) - .map(move |(n, _, _)| println!("Echo: {} - {}", n, addr)) + .map(move |(n, ..)| println!("Echo: {} - {}", n, addr)) .map_err(move |err| println!("Error: {:?} - {}", err, addr)); handle.spawn(done); @@ -72,11 +72,11 @@ fn main() { let done = arc_config.accept_async(stream) .and_then(|stream| io::write_all( stream, - "HTTP/1.0 200 ok\r\n\ + &b"HTTP/1.0 200 ok\r\n\ Connection: close\r\n\ Content-length: 12\r\n\ \r\n\ - Hello world!".as_bytes() + Hello world!"[..] )) .and_then(|(stream, _)| io::flush(stream)) .map(move |_| println!("Accept: {}", addr)) diff --git a/src/lib.rs b/src/lib.rs index dc92868..1a48ddc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -227,7 +227,7 @@ impl io::Write for TlsStream while self.session.wants_write() { self.session.write_tls(&mut self.io)?; } - Ok(()) + self.io.flush() } } @@ -244,7 +244,10 @@ impl AsyncWrite for TlsStream { fn shutdown(&mut self) -> Poll<(), io::Error> { self.session.send_close_notify(); - try_nb!(io::Write::flush(self)); + while self.session.wants_write() { + try_nb!(self.session.write_tls(&mut self.io)); + } + try_nb!(self.io.flush()); self.io.shutdown() } } From ade576a40343abfe247274b9800cc1a9cd7cc87b Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Mon, 8 May 2017 17:30:56 -1000 Subject: [PATCH 024/171] 0.2.0: Update Rustls and webpki-roots versions. --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8343c6f..4855870 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.1.7" +version = "0.2.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -13,13 +13,13 @@ categories = ["asynchronous", "cryptography", "network-programming"] [dependencies] futures = "0.1" tokio-io = "0.1" -rustls = "0.5" +rustls = "0.7" tokio-proto = { version = "0.1", optional = true } [dev-dependencies] tokio-core = "0.1" clap = "2.20" -webpki-roots = "0.7" +webpki-roots = "0.10.0" [target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.2" From 5a7e49a073b3a732ac6b73fb345295e8baab95fb Mon Sep 17 00:00:00 2001 From: quininer kel Date: Tue, 9 May 2017 12:28:32 +0800 Subject: [PATCH 025/171] [Changed] update dev dependencies --- Cargo.toml | 2 +- LICENSE-APACHE | 2 +- LICENSE-MIT | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4855870..f0dea28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,4 +22,4 @@ clap = "2.20" webpki-roots = "0.10.0" [target.'cfg(unix)'.dev-dependencies] -tokio-file-unix = "0.2" +tokio-file-unix = "0.4" diff --git a/LICENSE-APACHE b/LICENSE-APACHE index 4e411cf..2154394 100644 --- a/LICENSE-APACHE +++ b/LICENSE-APACHE @@ -186,7 +186,7 @@ APPENDIX: How to apply the Apache License to your work. same "printed page" as the copyright notice for easier identification within third-party archives. -Copyright 2016 quininer kel +Copyright 2017 quininer kel Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/LICENSE-MIT b/LICENSE-MIT index d0dfcc7..4500636 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,4 +1,4 @@ -Copyright (c) 2016 quininer kel +Copyright (c) 2017 quininer kel Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated From 0b0ab95419c549b013d3a5bd4d41c49779488a58 Mon Sep 17 00:00:00 2001 From: Benjamin Fry Date: Wed, 17 May 2017 10:27:56 -0700 Subject: [PATCH 026/171] upgrade rustls to 0.8 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f0dea28..f01737e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ categories = ["asynchronous", "cryptography", "network-programming"] [dependencies] futures = "0.1" tokio-io = "0.1" -rustls = "0.7" +rustls = "0.8" tokio-proto = { version = "0.1", optional = true } [dev-dependencies] From d224148327958402f713305fff2ee0aeaaef7334 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Thu, 18 May 2017 02:00:07 +0800 Subject: [PATCH 027/171] [Changed] bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f01737e..438c91b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.2.0" +version = "0.2.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 72c9c1d59ec6a2dea0735e41dd0ad1aa12baf24d Mon Sep 17 00:00:00 2001 From: PZ Read Date: Fri, 26 May 2017 15:54:47 +0800 Subject: [PATCH 028/171] Fix plaintext write logic for limited rustls buffer. --- src/lib.rs | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1a48ddc..8728219 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -209,17 +209,29 @@ impl io::Write for TlsStream where S: AsyncRead + AsyncWrite, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - let output = self.session.write(buf)?; + loop { + let output = self.session.write(buf)?; - while self.session.wants_write() { - match self.session.write_tls(&mut self.io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, - Err(e) => return Err(e) + while self.session.wants_write() { + match self.session.write_tls(&mut self.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if output == 0 { + // Both rustls buffer and IO buffer are blocking. + return Err(io::Error::from(io::ErrorKind::WouldBlock)); + } else { + break; + } + } + Err(e) => return Err(e) + } + } + + if output > 0 { + // Already wrote something out. + return Ok(output); } } - - Ok(output) } fn flush(&mut self) -> io::Result<()> { From 185f01093783b1356d3d135eaa1475b3584b723e Mon Sep 17 00:00:00 2001 From: PZ Read Date: Fri, 26 May 2017 16:38:02 +0800 Subject: [PATCH 029/171] Add async builders for custom session. --- src/lib.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 8728219..87f461f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,15 @@ impl ClientConfigExt for Arc { } } +pub fn connect_async_with_session(stream: S, session: ClientSession) + -> ConnectAsync + where S: AsyncRead + AsyncWrite +{ + ConnectAsync(MidHandshake { + inner: Some(TlsStream::new(stream, session)) + }) +} + impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync @@ -63,6 +72,15 @@ impl ServerConfigExt for Arc { } } +pub fn accept_async_with_session(stream: S, session: ServerSession) + -> AcceptAsync + where S: AsyncRead + AsyncWrite +{ + AcceptAsync(MidHandshake { + inner: Some(TlsStream::new(stream, session)) + }) +} + impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; From d6d06041d9735ba64d09e845b184873e36056e41 Mon Sep 17 00:00:00 2001 From: PZ Read Date: Fri, 26 May 2017 17:59:51 +0800 Subject: [PATCH 030/171] Fix empty buffer --- src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 87f461f..2c6a7e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -227,6 +227,10 @@ impl io::Write for TlsStream where S: AsyncRead + AsyncWrite, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { + if buf.len() == 0 { + return Ok(0); + } + loop { let output = self.session.write(buf)?; From 076c266fa1340a6e02eaec5352fa9ea69c9b36bc Mon Sep 17 00:00:00 2001 From: quininer kel Date: Fri, 26 May 2017 18:22:03 +0800 Subject: [PATCH 031/171] [Changed] bump version --- Cargo.toml | 2 +- src/lib.rs | 26 +++++++++++--------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 438c91b..863d8cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.2.1" +version = "0.2.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/lib.rs b/src/lib.rs index 2c6a7e9..4e3a47f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,12 +46,11 @@ impl ClientConfigExt for Arc { -> ConnectAsync where S: AsyncRead + AsyncWrite { - ConnectAsync(MidHandshake { - inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) - }) + connect_async_with_session(stream, ClientSession::new(self, domain)) } } +#[inline] pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: AsyncRead + AsyncWrite @@ -66,12 +65,11 @@ impl ServerConfigExt for Arc { -> AcceptAsync where S: AsyncRead + AsyncWrite { - AcceptAsync(MidHandshake { - inner: Some(TlsStream::new(stream, ServerSession::new(self))) - }) + accept_async_with_session(stream, ServerSession::new(self)) } } +#[inline] pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: AsyncRead + AsyncWrite @@ -227,7 +225,7 @@ impl io::Write for TlsStream where S: AsyncRead + AsyncWrite, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - if buf.len() == 0 { + if buf.is_empty() { return Ok(0); } @@ -237,14 +235,12 @@ impl io::Write for TlsStream while self.session.wants_write() { match self.session.write_tls(&mut self.io) { Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if output == 0 { - // Both rustls buffer and IO buffer are blocking. - return Err(io::Error::from(io::ErrorKind::WouldBlock)); - } else { - break; - } - } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => if output == 0 { + // Both rustls buffer and IO buffer are blocking. + return Err(io::Error::from(io::ErrorKind::WouldBlock)); + } else { + break; + }, Err(e) => return Err(e) } } From d606a10000ae649106196bd6de62cf219528e522 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Sat, 17 Jun 2017 17:20:25 +0800 Subject: [PATCH 032/171] [Changed] bump rustls to 0.9 --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 863d8cb..7a0fe3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.2.2" +version = "0.2.3" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -13,13 +13,13 @@ categories = ["asynchronous", "cryptography", "network-programming"] [dependencies] futures = "0.1" tokio-io = "0.1" -rustls = "0.8" +rustls = "0.9" tokio-proto = { version = "0.1", optional = true } [dev-dependencies] tokio-core = "0.1" clap = "2.20" -webpki-roots = "0.10.0" +webpki-roots = "0.11.0" [target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.4" From 42e1d72fb287e9e4a173d141230a577198f49721 Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Tue, 18 Jul 2017 14:23:56 -0700 Subject: [PATCH 033/171] Fixed TlsStream closing the connection abruptly on fatal errors. (#12) Instead, flush queued TLS messages when an error occurs before closing the connection. --- src/lib.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4e3a47f..d4d1aa5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,8 +172,16 @@ impl TlsStream continue }, Ok(_) => { - self.session.process_new_packets() - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + if let Err(err) = self.session.process_new_packets() { + // flush queued messages before returning an Err in + // order to send alerts instead of abruptly closing + // the socket + if self.session.wants_write() { + // ignore result to avoid masking original error + let _ = self.session.write_tls(&mut self.io); + } + return Err(io::Error::new(io::ErrorKind::Other, err)); + } continue }, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, From c3961081ec0584c56858c4b1c8b1087ed1568cc4 Mon Sep 17 00:00:00 2001 From: quininer kel Date: Wed, 19 Jul 2017 09:28:45 +0800 Subject: [PATCH 034/171] [Changed] bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7a0fe3a..7e438ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.2.3" +version = "0.2.4" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 36fabdadfd76c99b812ffcd38afb119e16a5162b Mon Sep 17 00:00:00 2001 From: quininer kel Date: Fri, 21 Jul 2017 17:57:57 +0800 Subject: [PATCH 035/171] [Added] danger feature --- Cargo.toml | 3 +++ src/lib.rs | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 7e438ba..9c31b71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ readme = "README.md" description = "Asynchronous TLS/SSL streams for Tokio using Rustls." categories = ["asynchronous", "cryptography", "network-programming"] +[features] +danger = [ "rustls/dangerous_configuration" ] + [dependencies] futures = "0.1" tokio-io = "0.1" diff --git a/src/lib.rs b/src/lib.rs index d4d1aa5..1c1aa11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,11 @@ pub trait ClientConfigExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: AsyncRead + AsyncWrite; + + #[cfg(feature = "danger")] + fn danger_connect_async_without_providing_domain_for_certificate_verification_and_server_name_indication(&self, stream: S) + -> ConnectAsync + where S: AsyncRead + AsyncWrite; } /// Extension trait for the `Arc` type in the `rustls` crate. @@ -48,6 +53,30 @@ impl ClientConfigExt for Arc { { connect_async_with_session(stream, ClientSession::new(self, domain)) } + + #[cfg(feature = "danger")] + fn danger_connect_async_without_providing_domain_for_certificate_verification_and_server_name_indication(&self, stream: S) + -> ConnectAsync + where S: AsyncRead + AsyncWrite + { + use rustls::{ ServerCertVerifier, RootCertStore, Certificate, TLSError }; + + struct NoCertVerifier; + impl ServerCertVerifier for NoCertVerifier { + fn verify_server_cert(&self, _: &RootCertStore, _: &[Certificate], _: &str) + -> Result<(), TLSError> + { + Ok(()) + } + } + + let mut client_config = ClientConfig::new(); + client_config.clone_from(self); + client_config.dangerous() + .set_certificate_verifier(Box::new(NoCertVerifier)); + + Arc::new(client_config).connect_async("", stream) + } } #[inline] From 4b98a7b07a67061f4be0de762bdf27810eaefd17 Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 13 Aug 2017 13:05:47 +0800 Subject: [PATCH 036/171] [Changed] update rustls --- Cargo.toml | 4 ++-- src/lib.rs | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9c31b71..4372692 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.2.4" +version = "0.3.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -16,7 +16,7 @@ danger = [ "rustls/dangerous_configuration" ] [dependencies] futures = "0.1" tokio-io = "0.1" -rustls = "0.9" +rustls = "0.10" tokio-proto = { version = "0.1", optional = true } [dev-dependencies] diff --git a/src/lib.rs b/src/lib.rs index 1c1aa11..5c24c18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,21 +59,21 @@ impl ClientConfigExt for Arc { -> ConnectAsync where S: AsyncRead + AsyncWrite { - use rustls::{ ServerCertVerifier, RootCertStore, Certificate, TLSError }; + use rustls::{ ServerCertVerifier, RootCertStore, Certificate, ServerCertVerified, TLSError }; struct NoCertVerifier; impl ServerCertVerifier for NoCertVerifier { - fn verify_server_cert(&self, _: &RootCertStore, _: &[Certificate], _: &str) - -> Result<(), TLSError> + fn verify_server_cert(&self, _: &RootCertStore, _: &[Certificate], _: &str, _: &[u8]) + -> Result { - Ok(()) + Ok(ServerCertVerified::assertion()) } } let mut client_config = ClientConfig::new(); client_config.clone_from(self); client_config.dangerous() - .set_certificate_verifier(Box::new(NoCertVerifier)); + .set_certificate_verifier(Arc::new(NoCertVerifier)); Arc::new(client_config).connect_async("", stream) } From 037f84ea9838c27495630edc8f47aebb16a72c5b Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 13 Aug 2017 18:19:17 +0800 Subject: [PATCH 037/171] [Added] tests --- .travis.yml | 10 ++++ appveyor.yml | 18 +++++++ tests/end.cert | 24 ++++++++++ tests/end.chain | 89 +++++++++++++++++++++++++++++++++++ tests/end.rsa | 27 +++++++++++ tests/test.rs | 121 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 289 insertions(+) create mode 100644 .travis.yml create mode 100644 appveyor.yml create mode 100644 tests/end.cert create mode 100644 tests/end.chain create mode 100644 tests/end.rsa create mode 100644 tests/test.rs diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..0d7365a --- /dev/null +++ b/.travis.yml @@ -0,0 +1,10 @@ +language: rust +rust: + - stable +cache: cargo +os: + - linux + - osx + +script: + - cargo test --all-features diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 0000000..80ae684 --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,18 @@ +environment: + matrix: + - TARGET: x86_64-pc-windows-msvc + BITS: 64 + - TARGET: i686-pc-windows-msvc + BITS: 32 + +install: + - appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe + - rustup-init.exe -y --default-host %TARGET% + - set PATH=%PATH%;%USERPROFILE%\.cargo\bin + - rustc --version + - cargo --version + +build: false + +test_script: + - 'cargo test --all-features' diff --git a/tests/end.cert b/tests/end.cert new file mode 100644 index 0000000..66f087e --- /dev/null +++ b/tests/end.cert @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u +eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTE2MTIxMDE3NDIzM1oX +DTIyMDYwMjE3NDIzM1owGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1YDz66+7VD4DL1+/sVHMQ+BbDRgmD +OQlX++mfW8D3QNQm/qDBEbu7T7qqdc9GKDar4WIzBN8SBkzM1EjMGwNnZPV/Tfz0 +qUAR1L/7Zzf1GaFZvWXgksyUpfwvmprH3Iy/dpkETwtPthpTPNlui3hZnm/5kkjR +RWg9HmID4O04Ld6SK313v2ZgrPZbkKvbqlqhUnYWjL3blKVGbpXIsuZzEU9Ph+gH +tPcEhZpFsM6eLe+2TVscIrycMEOTXqAAmO6zZ9sQWtfllu3CElm904H6+jA/9Leg +al72pMmkYr8wWniqDDuijXuCPlVx5EDFFyxBmW18UeDEQaKV3kNfelaTAgMBAAGj +gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFIYhJkVy +AAKT6cY/ruH1Eu+NNxteMEIGA1UdIwQ7MDmAFNwuPy4Do//Sm5CZDrocHWTrNr96 +oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO +dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0 +MA0GCSqGSIb3DQEBCwUAA4IBgQCWV76jfQDZKtfmj45fTwZzoe/PxjWPRbAvSEnt +LRHrPhqQfpMLqpun8uu/w86mHiR/AmiAySMu3zivW6wfGzlRWLi/zCyO6r9LGsgH +bNk5CF642cdZFvn1SiSm1oGXQrolIpcyXu88nUpt74RnY4ETCC1dRQKqxsYufe5T +DOmTm3ChinNW4QRG3yvW6DVuyxVAgZvofyKJOsM3GO6oogIM41aBqZ3UTwmIwp6D +oISdiATslFOzYzjnyXNR8DG8OOkv1ehWuyb8x+hQCZAuogQOWYtCSd6k3kKgd0EM +4CWbt1XDV9ZJwBf2uxZeKuCu/KIy9auNtijAwPsUv9qxuzko018zhl3lWm5p2Sqw +O7fFshU3A6df8hMw7ST6/tgFY7geT88U4iJhfWMwr/CZSRSVMXhTyJgbLIXxKYZj +Ym5v4NAIQP6hI4HixzQaYgrhW6YX6myk+emMjQLRJHT8uHvmT7fuxMJVWWgsCkr1 +C75pRQEagykN/Uzr5e6Tm8sVu88= +-----END CERTIFICATE----- diff --git a/tests/end.chain b/tests/end.chain new file mode 100644 index 0000000..7c39013 --- /dev/null +++ b/tests/end.chain @@ -0,0 +1,89 @@ +-----BEGIN CERTIFICATE----- +MIIGnzCCAoegAwIBAgIBezANBgkqhkiG9w0BAQsFADAaMRgwFgYDVQQDDA9wb255 +dG93biBSU0EgQ0EwHhcNMTYxMjEwMTc0MjMzWhcNMjYxMjA4MTc0MjMzWjAsMSow +KAYDVQQDDCFwb255dG93biBSU0EgbGV2ZWwgMiBpbnRlcm1lZGlhdGUwggGiMA0G +CSqGSIb3DQEBAQUAA4IBjwAwggGKAoIBgQDnfb7vaJbaHEyVTflswWhmHqx5W0NO +KyKbDp2zXEJwDO+NDJq6i1HGnFd/vO4LyjJBU1wUsKtE+m55cfRmUHVuZ2w4n/VF +p7Z7n+SNuvJNcrzDxyKVy4GIZ39zQePnniqtLqXh6eI8Ow6jiMgVxC/wbWcVLKv6 +4RM+2fLjJAC9b27QfjhOlMKVeMOEvPrrpjLSauaHAktQPhuzIAwzxM0+KnvDkWWy +NVqAV/lq6fSO/9vJRhM4E2nxo6yqi7qTdxVxMmKsNn7L6HvjQgx+FXziAUs55Qd9 +cP7etCmPmoefkcgdbxDOIKH8D+DvfacZwngqcnr/q96Ff4uJ13d2OzR1mWVSZ2hE +JQt/BbZBANciqu9OZf3dj6uOOXgFF705ak0GfLtpZpc29M+fVnknXPDSiKFqjzOO +KL+SRGyuNc9ZYjBKkXPJ1OToAs6JSvgDxfOfX0thuo2rslqfpj2qCFugsRIRAqvb +eyFwg+BPM/P/EfauXlAcQtBF04fOi7xN2okCAwEAAaNeMFwwHQYDVR0OBBYEFNwu +Py4Do//Sm5CZDrocHWTrNr96MCAGA1UdJQEB/wQWMBQGCCsGAQUFBwMBBggrBgEF +BQcDAjAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIB/jANBgkqhkiG9w0BAQsFAAOC +BAEAMHZpBqDIUAVFZNw4XbuimXQ4K8q4uePrLGHLb4F/gHbr8kYrU4H+cy4l+xXf +2dlEBdZoqjSF7uXzQg5Fd8Ff3ZgutXd1xeUJnxo0VdpKIhqeaTPqhffC2X6FQQH5 +KrN7NVWQSnUhPNpBFELpmdpY1lHigFW7nytYj0C6VJ4QsbqhfW+n/t+Zgqtfh/Od +ZbclzxFwMM55zRA2HP6IwXS2+d61Jk/RpDHTzhWdjGH4906zGNNMa7slHpCTA9Ju +TrtjEAGt2PBSievBJOHZW80KVAoEX2n9B3ZABaz+uX0VVZG0D2FwhPpUeA57YiXu +qiktZR4Ankph3LabXp4IlAX16qpYsEW8TWE/HLreeqoM0WDoI6rF9qnTpV2KWqBf +ziMYkfSkT7hQ2bWc493lW+QwSxCsuBsDwlrCwAl6jFSf1+jEQx98/8n9rDNyD9dL +PvECmtF30WY98nwZ9/kO2DufQrd0mwSHcIT0pAwl5fimpkwTjj+TTbytO3M4jK5L +tuIzsViQ95BmJQ3XuLdkQ/Ug8rpECYRX5fQX1qXkkvl920ohpKqKyEji1OmfmJ0Z +tZChaEcu3Mp3U+gD4az2ogmle3i/Phz8ZEPFo4/21G5Qd72z0lBgaQIeyyCk5MHt +Yg0vA7X0/w4bz+OJv5tf7zJsPCYSprr+c/7YUJk9Fqu6+g9ZAavI99xFKdGhz4Og +w0trnKNCxYc6+NPopTDbXuY+fo4DK7C0CSae5sKs7013Ne6w4KvgfLKpvlemkGfg +ZA3+1FMXVfFIEH7Cw9cx6F02Sr3k1VrU68oM3wH5nvTUkELOf8nRMlzliQjVCpKB +yFSe9dzRVSFEbMDxChiEulGgNUHj/6wwpg0ZmCwPRHutppT3jkfEqizN5iHb69GH +k6kol6knJofkaL656Q3Oc9o0ZrMlFh1RwmOvAk5fVK0/CV88/phROz2Wdmy5Bz4a +t0vzqFWA54y6+9EEVoOk9SU0CYfpGtpX4URjLK1EUG/l+RR3366Uee6TPrtEZ9cg +56VQMxhSaRNAvJ6DfiSuscSCNJzwuXaMXSZydGYnnP9Tb9p6c1uy1sXdluZkBIcK +CgC+gdDMSNlDn9ghc4xZGkuA8bjzfAYuRuGKmfTt8uuklkjw2b9w3SHjC4/Cmd2W +cFRnzfg2oL6e78hNg2ZGgsLzvb6Lu6/5IhXCO7RitzYf2+HLBbc+YLFsnG3qeGe1 +28yGnXOQd97Cr4+IzFucVy/33gMQkesNUSDFJSq1gE/hGrMgTTMQJ7yC3PRqg0kG +tpqTyKNdM0g1adxlR1qfDPvpUBApkgBbySnMyWEr5+tBuoHUtH2m49oV9YD4odMJ +yJjlGxituO/YNN6O8oANlraG1Q== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIJBzCCBO+gAwIBAgIJAN7WS1mRS9A+MA0GCSqGSIb3DQEBCwUAMBoxGDAWBgNV +BAMMD3Bvbnl0b3duIFJTQSBDQTAeFw0xNjEyMTAxNzQyMzNaFw0yNjEyMDgxNzQy +MzNaMBoxGDAWBgNVBAMMD3Bvbnl0b3duIFJTQSBDQTCCBCIwDQYJKoZIhvcNAQEB +BQADggQPADCCBAoCggQBAMNEzJ7aNdD2JSk9+NF9Hh2za9OQnt1d/7j6DtE3ieoT +ms8mMSXzoImXZayZ9Glx3yx/RhEb2vmINyb0vRUM4I/GH+XHdOBcs9kaJNv/Mpw4 +Ggd4e1LUqV1pzNrhYwRrTQTKyaDiDX2WEBNfQaaYnHltmSmsfyt3Klj+IMc6CyqV +q8SOQ6Go414Vn++Jj7p3E6owdwuvSvO8ERLobiA6vYB+qrS7E48c4zRIAFIO4uwt +g4TiCJLLWc1fRSoqGGX7KS+LzQF8Pq67IOHVna4e9peSe6nQnm0LQZAmaosYHvF4 +AX0Bj6TLv9PXCAGtB7Pciev5Br0tRZEdVyYfmwiVKUWcp77TghV3W+VaJVhPh5LN +X91ktvpeYek3uglqv2ZHtSG2S1KkBtTkbMOD+a2BEUfq0c0+BIsj6jdvt4cvIfet +4gUOxCvYMBs4/dmNT1zoe/kJ0lf8YXYLsXwVWdIW3jEE8QdkLtLI9XfyU9OKLZuD +mmoAf7ezvv/T3nKLFqhcwUFGgGtCIX+oWC16XSbDPBcKDBwNZn8C49b7BLdxqAg3 +msfxwhYzSs9F1MXt/h2dh7FVmkCSxtgNDX3NJn5/yT6USws2y0AS5vXVP9hRf0NV +KfKn9XlmHCxnZExwm68uZkUUYHB05jSWFojbfWE+Mf9djUeQ4FuwusztZdbyQ4yS +mMtBXO0I6SQBmjCoOa1ySW3DTuw/eKCfq+PoxqWD434bYA9nUa+pE27MP7GLyjCS +6+ED3MACizSF0YxkcC9pWUo4L5FKp+DxnNbtzMIILnsDZTVHOvKUy/gjTyTWm/+7 +2t98l7vBE8gn3Aux0V5WFe2uZIZ07wIi/OThoBO8mpt9Bm5cJTG07JStKEXX/UH1 +nL7cDZ2V5qbf4hJdDy4qixxxIZtmf//1BRlVQ9iYTOsMoy+36DXWbc3vSmjRefW1 +YENt4zxOPe4LUq2Z+LXq1OgVQrHrVevux0vieys7Rr2gA1sH8FaaNwTr7Q8dq+Av +Evk+iOUH4FuYorU1HuGHPkAkvLWosVwlB+VhfEai0V6+PmttmaOnCJNHfFTu5wCu +B9CFJ1tdzTzAbrLwgtWmO70KV7CfZPHO7lMWhSvplU0i5T9WytxP91IoFtXwRSO8 ++Ghyu0ynB3HywCH2dez89Vy903P6PEU0qTnYWRz6D/wi5+yHHNrm9CilWurs/Qex +kyB7lLD7Cb1JJc8QIFTqT6vj+cids3xd245hUdpFyZTX99YbF6IkiB2zGi5wvUmP +f1GPvkTLb7eF7bne9OClEjEqvc0hVJ2abO2WXkqxlQFEYZHNofm+y6bnby/BZZJo +beaSFcLOCe2Z8iZvVnzfHBCeLyWE89gc94z784S3LEsCAwEAAaNQME4wHQYDVR0O +BBYEFNz2wEPCQbx9OdRCNE4eALwHJfIgMB8GA1UdIwQYMBaAFNz2wEPCQbx9OdRC +NE4eALwHJfIgMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADggQBACbm2YX7 +sBG0Aslj36gmVlCTTluNg2tuK2isHbK3YhNwujrH/o/o2OV7UeUkZkPwE4g4/SjC +OwDWYniRNyDKBOeD9Q0XxR5z5IZQO+pRVvXF8DXO6kygWCOJM9XheKxp9Uke0aDg +m8F02NslKLUdy7piGlLSz1sgdjiE3izIwFZRpZY7sMozNWWvSAmzprbkE78LghIm +VEydQzIQlr5soWqc65uFLNbEA6QBPoFc6dDW+mnzXf8nrZUM03CACxAsuq/YkjRp +OHgwgfdNRdlu4YhZtuQNak4BUvDmigTGxDC+aMJw0ldL1bLtqLG6BvQbyLNPOOfo +5S8lGh4y06gb//052xHaqtCh5Ax5sHUE5By6wKHAKbuJy26qyKfaRoc3Jigs4Fd5 +3CuoDWHbyXfkgKiU+sc+1mvCxQKFRJ2fpGEFP8iEcLvdUae7ZkRM4Kb0vST+QhQV +fDaFkM3Bwqtui5YaZ6cHHQVyXQdujCmfesoZXKil2yduQ3KWgePjewzRV+aDWMzk +qKaF+TRANSqWbBU6JTwwQ4veKQThU3ir7nS2ovdPbhNS/FnWoKodj6eaqXfdYuBh +XOXLewIF568MJsLOuBubeAO2a9LOlhnv6eLGp2P4M7vwEdN/LRRQtwBBmqq8C3h+ +ewrJP12B/ag0bJDi9vCgPhYtDEpjpfsnxZEIqVZwshJ/MqXykFp2kYk62ylyfDWq +veI/aHwpzT2k+4CI/XmPWXl9NlI50HPdpcwCBDy8xVHwb/x7stNgQdIhaj9tzmKa +S+eqitclc8Iqrbd523H//QDzm8yiqRZUdveNa9gioTMErR0ujCpK8tO8mVZcVfNX +i1/Vsar5++nXcPhxKsd1t8XV2dk3gUZIfMgzLLzs+KSiFg+bT3c7LkCd+I3w30Iv +fh9cxFBAyYO9giwxaCfJgoz7OYqaHOOtASF85UV7gK9ELT7/z+RAcS/UfY1xbd54 +hIi1vRZj8lfkAYNtnYlud44joi1BvW/GZGFCiJ13SSvfHNs9v/5xguyCSgyCc0qx +ZkN/fzj/5wFQbxSl3MPn/JrsvlH6wvJht1SA50uVdUvJ5e5V8EgLYfMqlJNNpTHP +wZcHF+Dw126oyu2KhUxD126Gusxp+tV6I0EEZnVwwduFQWq9xm/gT+qohpveeylf +Q2XGz56DF2udJJnSFGSqzQOl9XopNC/4ecBMwIzqdFSpaWgK3VNAcigyDajgoE4v +ZuiVDEiLhLowZvi1V8GOWzcka7R2BQBjhOLWByQGDcm8cOMS7w8oCSQCaYmJyHvE +tTHq7fX6/sXv0AJqM3ysSdU01IVBNahnr5WEkmQMaFF0DGvRfqkVdKcChwrKv7r2 +DLxargy39i2aQGg= +-----END CERTIFICATE----- diff --git a/tests/end.rsa b/tests/end.rsa new file mode 100644 index 0000000..744bba5 --- /dev/null +++ b/tests/end.rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAtWA8+uvu1Q+Ay9fv7FRzEPgWw0YJgzkJV/vpn1vA90DUJv6g +wRG7u0+6qnXPRig2q+FiMwTfEgZMzNRIzBsDZ2T1f0389KlAEdS/+2c39RmhWb1l +4JLMlKX8L5qax9yMv3aZBE8LT7YaUzzZbot4WZ5v+ZJI0UVoPR5iA+DtOC3ekit9 +d79mYKz2W5Cr26paoVJ2Foy925SlRm6VyLLmcxFPT4foB7T3BIWaRbDOni3vtk1b +HCK8nDBDk16gAJjus2fbEFrX5ZbtwhJZvdOB+vowP/S3oGpe9qTJpGK/MFp4qgw7 +oo17gj5VceRAxRcsQZltfFHgxEGild5DX3pWkwIDAQABAoIBAFDTazlSbGML/pRY +TTWeyIw2UkaA7npIr45C13BJfitw+1nJPK/tDCDDveZ6i3yzLPHZhV5A/HtWzWC1 +9R7nptOrnO83PNN2nPOVQFxzOe+ClXGdQkoagQp5EXHRTspj0WD9I+FUrDDAcOjJ +BAgMJPyi6zlnZAXGDVa3NGyQDoZqwU2k36L4rEsJIkG0NVurZhpiCexNkkf32495 +TOINQ0iKdfJ4iZoEYQ9G+x4NiuAJRCHuIcH76SNfT+Uv3wX0ut5EFPtflnvtdgcp +QVcoKwYdO0+mgO5xqWlBcsujSvgBdiNAGnAxKHWiEaacuIJi4+yYovyEebP6QI2X +Zg/U2wkCgYEA794dE5CPXLOmv6nioVC/ubOESk7vjSlEka/XFbKr4EY794YEqrB1 +8TUqg09Bn3396AS1e6P2shr3bxos5ybhOxDGSLnJ+aC0tRFjd1BPKnA80vZM7ggt +5cjmdD5Zp0tIQTIAAYU5bONQOwj0ej4PE7lny26eLa5vfvCwlrD+rM0CgYEAwZMN +W/5PA2A+EM08IaHic8my0dCunrNLF890ouZnDG99SbgMGvvEsGIcCP1sai702hNh +VgGDxCz6/HUy+4O4YNFVtjY7uGEpfIEcEI7CsLQRP2ggWEFxThZtnEtO8PbM3J/i +qcS6njHdE+0XuCjgZwGgva5xH2pkWFzw/AIpEN8CgYB2HOo2axWc8T2n3TCifI+c +EqCOsqXU3cBM+MgxgASQcCUxMkX0AuZguuxPMmS+85xmdoMi+c8NTqgOhlYcEJIR +sqXgw9OH3zF8g6513w7Md+4Ld4rUHyTypGWOUfF1pmVS7RsBpKdtTdWA7FzuIMbt +0HsiujqbheyTFlPuMAOH9QKBgBWS1gJSrWuq5j/pH7J/4EUXTZ6kq1F0mgHlVRJy +qzlvk38LzA2V0a32wTkfRV3wLcnALzDuqkjK2o4YYb42R+5CZlMQaEd8TKtbmE0g +HAKljuaKLFCpun8BcOXiXsHsP5i3GQPisQnAdOsrmWEk7R2NyORa9LCToutWMGVl +uD3xAoGAA183Vldm+m4KPsKS17t8MbwBryDXvowGzruh/Z+PGA0spr+ke4XxwT1y +kMMP1+5flzmjlAf4+W8LehKuVqvQoMlPn5UVHmSxQ7cGx/O/o6Gbn8Q25/6UT+sM +B1Y0rlLoKG62pnkeXp1O4I57gnClatWRg5qw11a8V8e3jvDKIYM= +-----END RSA PRIVATE KEY----- diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..ecf659d --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,121 @@ +extern crate rustls; +extern crate futures; +extern crate tokio_core; +extern crate tokio_io; +extern crate tokio_rustls; + +use std::{ io, thread }; +use std::io::{ BufReader, Cursor }; +use std::sync::Arc; +use std::sync::mpsc::channel; +use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; +use futures::{ Future, Stream }; +use tokio_core::reactor::Core; +use tokio_core::net::{ TcpListener, TcpStream }; +use tokio_io::io as aio; +use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; + +const CERT: &str = include_str!("end.cert"); +const CHAIN: &str = include_str!("end.chain"); +const RSA: &str = include_str!("end.rsa"); +const HELLO_WORLD: &[u8] = b"Hello world!"; + + +fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { + let mut config = ServerConfig::new(); + config.set_single_cert(cert, rsa); + let config = Arc::new(config); + + let (send, recv) = channel(); + + thread::spawn(move || { + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let listener = TcpListener::bind(&addr, &handle).unwrap(); + + send.send(listener.local_addr().unwrap()).unwrap(); + + let done = listener.incoming() + .for_each(|(stream, _)| { + let done = config.accept_async(stream) + .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) + .and_then(|(stream, buf)| { + assert_eq!(buf, HELLO_WORLD); + aio::write_all(stream, HELLO_WORLD) + }) + .map(drop) + .map_err(drop); + + handle.spawn(done); + Ok(()) + }) + .map(drop) + .map_err(drop); + core.run(done).unwrap(); + }); + + recv.recv().unwrap() +} + +fn start_client(addr: &SocketAddr, domain: Option<&str>, chain: Option>>) -> io::Result<()> { + let mut config = ClientConfig::new(); + if let Some(mut chain) = chain { + config.root_store.add_pem_file(&mut chain).unwrap(); + } + let config = Arc::new(config); + + let mut core = Core::new()?; + let handle = core.handle(); + + #[allow(unreachable_code, unused_variables)] + let done = TcpStream::connect(addr, &handle) + .and_then(|stream| if let Some(domain) = domain { + config.connect_async(domain, stream) + } else { + #[cfg(feature = "danger")] + let c = config.danger_connect_async_without_providing_domain_for_certificate_verification_and_server_name_indication(stream); + + #[cfg(not(feature = "danger"))] + let c = panic!(); + + c + }) + .and_then(|stream| aio::write_all(stream, HELLO_WORLD)) + .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) + .and_then(|(_, buf)| { + assert_eq!(buf, HELLO_WORLD); + Ok(()) + }); + + core.run(done) +} + + +#[test] +fn main() { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let chain = BufReader::new(Cursor::new(CHAIN)); + + let addr = start_server(cert, keys.pop().unwrap()); + + start_client(&addr, Some("localhost"), Some(chain)).unwrap(); + + #[cfg(feature = "danger")] + start_client(&addr, None, None).unwrap(); +} + +#[should_panic] +#[test] +fn fail() { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let chain = BufReader::new(Cursor::new(CHAIN)); + + let addr = start_server(cert, keys.pop().unwrap()); + + start_client(&addr, Some("google.com"), Some(chain)).unwrap(); +} From aefc023dd4abb710bb2381eacfe8077729ffbd88 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 15 Aug 2017 22:00:20 +0800 Subject: [PATCH 038/171] [Fixed] call only once send_close_notify --- Cargo.toml | 8 ++++++-- README.md | 7 ++++++- src/lib.rs | 7 ++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4372692..796d24f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.3.0" +version = "0.3.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -10,6 +10,10 @@ readme = "README.md" description = "Asynchronous TLS/SSL streams for Tokio using Rustls." categories = ["asynchronous", "cryptography", "network-programming"] +[badges] +travis-ci = { repository = "quininer/tokio-rustls" } +appveyor = { repository = "quininer/tokio-rustls" } + [features] danger = [ "rustls/dangerous_configuration" ] @@ -22,7 +26,7 @@ tokio-proto = { version = "0.1", optional = true } [dev-dependencies] tokio-core = "0.1" clap = "2.20" -webpki-roots = "0.11.0" +webpki-roots = "0.12" [target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.4" diff --git a/README.md b/README.md index 4003d57..9612f64 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,10 @@ # tokio-rustls -[![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) [![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-MIT) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-APACHE) [![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) +[![travis-ci](https://travis-ci.org/quininer/tokio-rustls.svg?branch=master)](https://travis-ci.org/quininer/tokio-rustls) +[![appveyor](https://ci.appveyor.com/api/projects/status/4ukw15enii50suqi?svg=true)](https://ci.appveyor.com/project/quininer/tokio-rustls) +[![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) +[![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-MIT) +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-APACHE) +[![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using [Rustls](https://github.com/ctz/rustls). diff --git a/src/lib.rs b/src/lib.rs index 5c24c18..978eb6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -165,6 +165,7 @@ impl Future for MidHandshake /// protocol. #[derive(Debug)] pub struct TlsStream { + is_shutdown: bool, eof: bool, io: S, session: C @@ -186,6 +187,7 @@ impl TlsStream #[inline] pub fn new(io: S, session: C) -> TlsStream { TlsStream { + is_shutdown: false, eof: false, io: io, session: session @@ -310,7 +312,10 @@ impl AsyncWrite for TlsStream C: Session { fn shutdown(&mut self) -> Poll<(), io::Error> { - self.session.send_close_notify(); + if !self.is_shutdown { + self.session.send_close_notify(); + self.is_shutdown = true; + } while self.session.wants_write() { try_nb!(self.session.write_tls(&mut self.io)); } From eccf90a5343cbc093eb94f0fc9d0bc0778117613 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Mon, 28 Aug 2017 18:40:16 -1000 Subject: [PATCH 039/171] Remove `danger` feature & the API it controls. The singular purpose of this crate should be to integrate Tokio and Rustls. Therefore, any feature that isn't about making Rustls work nicely with Tokio should be assumed a priori to be out of scope. In particular, it is out of scope for tokio-rustls to provide APIs to control SNI behavior. Instead, the application should configure Rustls's SNI behavior using Rustls's configuration APIs, and pass the configuration to tokio-rustls. Similarly, it is out of scope for tokio-rustls to provide APIs to control the certificate validation behavior. Instead, the application should configure certificate validation using Rustls's APIs. Perhaps there should be a crate that makes it convenient to do "dangerous" certificate validation, but IMO that shouldn't be tokio-rustls, but a different one. FWIW, the `danger` API was inherited from tokio-tls, and I'm working on making an analogous change there. --- Cargo.toml | 3 --- src/lib.rs | 29 ----------------------------- tests/test.rs | 21 ++++----------------- 3 files changed, 4 insertions(+), 49 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 796d24f..a739599 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,9 +14,6 @@ categories = ["asynchronous", "cryptography", "network-programming"] travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } -[features] -danger = [ "rustls/dangerous_configuration" ] - [dependencies] futures = "0.1" tokio-io = "0.1" diff --git a/src/lib.rs b/src/lib.rs index 978eb6c..12cdf53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,11 +22,6 @@ pub trait ClientConfigExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: AsyncRead + AsyncWrite; - - #[cfg(feature = "danger")] - fn danger_connect_async_without_providing_domain_for_certificate_verification_and_server_name_indication(&self, stream: S) - -> ConnectAsync - where S: AsyncRead + AsyncWrite; } /// Extension trait for the `Arc` type in the `rustls` crate. @@ -53,30 +48,6 @@ impl ClientConfigExt for Arc { { connect_async_with_session(stream, ClientSession::new(self, domain)) } - - #[cfg(feature = "danger")] - fn danger_connect_async_without_providing_domain_for_certificate_verification_and_server_name_indication(&self, stream: S) - -> ConnectAsync - where S: AsyncRead + AsyncWrite - { - use rustls::{ ServerCertVerifier, RootCertStore, Certificate, ServerCertVerified, TLSError }; - - struct NoCertVerifier; - impl ServerCertVerifier for NoCertVerifier { - fn verify_server_cert(&self, _: &RootCertStore, _: &[Certificate], _: &str, _: &[u8]) - -> Result - { - Ok(ServerCertVerified::assertion()) - } - } - - let mut client_config = ClientConfig::new(); - client_config.clone_from(self); - client_config.dangerous() - .set_certificate_verifier(Arc::new(NoCertVerifier)); - - Arc::new(client_config).connect_async("", stream) - } } #[inline] diff --git a/tests/test.rs b/tests/test.rs index ecf659d..86c715e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -60,7 +60,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { recv.recv().unwrap() } -fn start_client(addr: &SocketAddr, domain: Option<&str>, chain: Option>>) -> io::Result<()> { +fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { let mut config = ClientConfig::new(); if let Some(mut chain) = chain { config.root_store.add_pem_file(&mut chain).unwrap(); @@ -72,17 +72,7 @@ fn start_client(addr: &SocketAddr, domain: Option<&str>, chain: Option Date: Sun, 27 Aug 2017 18:13:08 -1000 Subject: [PATCH 040/171] 0.4.0: Use rustls 0.11, webpki-roots 0.13, and update other deps. --- Cargo.toml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 796d24f..c82a3d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.3.1" +version = "0.4.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -18,15 +18,15 @@ appveyor = { repository = "quininer/tokio-rustls" } danger = [ "rustls/dangerous_configuration" ] [dependencies] -futures = "0.1" -tokio-io = "0.1" -rustls = "0.10" -tokio-proto = { version = "0.1", optional = true } +futures = "0.1.15" +tokio-io = "0.1.3" +rustls = "0.11" +tokio-proto = { version = "0.1.1", optional = true } [dev-dependencies] -tokio-core = "0.1" -clap = "2.20" -webpki-roots = "0.12" +tokio-core = "0.1.9" +clap = "2.26" +webpki-roots = "0.13" [target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.4" From 4b2b016024850ec8c671617c58711937a30894a6 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Sun, 27 Aug 2017 18:53:47 -1000 Subject: [PATCH 041/171] Update examples for webpki-roots API changes. --- README.md | 2 +- examples/client.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9612f64..d9db909 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ use tokio_rustls::ClientConfigExt; // ... let mut config = ClientConfig::new(); -config.root_store.add_trust_anchors(&webpki_roots::ROOTS); +config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); let config = Arc::new(config); TcpStream::connect(&addr, &handle) diff --git a/examples/client.rs b/examples/client.rs index cfc5bca..4a737f6 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -64,7 +64,7 @@ fn main() { let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); config.root_store.add_pem_file(&mut pem).unwrap(); } else { - config.root_store.add_trust_anchors(&webpki_roots::ROOTS); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } let arc_config = Arc::new(config); From 51ed8da9cbfc4ad7a4a790ffaa2a28e3c3dcec48 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Sun, 3 Sep 2017 12:58:55 -1000 Subject: [PATCH 042/171] Update to in-progress Rustls, webpki, and webpki-roots. Use the new, less error-prone, API in Rustls. --- Cargo.toml | 5 +++-- examples/client.rs | 3 +++ examples/server.rs | 4 ++-- src/lib.rs | 5 +++-- src/proto.rs | 15 ++++++++------- tests/test.rs | 8 +++++--- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6e8ec9f..cc6a0c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,13 +17,14 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] futures = "0.1.15" tokio-io = "0.1.3" -rustls = "0.11" +rustls = { git = "https://github.com/ctz/rustls" } tokio-proto = { version = "0.1.1", optional = true } +webpki = { git = "https://github.com/briansmith/webpki" } [dev-dependencies] tokio-core = "0.1.9" clap = "2.26" -webpki-roots = "0.13" +webpki-roots = { git = "https://github.com/briansmith/webpki-roots", branch = "webpki-github" } [target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.4" diff --git a/examples/client.rs b/examples/client.rs index 4a737f6..418a6a4 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -3,6 +3,7 @@ extern crate rustls; extern crate futures; extern crate tokio_io; extern crate tokio_core; +extern crate webpki; extern crate webpki_roots; extern crate tokio_rustls; @@ -68,6 +69,8 @@ fn main() { } let arc_config = Arc::new(config); + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); + let socket = TcpStream::connect(&addr, &handle); // Use async non-blocking I/O for stdin/stdout on Unixy platforms. diff --git a/examples/server.rs b/examples/server.rs index 89a049f..9d407b9 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -11,7 +11,7 @@ use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; use futures::{ Future, Stream }; -use rustls::{ Certificate, PrivateKey, ServerConfig }; +use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_io::{ io, AsyncRead }; use tokio_core::net::TcpListener; @@ -51,7 +51,7 @@ fn main() { let mut core = Core::new().unwrap(); let handle = core.handle(); - let mut config = ServerConfig::new(); + let mut config = ServerConfig::new(NoClientAuth::new()); config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)); let arc_config = Arc::new(config); diff --git a/src/lib.rs b/src/lib.rs index 12cdf53..8ec6017 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ #[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; #[macro_use] extern crate tokio_io; extern crate rustls; +extern crate webpki; pub mod proto; @@ -19,7 +20,7 @@ use rustls::{ /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ClientConfigExt { - fn connect_async(&self, domain: &str, stream: S) + fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync where S: AsyncRead + AsyncWrite; } @@ -42,7 +43,7 @@ pub struct AcceptAsync(MidHandshake); impl ClientConfigExt for Arc { - fn connect_async(&self, domain: &str, stream: S) + fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync where S: AsyncRead + AsyncWrite { diff --git a/src/proto.rs b/src/proto.rs index 688bb19..7c659e4 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -19,6 +19,7 @@ use rustls::{ ServerConfig, ClientConfig, ServerSession, ClientSession }; use self::tokio_proto::multiplex; use self::tokio_proto::pipeline; use self::tokio_proto::streaming; +use webpki; use { TlsStream, ServerConfigExt, ClientConfigExt, AcceptAsync, ConnectAsync }; @@ -292,7 +293,7 @@ impl Future for ServerStreamingMultiplexBind pub struct Client { inner: Arc, connector: Arc, - hostname: String, + hostname: webpki::DNSName, } impl Client { @@ -303,11 +304,11 @@ impl Client { /// will go through the negotiated TLS stream through the `protocol` specified. pub fn new(protocol: T, connector: Arc, - hostname: &str) -> Client { + hostname: webpki::DNSName) -> Client { Client { inner: Arc::new(protocol), connector: connector, - hostname: hostname.to_string(), + hostname: hostname, } } } @@ -339,7 +340,7 @@ impl pipeline::ClientProto for Client fn bind_transport(&self, io: I) -> Self::BindTransport { let proto = self.inner.clone(); - let io = self.connector.connect_async(&self.hostname, io); + let io = self.connector.connect_async(self.hostname.as_ref(), io); ClientPipelineBind { state: ClientPipelineState::First(io, proto), @@ -397,7 +398,7 @@ impl multiplex::ClientProto for Client fn bind_transport(&self, io: I) -> Self::BindTransport { let proto = self.inner.clone(); - let io = self.connector.connect_async(&self.hostname, io); + let io = self.connector.connect_async(self.hostname.as_ref(), io); ClientMultiplexBind { state: ClientMultiplexState::First(io, proto), @@ -458,7 +459,7 @@ impl streaming::pipeline::ClientProto for Client fn bind_transport(&self, io: I) -> Self::BindTransport { let proto = self.inner.clone(); - let io = self.connector.connect_async(&self.hostname, io); + let io = self.connector.connect_async(self.hostname.as_ref(), io); ClientStreamingPipelineBind { state: ClientStreamingPipelineState::First(io, proto), @@ -519,7 +520,7 @@ impl streaming::multiplex::ClientProto for Client fn bind_transport(&self, io: I) -> Self::BindTransport { let proto = self.inner.clone(); - let io = self.connector.connect_async(&self.hostname, io); + let io = self.connector.connect_async(self.hostname.as_ref(), io); ClientStreamingMultiplexBind { state: ClientStreamingMultiplexState::First(io, proto), diff --git a/tests/test.rs b/tests/test.rs index 86c715e..e66e2aa 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -3,6 +3,7 @@ extern crate futures; extern crate tokio_core; extern crate tokio_io; extern crate tokio_rustls; +extern crate webpki; use std::{ io, thread }; use std::io::{ BufReader, Cursor }; @@ -24,7 +25,7 @@ const HELLO_WORLD: &[u8] = b"Hello world!"; fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { - let mut config = ServerConfig::new(); + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, rsa); let config = Arc::new(config); @@ -60,7 +61,9 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { recv.recv().unwrap() } -fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { +fn start_client(addr: &SocketAddr, domain: &str, + chain: Option>>) -> io::Result<()> { + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); if let Some(mut chain) = chain { config.root_store.add_pem_file(&mut chain).unwrap(); @@ -91,7 +94,6 @@ fn main() { let chain = BufReader::new(Cursor::new(CHAIN)); let addr = start_server(cert, keys.pop().unwrap()); - start_client(&addr, "localhost", Some(chain)).unwrap(); } From 8aa3f3a14b9e285a1f77a0fbadd05878368d9c76 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 8 Jan 2018 20:45:49 +0800 Subject: [PATCH 043/171] bump to 0.5.0 --- Cargo.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cc6a0c8..7c6acd6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.4.0" +version = "0.5.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -17,14 +17,14 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] futures = "0.1.15" tokio-io = "0.1.3" -rustls = { git = "https://github.com/ctz/rustls" } +rustls = "0.12" +webpki = "0.18.0-alpha" tokio-proto = { version = "0.1.1", optional = true } -webpki = { git = "https://github.com/briansmith/webpki" } [dev-dependencies] tokio-core = "0.1.9" clap = "2.26" -webpki-roots = { git = "https://github.com/briansmith/webpki-roots", branch = "webpki-github" } +webpki-roots = "0.14" [target.'cfg(unix)'.dev-dependencies] tokio-file-unix = "0.4" From 8d6140a7b981c7877fb18ed135204c9836604fb2 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 28 Feb 2018 14:36:37 +0800 Subject: [PATCH 044/171] upgrade example/test to tokio --- Cargo.toml | 3 ++- examples/server.rs | 31 ++++++++++++++++--------------- tests/test.rs | 25 ++++++++++--------------- 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7c6acd6..151320b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,8 @@ webpki = "0.18.0-alpha" tokio-proto = { version = "0.1.1", optional = true } [dev-dependencies] -tokio-core = "0.1.9" +tokio-core = "0.1" +tokio = "0.1" clap = "2.26" webpki-roots = "0.14" diff --git a/examples/server.rs b/examples/server.rs index 9d407b9..a450393 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -2,7 +2,7 @@ extern crate clap; extern crate rustls; extern crate futures; extern crate tokio_io; -extern crate tokio_core; +extern crate tokio; extern crate webpki_roots; extern crate tokio_rustls; @@ -14,8 +14,8 @@ use futures::{ Future, Stream }; use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_io::{ io, AsyncRead }; -use tokio_core::net::TcpListener; -use tokio_core::reactor::Core; +use tokio::net::TcpListener; +use tokio::executor::current_thread; use clap::{ App, Arg }; use tokio_rustls::ServerConfigExt; @@ -48,27 +48,28 @@ fn main() { let key_file = matches.value_of("key").unwrap(); let flag_echo = matches.occurrences_of("echo") > 0; - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let mut config = ServerConfig::new(NoClientAuth::new()); config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)); let arc_config = Arc::new(config); - let socket = TcpListener::bind(&addr, &handle).unwrap(); + let socket = TcpListener::bind(&addr).unwrap(); let done = socket.incoming() - .for_each(|(stream, addr)| if flag_echo { + .for_each(move |stream| if flag_echo { + let addr = stream.peer_addr().ok(); + let addr2 = addr.clone(); let done = arc_config.accept_async(stream) .and_then(|stream| { let (reader, writer) = stream.split(); io::copy(reader, writer) }) - .map(move |(n, ..)| println!("Echo: {} - {}", n, addr)) - .map_err(move |err| println!("Error: {:?} - {}", err, addr)); - handle.spawn(done); + .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) + .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); + current_thread::spawn(done); Ok(()) } else { + let addr = stream.peer_addr().ok(); + let addr2 = addr.clone(); let done = arc_config.accept_async(stream) .and_then(|stream| io::write_all( stream, @@ -79,12 +80,12 @@ fn main() { Hello world!"[..] )) .and_then(|(stream, _)| io::flush(stream)) - .map(move |_| println!("Accept: {}", addr)) - .map_err(move |err| println!("Error: {:?} - {}", err, addr)); - handle.spawn(done); + .map(move |_| println!("Accept: {:?}", addr)) + .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); + current_thread::spawn(done); Ok(()) }); - core.run(done).unwrap(); + current_thread::run(|_| current_thread::spawn(done.map_err(drop))); } diff --git a/tests/test.rs b/tests/test.rs index e66e2aa..baa22b2 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,6 @@ extern crate rustls; extern crate futures; -extern crate tokio_core; +extern crate tokio; extern crate tokio_io; extern crate tokio_rustls; extern crate webpki; @@ -11,8 +11,8 @@ use std::sync::Arc; use std::sync::mpsc::channel; use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; use futures::{ Future, Stream }; -use tokio_core::reactor::Core; -use tokio_core::net::{ TcpListener, TcpStream }; +use tokio::executor::current_thread; +use tokio::net::{ TcpListener, TcpStream }; use tokio_io::io as aio; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; @@ -33,14 +33,12 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { thread::spawn(move || { let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let listener = TcpListener::bind(&addr, &handle).unwrap(); + let listener = TcpListener::bind(&addr).unwrap(); send.send(listener.local_addr().unwrap()).unwrap(); let done = listener.incoming() - .for_each(|(stream, _)| { + .for_each(move |stream| { let done = config.accept_async(stream) .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) .and_then(|(stream, buf)| { @@ -50,12 +48,13 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { .map(drop) .map_err(drop); - handle.spawn(done); + current_thread::spawn(done); Ok(()) }) .map(drop) .map_err(drop); - core.run(done).unwrap(); + + current_thread::run(|_| current_thread::spawn(done)); }); recv.recv().unwrap() @@ -70,11 +69,7 @@ fn start_client(addr: &SocketAddr, domain: &str, } let config = Arc::new(config); - let mut core = Core::new()?; - let handle = core.handle(); - - #[allow(unreachable_code, unused_variables)] - let done = TcpStream::connect(addr, &handle) + let done = TcpStream::connect(addr) .and_then(|stream| config.connect_async(domain, stream)) .and_then(|stream| aio::write_all(stream, HELLO_WORLD)) .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) @@ -83,7 +78,7 @@ fn start_client(addr: &SocketAddr, domain: &str, Ok(()) }); - core.run(done) + done.wait() } From daac8f585fcff3b9092ea01b98251dc6971577fc Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 7 Mar 2018 12:24:16 +0800 Subject: [PATCH 045/171] fix outdated README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d9db909..66eb1f9 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using ### Basic Structure of a Client ```rust +use webpki::DNSNameRef; use rustls::ClientConfig; use tokio_rustls::ClientConfigExt; @@ -20,9 +21,10 @@ use tokio_rustls::ClientConfigExt; let mut config = ClientConfig::new(); config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); let config = Arc::new(config); +let domain = DNSNameRef::try_from_ascii_str("www.rust-lang.org").unwrap(); -TcpStream::connect(&addr, &handle) - .and_then(|socket| config.connect_async("www.rust-lang.org", socket)) +TcpStream::connect(&addr) + .and_then(|socket| config.connect_async(domain, socket)) // ... ``` From 9f78454cf1736846c40c09b70a1c93e8a8fa9743 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 20 Mar 2018 20:17:44 +0800 Subject: [PATCH 046/171] feat: try futures 0.2 --- Cargo.toml | 12 +++---- src/lib.rs | 97 ++++++++++++++++++++++++++++----------------------- tests/test.rs | 13 +++---- 3 files changed, 64 insertions(+), 58 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 151320b..68d12c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,17 +15,15 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = "0.1.15" -tokio-io = "0.1.3" +futures = "0.2.0-alpha" +tokio = { version = "0.1", features = [ "unstable-futures" ] } rustls = "0.12" webpki = "0.18.0-alpha" -tokio-proto = { version = "0.1.1", optional = true } [dev-dependencies] -tokio-core = "0.1" -tokio = "0.1" +tokio = { version = "0.1", features = [ "unstable-futures" ] } clap = "2.26" webpki-roots = "0.14" -[target.'cfg(unix)'.dev-dependencies] -tokio-file-unix = "0.4" +[patch.crates-io] +tokio = { git = "https://github.com/tokio-rs/tokio" } diff --git a/src/lib.rs b/src/lib.rs index 8ec6017..33a1dc3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,14 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). - -#[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; -#[macro_use] extern crate tokio_io; +extern crate futures; +extern crate tokio; extern crate rustls; extern crate webpki; -pub mod proto; - use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; -use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::task::Context; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig @@ -22,14 +19,14 @@ use rustls::{ pub trait ClientConfigExt { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync - where S: AsyncRead + AsyncWrite; + where S: io::Read + io::Write; } /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ServerConfigExt { fn accept_async(&self, stream: S) -> AcceptAsync - where S: AsyncRead + AsyncWrite; + where S: io::Read + io::Write; } @@ -45,7 +42,7 @@ pub struct AcceptAsync(MidHandshake); impl ClientConfigExt for Arc { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync - where S: AsyncRead + AsyncWrite + where S: io::Read + io::Write { connect_async_with_session(stream, ClientSession::new(self, domain)) } @@ -54,7 +51,7 @@ impl ClientConfigExt for Arc { #[inline] pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync - where S: AsyncRead + AsyncWrite + where S: io::Read + io::Write { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) @@ -64,7 +61,7 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync - where S: AsyncRead + AsyncWrite + where S: io::Read + io::Write { accept_async_with_session(stream, ServerSession::new(self)) } @@ -73,28 +70,28 @@ impl ServerConfigExt for Arc { #[inline] pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync - where S: AsyncRead + AsyncWrite + where S: io::Read + io::Write { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } -impl Future for ConnectAsync { +impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; - fn poll(&mut self) -> Poll { - self.0.poll() + fn poll(&mut self, ctx: &mut Context) -> Poll { + self.0.poll(ctx) } } -impl Future for AcceptAsync { +impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; - fn poll(&mut self) -> Poll { - self.0.poll() + fn poll(&mut self, ctx: &mut Context) -> Poll { + self.0.poll(ctx) } } @@ -104,12 +101,12 @@ struct MidHandshake { } impl Future for MidHandshake - where S: AsyncRead + AsyncWrite, C: Session + where S: io::Read + io::Write, C: Session { type Item = TlsStream; type Error = io::Error; - fn poll(&mut self) -> Poll { + fn poll(&mut self, _: &mut Context) -> Poll { loop { let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; @@ -121,7 +118,7 @@ impl Future for MidHandshake (..) => break }, Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), + (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), (io::ErrorKind::WouldBlock, false) => break, (..) => return Err(e) } @@ -154,7 +151,7 @@ impl TlsStream { } impl TlsStream - where S: AsyncRead + AsyncWrite, C: Session + where S: io::Read + io::Write, C: Session { #[inline] pub fn new(io: S, session: C) -> TlsStream { @@ -214,7 +211,7 @@ impl TlsStream } impl io::Read for TlsStream - where S: AsyncRead + AsyncWrite, C: Session + where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { @@ -233,7 +230,7 @@ impl io::Read for TlsStream } impl io::Write for TlsStream - where S: AsyncRead + AsyncWrite, C: Session + where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { if buf.is_empty() { @@ -272,26 +269,40 @@ impl io::Write for TlsStream } } -impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{} -impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; +mod tokio_impl { + use super::*; + use tokio::io::{ AsyncRead, AsyncWrite }; + use tokio::prelude::Poll; + + impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session + {} + + impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session + { + fn shutdown(&mut self) -> Poll<(), io::Error> { + if !self.is_shutdown { + self.session.send_close_notify(); + self.is_shutdown = true; + } + while self.session.wants_write() { + self.session.write_tls(&mut self.io)?; + } + self.io.flush()?; + self.io.shutdown() } - while self.session.wants_write() { - try_nb!(self.session.write_tls(&mut self.io)); - } - try_nb!(self.io.flush()); - self.io.shutdown() } } + +mod futures_impl { + use super::*; + use futures::io::{ AsyncRead, AsyncWrite }; + + // TODO +} diff --git a/tests/test.rs b/tests/test.rs index baa22b2..5698ec0 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,7 +1,6 @@ extern crate rustls; extern crate futures; extern crate tokio; -extern crate tokio_io; extern crate tokio_rustls; extern crate webpki; @@ -10,10 +9,9 @@ use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; -use futures::{ Future, Stream }; -use tokio::executor::current_thread; +use futures::{ FutureExt, StreamExt }; use tokio::net::{ TcpListener, TcpStream }; -use tokio_io::io as aio; +use tokio::io as aio; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; @@ -48,13 +46,12 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { .map(drop) .map_err(drop); - current_thread::spawn(done); + tokio::spawn2(done); Ok(()) }) - .map(drop) - .map_err(drop); + .then(|_| Ok(())); - current_thread::run(|_| current_thread::spawn(done)); + tokio::runtime::run2(done); }); recv.recv().unwrap() From 8c79329c7a8338a6519d33dc3ca493e6d4832542 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 21 Mar 2018 13:08:47 +0800 Subject: [PATCH 047/171] feat: split tokio_impl/futures_impl --- Cargo.toml | 10 ++--- src/futures_impl.rs | 80 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 90 ++------------------------------------------- src/tokio_impl.rs | 76 ++++++++++++++++++++++++++++++++++++++ tests/test.rs | 7 ++-- 5 files changed, 168 insertions(+), 95 deletions(-) create mode 100644 src/futures_impl.rs create mode 100644 src/tokio_impl.rs diff --git a/Cargo.toml b/Cargo.toml index 68d12c2..765cdde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,15 +15,15 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = "0.2.0-alpha" -tokio = { version = "0.1", features = [ "unstable-futures" ] } +futures = { version = "0.2.0-alpha", optional = true } +tokio = { version = "0.1", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] -tokio = { version = "0.1", features = [ "unstable-futures" ] } +tokio = "0.1" clap = "2.26" webpki-roots = "0.14" -[patch.crates-io] -tokio = { git = "https://github.com/tokio-rs/tokio" } +[features] +default = [ "futures", "tokio" ] diff --git a/src/futures_impl.rs b/src/futures_impl.rs new file mode 100644 index 0000000..b8b91bf --- /dev/null +++ b/src/futures_impl.rs @@ -0,0 +1,80 @@ +use super::*; +use futures::{ Future, Poll, Async }; +use futures::io::{ Error, AsyncRead, AsyncWrite }; +use futures::task::Context; + + +impl Future for ConnectAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self, ctx: &mut Context) -> Poll { + self.0.poll(ctx) + } +} + +impl Future for AcceptAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self, ctx: &mut Context) -> Poll { + self.0.poll(ctx) + } +} + +impl Future for MidHandshake + where S: io::Read + io::Write, C: Session +{ + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self, _: &mut Context) -> Poll { + loop { + let stream = self.inner.as_mut().unwrap(); + if !stream.session.is_handshaking() { break }; + + match stream.do_io() { + Ok(()) => match (stream.eof, stream.session.is_handshaking()) { + (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (false, true) => continue, + (..) => break + }, + Err(e) => match (e.kind(), stream.session.is_handshaking()) { + (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), + (io::ErrorKind::WouldBlock, false) => break, + (..) => return Err(e) + } + } + } + + Ok(Async::Ready(self.inner.take().unwrap())) + } +} + +impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn poll_read(&mut self, _: &mut Context, buf: &mut [u8]) -> Poll { + unimplemented!() + } +} + +impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn poll_write(&mut self, _: &mut Context, buf: &[u8]) -> Poll { + unimplemented!() + } + + fn poll_flush(&mut self, _: &mut Context) -> Poll<(), Error> { + unimplemented!() + } + + fn poll_close(&mut self, _: &mut Context) -> Poll<(), Error> { + unimplemented!() + } +} diff --git a/src/lib.rs b/src/lib.rs index 33a1dc3..98ff6c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,11 @@ extern crate tokio; extern crate rustls; extern crate webpki; +mod tokio_impl; +mod futures_impl; + use std::io; use std::sync::Arc; -use futures::{ Future, Poll, Async }; -use futures::task::Context; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig @@ -77,58 +78,11 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) }) } -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -impl Future for AcceptAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - struct MidHandshake { inner: Option> } -impl Future for MidHandshake - where S: io::Read + io::Write, C: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, _: &mut Context) -> Poll { - loop { - let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; - - match stream.do_io() { - Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), - (false, true) => continue, - (..) => break - }, - Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), - (io::ErrorKind::WouldBlock, false) => break, - (..) => return Err(e) - } - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -268,41 +222,3 @@ impl io::Write for TlsStream self.io.flush() } } - - -mod tokio_impl { - use super::*; - use tokio::io::{ AsyncRead, AsyncWrite }; - use tokio::prelude::Poll; - - impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session - {} - - impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session - { - fn shutdown(&mut self) -> Poll<(), io::Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - while self.session.wants_write() { - self.session.write_tls(&mut self.io)?; - } - self.io.flush()?; - self.io.shutdown() - } - } -} - -mod futures_impl { - use super::*; - use futures::io::{ AsyncRead, AsyncWrite }; - - // TODO -} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs new file mode 100644 index 0000000..117b56b --- /dev/null +++ b/src/tokio_impl.rs @@ -0,0 +1,76 @@ +use super::*; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; + + +impl Future for ConnectAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +impl Future for AcceptAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +impl Future for MidHandshake + where S: io::Read + io::Write, C: Session +{ + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let stream = self.inner.as_mut().unwrap(); + if !stream.session.is_handshaking() { break }; + + match stream.do_io() { + Ok(()) => match (stream.eof, stream.session.is_handshaking()) { + (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (false, true) => continue, + (..) => break + }, + Err(e) => match (e.kind(), stream.session.is_handshaking()) { + (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), + (io::ErrorKind::WouldBlock, false) => break, + (..) => return Err(e) + } + } + } + + Ok(Async::Ready(self.inner.take().unwrap())) + } +} + +impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{} + +impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + if !self.is_shutdown { + self.session.send_close_notify(); + self.is_shutdown = true; + } + while self.session.wants_write() { + self.session.write_tls(&mut self.io)?; + } + self.io.flush()?; + self.io.shutdown() + } +} diff --git a/tests/test.rs b/tests/test.rs index 5698ec0..5e37e73 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -9,7 +9,8 @@ use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; -use futures::{ FutureExt, StreamExt }; +use tokio::prelude::*; +// use futures::{ FutureExt, StreamExt }; use tokio::net::{ TcpListener, TcpStream }; use tokio::io as aio; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; @@ -46,12 +47,12 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { .map(drop) .map_err(drop); - tokio::spawn2(done); + tokio::spawn(done); Ok(()) }) .then(|_| Ok(())); - tokio::runtime::run2(done); + tokio::runtime::run(done); }); recv.recv().unwrap() From 72de25ebce08a9c9421bfac4d23cf23deb8694cf Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 21 Mar 2018 21:44:36 +0800 Subject: [PATCH 048/171] change: impl io::{Read,Write} --- Cargo.toml | 2 +- examples/server.rs | 13 ++-- src/futures_impl.rs | 8 ++- src/lib.rs | 160 +++++++++++++++++++++++++------------------- src/tokio_impl.rs | 13 ++-- 5 files changed, 107 insertions(+), 89 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 765cdde..2d42213 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.5.0" +version = "0.6.0-alpha" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/examples/server.rs b/examples/server.rs index a450393..178ff75 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,7 +1,5 @@ extern crate clap; extern crate rustls; -extern crate futures; -extern crate tokio_io; extern crate tokio; extern crate webpki_roots; extern crate tokio_rustls; @@ -10,12 +8,11 @@ use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; -use futures::{ Future, Stream }; use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_io::{ io, AsyncRead }; +use tokio::prelude::{ Future, Stream }; +use tokio::io::{ self, AsyncRead }; use tokio::net::TcpListener; -use tokio::executor::current_thread; use clap::{ App, Arg }; use tokio_rustls::ServerConfigExt; @@ -64,7 +61,7 @@ fn main() { }) .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); - current_thread::spawn(done); + tokio::spawn(done); Ok(()) } else { @@ -82,10 +79,10 @@ fn main() { .and_then(|(stream, _)| io::flush(stream)) .map(move |_| println!("Accept: {:?}", addr)) .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); - current_thread::spawn(done); + tokio::spawn(done); Ok(()) }); - current_thread::run(|_| current_thread::spawn(done.map_err(drop))); + tokio::run(done.map_err(drop)); } diff --git a/src/futures_impl.rs b/src/futures_impl.rs index b8b91bf..e8c1f79 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -1,7 +1,9 @@ +extern crate futures; + use super::*; -use futures::{ Future, Poll, Async }; -use futures::io::{ Error, AsyncRead, AsyncWrite }; -use futures::task::Context; +use self::futures::{ Future, Poll, Async }; +use self::futures::io::{ Error, AsyncRead, AsyncWrite }; +use self::futures::task::Context; impl Future for ConnectAsync { diff --git a/src/lib.rs b/src/lib.rs index 98ff6c0..b7cd5b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,10 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -extern crate futures; -extern crate tokio; extern crate rustls; extern crate webpki; -mod tokio_impl; -mod futures_impl; +#[cfg(feature = "tokio")] mod tokio_impl; +#[cfg(feature = "futures")] mod futures_impl; use std::io; use std::sync::Arc; @@ -17,14 +15,14 @@ use rustls::{ /// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ClientConfigExt { +pub trait ClientConfigExt: sealed::Sealed { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write; } /// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ServerConfigExt { +pub trait ServerConfigExt: sealed::Sealed { fn accept_async(&self, stream: S) -> AcceptAsync where S: io::Read + io::Write; @@ -39,6 +37,7 @@ pub struct ConnectAsync(MidHandshake); /// once the accept handshake has finished. pub struct AcceptAsync(MidHandshake); +impl sealed::Sealed for Arc {} impl ClientConfigExt for Arc { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) @@ -54,11 +53,11 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: io::Read + io::Write { - ConnectAsync(MidHandshake { - inner: Some(TlsStream::new(stream, session)) - }) + ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } +impl sealed::Sealed for Arc {} + impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync @@ -73,9 +72,7 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: io::Read + io::Write { - AcceptAsync(MidHandshake { - inner: Some(TlsStream::new(stream, session)) - }) + AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } @@ -104,11 +101,30 @@ impl TlsStream { } } + +macro_rules! try_wouldblock { + ( continue $r:expr ) => { + match $r { + Ok(true) => continue, + Ok(false) => false, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(e) => return Err(e) + } + }; + ( ignore $r:expr ) => { + match $r { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + } + }; +} + impl TlsStream where S: io::Read + io::Write, C: Session { #[inline] - pub fn new(io: S, session: C) -> TlsStream { + fn new(io: S, session: C) -> TlsStream { TlsStream { is_shutdown: false, eof: false, @@ -117,45 +133,46 @@ impl TlsStream } } + fn do_read(&mut self) -> io::Result { + if !self.eof && self.session.wants_read() { + if self.session.read_tls(&mut self.io)? == 0 { + self.eof = true; + } + + if let Err(err) = self.session.process_new_packets() { + // flush queued messages before returning an Err in + // order to send alerts instead of abruptly closing + // the socket + if self.session.wants_write() { + // ignore result to avoid masking original error + let _ = self.session.write_tls(&mut self.io); + } + return Err(io::Error::new(io::ErrorKind::InvalidData, err)); + } + + Ok(true) + } else { + Ok(false) + } + } + + fn do_write(&mut self) -> io::Result { + if self.session.wants_write() { + self.session.write_tls(&mut self.io)?; + + Ok(true) + } else { + Ok(false) + } + } + + #[inline] pub fn do_io(&mut self) -> io::Result<()> { loop { - let read_would_block = if !self.eof && self.session.wants_read() { - match self.session.read_tls(&mut self.io) { - Ok(0) => { - self.eof = true; - continue - }, - Ok(_) => { - if let Err(err) = self.session.process_new_packets() { - // flush queued messages before returning an Err in - // order to send alerts instead of abruptly closing - // the socket - if self.session.wants_write() { - // ignore result to avoid masking original error - let _ = self.session.write_tls(&mut self.io); - } - return Err(io::Error::new(io::ErrorKind::Other, err)); - } - continue - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - } else { - false - }; + let write_would_block = try_wouldblock!(continue self.do_write()); + let read_would_block = try_wouldblock!(continue self.do_read()); - let write_would_block = if self.session.wants_write() { - match self.session.write_tls(&mut self.io) { - Ok(_) => continue, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - } else { - false - }; - - if read_would_block || write_would_block { + if write_would_block || read_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { return Ok(()); @@ -168,12 +185,14 @@ impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { + try_wouldblock!(ignore self.do_io()); + loop { match self.session.read(buf) { - Ok(0) if !self.eof => self.do_io()?, + Ok(0) if !self.eof => while self.do_read()? {}, Ok(n) => return Ok(n), Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { - self.do_io()?; + try_wouldblock!(ignore self.do_read()); return if self.eof { Ok(0) } else { Err(e) } } else { return Err(e) @@ -187,38 +206,39 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - if buf.is_empty() { - return Ok(0); - } + try_wouldblock!(ignore self.do_io()); + + let mut wlen = self.session.write(buf)?; loop { - let output = self.session.write(buf)?; - - while self.session.wants_write() { - match self.session.write_tls(&mut self.io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => if output == 0 { + match self.do_write() { + Ok(true) => continue, + Ok(false) if wlen == 0 => (), + Ok(false) => break, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => + if wlen == 0 { // Both rustls buffer and IO buffer are blocking. return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { - break; + continue }, - Err(e) => return Err(e) - } + Err(e) => return Err(e) } - if output > 0 { - // Already wrote something out. - return Ok(output); - } + assert_eq!(wlen, 0); + wlen = self.session.write(buf)?; } + + Ok(wlen) } fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - while self.session.wants_write() { - self.session.write_tls(&mut self.io)?; - } + while self.do_write()? {}; self.io.flush() } } + +mod sealed { + pub trait Sealed {} +} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 117b56b..edbda5b 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,7 +1,9 @@ +extern crate tokio; + use super::*; -use tokio::prelude::*; -use tokio::io::{ AsyncRead, AsyncWrite }; -use tokio::prelude::Poll; +use self::tokio::prelude::*; +use self::tokio::io::{ AsyncRead, AsyncWrite }; +use self::tokio::prelude::Poll; impl Future for ConnectAsync { @@ -67,10 +69,7 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - while self.session.wants_write() { - self.session.write_tls(&mut self.io)?; - } - self.io.flush()?; + while self.do_write()? {}; self.io.shutdown() } } From d4cb46e8950a141e55ba76e774b7c5340ea1f862 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 22 Mar 2018 19:47:27 +0800 Subject: [PATCH 049/171] feat: start futures_impl --- Cargo.toml | 9 +++- examples/client.rs | 9 ++-- src/futures_impl.rs | 91 ++++++++++++++++++++++++++++++++------ src/lib.rs | 80 +++++++++++++++++---------------- src/tokio_impl.rs | 4 +- tests/test.rs | 105 +++++++++++++++++++++++++++++++++++++++----- 6 files changed, 227 insertions(+), 71 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2d42213..47a9ddd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,8 +22,15 @@ webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" +tokio-io = "0.1" +# tokio-core = "0.1" +# tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" [features] -default = [ "futures", "tokio" ] +unstable-futures = [ "futures", "tokio/unstable-futures" ] +default = [ "unstable-futures", "tokio" ] + +[patch.crates-io] +tokio = { path = "../ref/tokio" } diff --git a/examples/client.rs b/examples/client.rs index 418a6a4..d454b01 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,7 +1,6 @@ extern crate clap; extern crate rustls; -extern crate futures; -extern crate tokio_io; +extern crate tokio; extern crate tokio_core; extern crate webpki; extern crate webpki_roots; @@ -14,16 +13,16 @@ use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::{ BufReader, stdout, stdin }; use std::fs; -use futures::Future; +use tokio::io; +use tokio::prelude::*; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; -use tokio_io::io; use clap::{ App, Arg }; use rustls::ClientConfig; use tokio_rustls::ClientConfigExt; #[cfg(unix)] -use tokio_io::AsyncRead; +use tokio::io::AsyncRead; #[cfg(unix)] use tokio_file_unix::{ StdFile, File }; diff --git a/src/futures_impl.rs b/src/futures_impl.rs index e8c1f79..0f5cc80 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -6,7 +6,7 @@ use self::futures::io::{ Error, AsyncRead, AsyncWrite }; use self::futures::task::Context; -impl Future for ConnectAsync { +impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; @@ -15,7 +15,7 @@ impl Future for ConnectAsync { } } -impl Future for AcceptAsync { +impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; @@ -24,20 +24,67 @@ impl Future for AcceptAsync { } } +macro_rules! async { + ( to $r:expr ) => { + match $r { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()), + Err(e) => Err(e) + } + }; + ( from $r:expr ) => { + match $r { + Ok(n) => Ok(Async::Ready(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), + Err(e) => Err(e) + } + }; +} + +struct TaskStream<'a, 'b: 'a, S: 'a> { + io: &'a mut S, + task: &'a mut Context<'b> +} + +impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S> + where S: AsyncRead + AsyncWrite +{ + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + async!(to self.io.poll_read(self.task, buf)) + } +} + +impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S> + where S: AsyncRead + AsyncWrite +{ + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + async!(to self.io.poll_write(self.task, buf)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + async!(to self.io.poll_flush(self.task)) + } +} + impl Future for MidHandshake - where S: io::Read + io::Write, C: Session + where S: AsyncRead + AsyncWrite, C: Session { type Item = TlsStream; type Error = io::Error; - fn poll(&mut self, _: &mut Context) -> Poll { + fn poll(&mut self, ctx: &mut Context) -> Poll { loop { let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match stream.do_io() { + let mut taskio = TaskStream { io: &mut stream.io, task: ctx }; + + match TlsStream::do_io(&mut stream.session, &mut taskio, &mut stream.eof) { Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (true, true) => return Err(io::ErrorKind::UnexpectedEof.into()), (false, true) => continue, (..) => break }, @@ -58,8 +105,10 @@ impl AsyncRead for TlsStream S: AsyncRead + AsyncWrite, C: Session { - fn poll_read(&mut self, _: &mut Context, buf: &mut [u8]) -> Poll { - unimplemented!() + fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + // FIXME TlsStream + TaskStream + async!(from io::Read::read(&mut taskio, buf)) } } @@ -68,15 +117,29 @@ impl AsyncWrite for TlsStream S: AsyncRead + AsyncWrite, C: Session { - fn poll_write(&mut self, _: &mut Context, buf: &[u8]) -> Poll { - unimplemented!() + fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + // FIXME TlsStream + TaskStream + async!(from io::Write::write(&mut taskio, buf)) } - fn poll_flush(&mut self, _: &mut Context) -> Poll<(), Error> { - unimplemented!() + fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + // FIXME TlsStream + TaskStream + async!(from io::Write::flush(&mut taskio)) } - fn poll_close(&mut self, _: &mut Context) -> Poll<(), Error> { - unimplemented!() + fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { + if !self.is_shutdown { + self.session.send_close_notify(); + self.is_shutdown = true; + } + + { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + while TlsStream::do_write(&mut self.session, &mut taskio)? {}; + } + + self.io.poll_close(ctx) } } diff --git a/src/lib.rs b/src/lib.rs index b7cd5b0..8167a5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ extern crate rustls; extern crate webpki; #[cfg(feature = "tokio")] mod tokio_impl; -#[cfg(feature = "futures")] mod futures_impl; +#[cfg(feature = "unstable-futures")] mod futures_impl; use std::io; use std::sync::Arc; @@ -101,25 +101,6 @@ impl TlsStream { } } - -macro_rules! try_wouldblock { - ( continue $r:expr ) => { - match $r { - Ok(true) => continue, - Ok(false) => false, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - }; - ( ignore $r:expr ) => { - match $r { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - } - }; -} - impl TlsStream where S: io::Read + io::Write, C: Session { @@ -133,19 +114,19 @@ impl TlsStream } } - fn do_read(&mut self) -> io::Result { - if !self.eof && self.session.wants_read() { - if self.session.read_tls(&mut self.io)? == 0 { - self.eof = true; + fn do_read(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result { + if !*eof && session.wants_read() { + if session.read_tls(io)? == 0 { + *eof = true; } - if let Err(err) = self.session.process_new_packets() { + if let Err(err) = session.process_new_packets() { // flush queued messages before returning an Err in // order to send alerts instead of abruptly closing // the socket - if self.session.wants_write() { + if session.wants_write() { // ignore result to avoid masking original error - let _ = self.session.write_tls(&mut self.io); + let _ = session.write_tls(io); } return Err(io::Error::new(io::ErrorKind::InvalidData, err)); } @@ -156,9 +137,9 @@ impl TlsStream } } - fn do_write(&mut self) -> io::Result { - if self.session.wants_write() { - self.session.write_tls(&mut self.io)?; + fn do_write(session: &mut C, io: &mut S) -> io::Result { + if session.wants_write() { + session.write_tls(io)?; Ok(true) } else { @@ -167,10 +148,21 @@ impl TlsStream } #[inline] - pub fn do_io(&mut self) -> io::Result<()> { + pub fn do_io(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result<()> { + macro_rules! try_wouldblock { + ( $r:expr ) => { + match $r { + Ok(true) => continue, + Ok(false) => false, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(e) => return Err(e) + } + }; + } + loop { - let write_would_block = try_wouldblock!(continue self.do_write()); - let read_would_block = try_wouldblock!(continue self.do_read()); + let write_would_block = try_wouldblock!(Self::do_write(session, io)); + let read_would_block = try_wouldblock!(Self::do_read(session, io, eof)); if write_would_block || read_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); @@ -181,18 +173,28 @@ impl TlsStream } } +macro_rules! try_ignore { + ( $r:expr ) => { + match $r { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + } + } +} + impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - try_wouldblock!(ignore self.do_io()); + try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); loop { match self.session.read(buf) { - Ok(0) if !self.eof => while self.do_read()? {}, + Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {}, Ok(n) => return Ok(n), Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { - try_wouldblock!(ignore self.do_read()); + try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof)); return if self.eof { Ok(0) } else { Err(e) } } else { return Err(e) @@ -206,12 +208,12 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - try_wouldblock!(ignore self.do_io()); + try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); let mut wlen = self.session.write(buf)?; loop { - match self.do_write() { + match Self::do_write(&mut self.session, &mut self.io) { Ok(true) => continue, Ok(false) if wlen == 0 => (), Ok(false) => break, @@ -234,7 +236,7 @@ impl io::Write for TlsStream fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - while self.do_write()? {}; + while Self::do_write(&mut self.session, &mut self.io)? {}; self.io.flush() } } diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index edbda5b..294f915 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,7 +35,7 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match stream.do_io() { + match TlsStream::do_io(&mut stream.session, &mut stream.io, &mut stream.eof) { Ok(()) => match (stream.eof, stream.session.is_handshaking()) { (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), (false, true) => continue, @@ -69,7 +69,7 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - while self.do_write()? {}; + while TlsStream::do_write(&mut self.session, &mut self.io)? {}; self.io.shutdown() } } diff --git a/tests/test.rs b/tests/test.rs index 5e37e73..c0e2c8f 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,18 +1,16 @@ extern crate rustls; -extern crate futures; extern crate tokio; extern crate tokio_rustls; extern crate webpki; +#[cfg(feature = "unstable-futures")] extern crate futures; + use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; -use tokio::prelude::*; -// use futures::{ FutureExt, StreamExt }; use tokio::net::{ TcpListener, TcpStream }; -use tokio::io as aio; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; @@ -24,6 +22,9 @@ const HELLO_WORLD: &[u8] = b"Hello world!"; fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { + use tokio::prelude::*; + use tokio::io as aio; + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, rsa); let config = Arc::new(config); @@ -45,21 +46,62 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { aio::write_all(stream, HELLO_WORLD) }) .map(drop) - .map_err(drop); + .map_err(|err| panic!("{:?}", err)); tokio::spawn(done); Ok(()) }) - .then(|_| Ok(())); + .map_err(|err| panic!("{:?}", err)); - tokio::runtime::run(done); + tokio::run(done); }); recv.recv().unwrap() } -fn start_client(addr: &SocketAddr, domain: &str, - chain: Option>>) -> io::Result<()> { +fn start_server2(cert: Vec, rsa: PrivateKey) -> SocketAddr { + use futures::{ FutureExt, StreamExt }; + use futures::io::{ AsyncReadExt, AsyncWriteExt }; + + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, rsa); + let config = Arc::new(config); + + let (send, recv) = channel(); + + thread::spawn(move || { + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); + let listener = TcpListener::bind(&addr).unwrap(); + + send.send(listener.local_addr().unwrap()).unwrap(); + + let done = listener.incoming() + .for_each(move |stream| { + let done = config.accept_async(stream) + .and_then(|stream| stream.read_exact(vec![0; HELLO_WORLD.len()])) + .and_then(|(stream, buf)| { + assert_eq!(buf, HELLO_WORLD); + stream.write_all(HELLO_WORLD) + }) + .map(drop) + .map_err(|err| panic!("{:?}", err)); + + tokio::spawn2(done); + Ok(()) + }) + .map(drop) + .map_err(|err| panic!("{:?}", err)); + + tokio::runtime::run2(done); + }); + + recv.recv().unwrap() +} + +fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + use tokio::prelude::*; + use tokio::io as aio; + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); if let Some(mut chain) = chain { @@ -79,9 +121,41 @@ fn start_client(addr: &SocketAddr, domain: &str, done.wait() } +#[cfg(feature = "unstable-futures")] +fn start_client2(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + use futures::FutureExt; + use futures::io::{ AsyncReadExt, AsyncWriteExt }; + use futures::executor::block_on; + + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); + let mut config = ClientConfig::new(); + if let Some(mut chain) = chain { + config.root_store.add_pem_file(&mut chain).unwrap(); + } + let config = Arc::new(config); + + let done = TcpStream::connect(addr) + .and_then(|stream| config.connect_async(domain, stream)) + .and_then(|stream| { + eprintln!("WRITE: {:?}", stream); + stream.write_all(HELLO_WORLD) + }) + .and_then(|(stream, _)| { + eprintln!("READ: {:?}", stream); + stream.read_exact(vec![0; HELLO_WORLD.len()]) + }) + .and_then(|(stream, buf)| { + eprintln!("OK: {:?}", stream); + assert_eq!(buf, HELLO_WORLD); + Ok(()) + }); + + block_on(done) +} + #[test] -fn main() { +fn pass() { let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let chain = BufReader::new(Cursor::new(CHAIN)); @@ -90,6 +164,17 @@ fn main() { start_client(&addr, "localhost", Some(chain)).unwrap(); } +#[cfg(feature = "unstable-futures")] +#[test] +fn pass2() { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let chain = BufReader::new(Cursor::new(CHAIN)); + + let addr = start_server2(cert, keys.pop().unwrap()); + start_client2(&addr, "localhost", Some(chain)).unwrap(); +} + #[should_panic] #[test] fn fail() { From 64ca6e290cd4c17e32957ed665e7314b0d812fc4 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 23 Mar 2018 17:14:27 +0800 Subject: [PATCH 050/171] change: use rustls Stream --- Cargo.toml | 11 ++++----- src/lib.rs | 57 ++++++++++++++++------------------------------- src/tokio_impl.rs | 15 ++++--------- tests/test.rs | 12 +++------- 4 files changed, 32 insertions(+), 63 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 47a9ddd..427015d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,14 +23,15 @@ webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" tokio-io = "0.1" -# tokio-core = "0.1" -# tokio-file-unix = "0.4" +tokio-core = "0.1" +tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" [features] -unstable-futures = [ "futures", "tokio/unstable-futures" ] -default = [ "unstable-futures", "tokio" ] +default = [ "tokio" ] +# unstable-futures = [ "futures", "tokio/unstable-futures" ] +# default = [ "unstable-futures", "tokio" ] [patch.crates-io] -tokio = { path = "../ref/tokio" } +# tokio = { path = "../ref/tokio" } diff --git a/src/lib.rs b/src/lib.rs index 8167a5c..1d4ab7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,8 @@ use std::io; use std::sync::Arc; use rustls::{ Session, ClientSession, ServerSession, - ClientConfig, ServerConfig + ClientConfig, ServerConfig, + Stream }; @@ -92,10 +93,12 @@ pub struct TlsStream { } impl TlsStream { + #[inline] pub fn get_ref(&self) -> (&S, &C) { (&self.io, &self.session) } + #[inline] pub fn get_mut(&mut self) -> (&mut S, &mut C) { (&mut self.io, &mut self.session) } @@ -187,19 +190,13 @@ impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); - loop { - match self.session.read(buf) { - Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {}, - Ok(n) => return Ok(n), - Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { - try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof)); - return if self.eof { Ok(0) } else { Err(e) } - } else { - return Err(e) - } - } + match stream.read(buf) { + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0), + Err(e) => Err(e) } } } @@ -208,35 +205,19 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); - let mut wlen = self.session.write(buf)?; - - loop { - match Self::do_write(&mut self.session, &mut self.io) { - Ok(true) => continue, - Ok(false) if wlen == 0 => (), - Ok(false) => break, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - if wlen == 0 { - // Both rustls buffer and IO buffer are blocking. - return Err(io::Error::from(io::ErrorKind::WouldBlock)); - } else { - continue - }, - Err(e) => return Err(e) - } - - assert_eq!(wlen, 0); - wlen = self.session.write(buf)?; - } - - Ok(wlen) + stream.write(buf) } fn flush(&mut self) -> io::Result<()> { - self.session.flush()?; - while Self::do_write(&mut self.session, &mut self.io)? {}; + { + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); + stream.flush()?; + } + self.io.flush() } } diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 294f915..fa4fbe8 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,17 +35,10 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match TlsStream::do_io(&mut stream.session, &mut stream.io, &mut stream.eof) { - Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), - (false, true) => continue, - (..) => break - }, - Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), - (io::ErrorKind::WouldBlock, false) => break, - (..) => return Err(e) - } + match stream.session.complete_io(&mut stream.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) } } diff --git a/tests/test.rs b/tests/test.rs index c0e2c8f..e231737 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -59,6 +59,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { recv.recv().unwrap() } +#[cfg(feature = "unstable-futures")] fn start_server2(cert: Vec, rsa: PrivateKey) -> SocketAddr { use futures::{ FutureExt, StreamExt }; use futures::io::{ AsyncReadExt, AsyncWriteExt }; @@ -136,16 +137,9 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option Date: Fri, 23 Mar 2018 17:31:55 +0800 Subject: [PATCH 051/171] fix: futures_impl --- Cargo.toml | 16 +++++++------- src/futures_impl.rs | 52 +++++++++++++++++++++++++-------------------- src/lib.rs | 10 --------- src/tokio_impl.rs | 6 ++++-- tests/test.rs | 47 +++------------------------------------- 5 files changed, 44 insertions(+), 87 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 427015d..89d6b49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,23 +15,23 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = { version = "0.2.0-alpha", optional = true } +futures = { version = "0.2.0-beta", optional = true } tokio = { version = "0.1", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" -tokio-io = "0.1" -tokio-core = "0.1" -tokio-file-unix = "0.4" +# tokio-io = "0.1" +# tokio-core = "0.1" +# tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" [features] -default = [ "tokio" ] -# unstable-futures = [ "futures", "tokio/unstable-futures" ] -# default = [ "unstable-futures", "tokio" ] +# default = [ "tokio" ] +default = [ "unstable-futures", "tokio" ] +unstable-futures = [ "futures", "tokio/unstable-futures" ] [patch.crates-io] -# tokio = { path = "../ref/tokio" } +tokio = { path = "../ref/tokio" } diff --git a/src/futures_impl.rs b/src/futures_impl.rs index 0f5cc80..22c637c 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -80,19 +80,13 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - let mut taskio = TaskStream { io: &mut stream.io, task: ctx }; + let (io, session) = stream.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; - match TlsStream::do_io(&mut stream.session, &mut taskio, &mut stream.eof) { - Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::ErrorKind::UnexpectedEof.into()), - (false, true) => continue, - (..) => break - }, - Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), - (io::ErrorKind::WouldBlock, false) => break, - (..) => return Err(e) - } + match session.complete_io(&mut taskio) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending), + Err(e) => return Err(e) } } @@ -106,9 +100,16 @@ impl AsyncRead for TlsStream C: Session { fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - // FIXME TlsStream + TaskStream - async!(from io::Read::read(&mut taskio, buf)) + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + + match io::Read::read(&mut stream, buf) { + Ok(n) => Ok(Async::Ready(n)), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(Async::Ready(0)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), + Err(e) => Err(e) + } } } @@ -118,15 +119,19 @@ impl AsyncWrite for TlsStream C: Session { fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - // FIXME TlsStream + TaskStream - async!(from io::Write::write(&mut taskio, buf)) + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + + async!(from io::Write::write(&mut stream, buf)) } fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - // FIXME TlsStream + TaskStream - async!(from io::Write::flush(&mut taskio)) + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + + async!(from io::Write::flush(&mut stream)) } fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { @@ -136,8 +141,9 @@ impl AsyncWrite for TlsStream } { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - while TlsStream::do_write(&mut self.session, &mut taskio)? {}; + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + session.complete_io(&mut taskio)?; } self.io.poll_close(ctx) diff --git a/src/lib.rs b/src/lib.rs index 1d4ab7f..293a112 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -176,16 +176,6 @@ impl TlsStream } } -macro_rules! try_ignore { - ( $r:expr ) => { - match $r { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - } - } -} - impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index fa4fbe8..f5d3c6c 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,7 +35,9 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match stream.session.complete_io(&mut stream.io) { + let (io, session) = stream.get_mut(); + + match session.complete_io(io) { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), Err(e) => return Err(e) @@ -62,7 +64,7 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - while TlsStream::do_write(&mut self.session, &mut self.io)? {}; + self.session.complete_io(&mut self.io)?; self.io.shutdown() } } diff --git a/tests/test.rs b/tests/test.rs index e231737..92c904a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -45,8 +45,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { assert_eq!(buf, HELLO_WORLD); aio::write_all(stream, HELLO_WORLD) }) - .map(drop) - .map_err(|err| panic!("{:?}", err)); + .then(|_| Ok(())); tokio::spawn(done); Ok(()) @@ -59,46 +58,6 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { recv.recv().unwrap() } -#[cfg(feature = "unstable-futures")] -fn start_server2(cert: Vec, rsa: PrivateKey) -> SocketAddr { - use futures::{ FutureExt, StreamExt }; - use futures::io::{ AsyncReadExt, AsyncWriteExt }; - - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa); - let config = Arc::new(config); - - let (send, recv) = channel(); - - thread::spawn(move || { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let listener = TcpListener::bind(&addr).unwrap(); - - send.send(listener.local_addr().unwrap()).unwrap(); - - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept_async(stream) - .and_then(|stream| stream.read_exact(vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - stream.write_all(HELLO_WORLD) - }) - .map(drop) - .map_err(|err| panic!("{:?}", err)); - - tokio::spawn2(done); - Ok(()) - }) - .map(drop) - .map_err(|err| panic!("{:?}", err)); - - tokio::runtime::run2(done); - }); - - recv.recv().unwrap() -} - fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; @@ -139,7 +98,7 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option Date: Fri, 23 Mar 2018 17:47:20 +0800 Subject: [PATCH 052/171] fix: example conflict --- Cargo.toml | 8 +------- examples/client/Cargo.toml | 16 ++++++++++++++++ examples/{client.rs => client/src/main.rs} | 0 examples/server/Cargo.toml | 17 +++++++++++++++++ examples/{server.rs => server/src/main.rs} | 1 - 5 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 examples/client/Cargo.toml rename examples/{client.rs => client/src/main.rs} (100%) create mode 100644 examples/server/Cargo.toml rename examples/{server.rs => server/src/main.rs} (99%) diff --git a/Cargo.toml b/Cargo.toml index 89d6b49..dbad35b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,16 +22,10 @@ webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" -# tokio-io = "0.1" -# tokio-core = "0.1" -# tokio-file-unix = "0.4" -clap = "2.26" -webpki-roots = "0.14" [features] -# default = [ "tokio" ] default = [ "unstable-futures", "tokio" ] unstable-futures = [ "futures", "tokio/unstable-futures" ] [patch.crates-io] -tokio = { path = "../ref/tokio" } +tokio = { git = "https://github.com/tokio-rs/tokio" } diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml new file mode 100644 index 0000000..790b018 --- /dev/null +++ b/examples/client/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "client" +version = "0.1.0" +authors = ["quininer "] + +[dependencies] +rustls = "0.12" +webpki = "0.18.0-alpha" +tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } + +tokio = "0.1" +tokio-core = "0.1" +tokio-file-unix = "0.4" + +clap = "2.26" +webpki-roots = "0.14" diff --git a/examples/client.rs b/examples/client/src/main.rs similarity index 100% rename from examples/client.rs rename to examples/client/src/main.rs diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml new file mode 100644 index 0000000..1516e92 --- /dev/null +++ b/examples/server/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "server" +version = "0.1.0" +authors = ["quininer "] + +[dependencies] +rustls = "0.12" +tokio-rustls = { path = "../..", default-features = false, features = [ "unstable-futures" ] } + +tokio = { version = "0.1", features = [ "unstable-futures" ] } +futures = "0.2.0-beta" + +clap = "2.26" + + +[patch.crates-io] +tokio = { git = "https://github.com/tokio-rs/tokio" } diff --git a/examples/server.rs b/examples/server/src/main.rs similarity index 99% rename from examples/server.rs rename to examples/server/src/main.rs index 178ff75..c321048 100644 --- a/examples/server.rs +++ b/examples/server/src/main.rs @@ -1,7 +1,6 @@ extern crate clap; extern crate rustls; extern crate tokio; -extern crate webpki_roots; extern crate tokio_rustls; use std::sync::Arc; From 1892fdb6092c9164f7ac753ae820c1e8b6781cf0 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 23 Mar 2018 17:47:44 +0800 Subject: [PATCH 053/171] remove tokio-proto support --- src/proto.rs | 552 --------------------------------------------------- 1 file changed, 552 deletions(-) delete mode 100644 src/proto.rs diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 7c659e4..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,552 +0,0 @@ -//! Wrappers for `tokio-proto` -//! -//! This module contains wrappers for protocols defined by the `tokio-proto` -//! crate. These wrappers will all attempt to negotiate a TLS connection first -//! and then delegate all further protocol information to the protocol -//! specified. -//! -//! This module requires the `tokio-proto` feature to be enabled. - -#![cfg(feature = "tokio-proto")] - -extern crate tokio_proto; - -use std::io; -use std::sync::Arc; -use futures::{ Future, IntoFuture, Poll }; -use tokio_io::{ AsyncRead, AsyncWrite }; -use rustls::{ ServerConfig, ClientConfig, ServerSession, ClientSession }; -use self::tokio_proto::multiplex; -use self::tokio_proto::pipeline; -use self::tokio_proto::streaming; -use webpki; - -use { TlsStream, ServerConfigExt, ClientConfigExt, AcceptAsync, ConnectAsync }; - -/// TLS server protocol wrapper. -/// -/// This structure is a wrapper for other implementations of `ServerProto` in -/// the `tokio-proto` crate. This structure will negotiate a TLS connection -/// first and then delegate all further operations to the `ServerProto` -/// implementation for the underlying type. -pub struct Server { - inner: Arc, - acceptor: Arc, -} - -impl Server { - /// Constructs a new TLS protocol which will delegate to the underlying - /// `protocol` specified. - /// - /// The `acceptor` provided will be used to accept TLS connections. All new - /// connections will go through the TLS acceptor first and then further I/O - /// will go through the negotiated TLS stream through the `protocol` - /// specified. - pub fn new(protocol: T, acceptor: Arc) -> Server { - Server { - inner: Arc::new(protocol), - acceptor: acceptor, - } - } -} - -/// Future returned from `bind_transport` in the `ServerProto` implementation. -pub struct ServerPipelineBind - where T: pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: PipelineState, -} - -enum PipelineState - where T: pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(AcceptAsync, Arc), - Next(::Future), -} - -impl pipeline::ServerProto for Server - where T: pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type Response = T::Response; - type Transport = T::Transport; - type BindTransport = ServerPipelineBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - - ServerPipelineBind { - state: PipelineState::First(self.acceptor.accept_async(io), proto), - } - } -} - -impl Future for ServerPipelineBind - where T: pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - PipelineState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - PipelineState::Next(ref mut b) => return b.poll(), - }; - self.state = PipelineState::Next(next.into_future()); - } - } -} - -/// Future returned from `bind_transport` in the `ServerProto` implementation. -pub struct ServerMultiplexBind - where T: multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: MultiplexState, -} - -enum MultiplexState - where T: multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(AcceptAsync, Arc), - Next(::Future), -} - -impl multiplex::ServerProto for Server - where T: multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type Response = T::Response; - type Transport = T::Transport; - type BindTransport = ServerMultiplexBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - - ServerMultiplexBind { - state: MultiplexState::First(self.acceptor.accept_async(io), proto), - } - } -} - -impl Future for ServerMultiplexBind - where T: multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - MultiplexState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - MultiplexState::Next(ref mut b) => return b.poll(), - }; - self.state = MultiplexState::Next(next.into_future()); - } - } -} - -/// Future returned from `bind_transport` in the `ServerProto` implementation. -pub struct ServerStreamingPipelineBind - where T: streaming::pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: StreamingPipelineState, -} - -enum StreamingPipelineState - where T: streaming::pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(AcceptAsync, Arc), - Next(::Future), -} - -impl streaming::pipeline::ServerProto for Server - where T: streaming::pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type RequestBody = T::RequestBody; - type Response = T::Response; - type ResponseBody = T::ResponseBody; - type Error = T::Error; - type Transport = T::Transport; - type BindTransport = ServerStreamingPipelineBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - - ServerStreamingPipelineBind { - state: StreamingPipelineState::First(self.acceptor.accept_async(io), proto), - } - } -} - -impl Future for ServerStreamingPipelineBind - where T: streaming::pipeline::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - StreamingPipelineState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - StreamingPipelineState::Next(ref mut b) => return b.poll(), - }; - self.state = StreamingPipelineState::Next(next.into_future()); - } - } -} - -/// Future returned from `bind_transport` in the `ServerProto` implementation. -pub struct ServerStreamingMultiplexBind - where T: streaming::multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: StreamingMultiplexState, -} - -enum StreamingMultiplexState - where T: streaming::multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(AcceptAsync, Arc), - Next(::Future), -} - -impl streaming::multiplex::ServerProto for Server - where T: streaming::multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type RequestBody = T::RequestBody; - type Response = T::Response; - type ResponseBody = T::ResponseBody; - type Error = T::Error; - type Transport = T::Transport; - type BindTransport = ServerStreamingMultiplexBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - - ServerStreamingMultiplexBind { - state: StreamingMultiplexState::First(self.acceptor.accept_async(io), proto), - } - } -} - -impl Future for ServerStreamingMultiplexBind - where T: streaming::multiplex::ServerProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - StreamingMultiplexState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - StreamingMultiplexState::Next(ref mut b) => return b.poll(), - }; - self.state = StreamingMultiplexState::Next(next.into_future()); - } - } -} - -/// TLS client protocol wrapper. -/// -/// This structure is a wrapper for other implementations of `ClientProto` in -/// the `tokio-proto` crate. This structure will negotiate a TLS connection -/// first and then delegate all further operations to the `ClientProto` -/// implementation for the underlying type. -pub struct Client { - inner: Arc, - connector: Arc, - hostname: webpki::DNSName, -} - -impl Client { - /// Constructs a new TLS protocol which will delegate to the underlying - /// `protocol` specified. - /// - /// The `connector` provided will be used to configure the TLS connection. Further I/O - /// will go through the negotiated TLS stream through the `protocol` specified. - pub fn new(protocol: T, - connector: Arc, - hostname: webpki::DNSName) -> Client { - Client { - inner: Arc::new(protocol), - connector: connector, - hostname: hostname, - } - } -} - -/// Future returned from `bind_transport` in the `ClientProto` implementation. -pub struct ClientPipelineBind - where T: pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: ClientPipelineState, -} - -enum ClientPipelineState - where T: pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(ConnectAsync, Arc), - Next(::Future), -} - -impl pipeline::ClientProto for Client - where T: pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type Response = T::Response; - type Transport = T::Transport; - type BindTransport = ClientPipelineBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - let io = self.connector.connect_async(self.hostname.as_ref(), io); - - ClientPipelineBind { - state: ClientPipelineState::First(io, proto), - } - } -} - -impl Future for ClientPipelineBind - where T: pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - ClientPipelineState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - ClientPipelineState::Next(ref mut b) => return b.poll(), - }; - self.state = ClientPipelineState::Next(next.into_future()); - } - } -} - -/// Future returned from `bind_transport` in the `ClientProto` implementation. -pub struct ClientMultiplexBind - where T: multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: ClientMultiplexState, -} - -enum ClientMultiplexState - where T: multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(ConnectAsync, Arc), - Next(::Future), -} - -impl multiplex::ClientProto for Client - where T: multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type Response = T::Response; - type Transport = T::Transport; - type BindTransport = ClientMultiplexBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - let io = self.connector.connect_async(self.hostname.as_ref(), io); - - ClientMultiplexBind { - state: ClientMultiplexState::First(io, proto), - } - } -} - -impl Future for ClientMultiplexBind - where T: multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - ClientMultiplexState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - ClientMultiplexState::Next(ref mut b) => return b.poll(), - }; - self.state = ClientMultiplexState::Next(next.into_future()); - } - } -} - -/// Future returned from `bind_transport` in the `ClientProto` implementation. -pub struct ClientStreamingPipelineBind - where T: streaming::pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: ClientStreamingPipelineState, -} - -enum ClientStreamingPipelineState - where T: streaming::pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(ConnectAsync, Arc), - Next(::Future), -} - -impl streaming::pipeline::ClientProto for Client - where T: streaming::pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type RequestBody = T::RequestBody; - type Response = T::Response; - type ResponseBody = T::ResponseBody; - type Error = T::Error; - type Transport = T::Transport; - type BindTransport = ClientStreamingPipelineBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - let io = self.connector.connect_async(self.hostname.as_ref(), io); - - ClientStreamingPipelineBind { - state: ClientStreamingPipelineState::First(io, proto), - } - } -} - -impl Future for ClientStreamingPipelineBind - where T: streaming::pipeline::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - ClientStreamingPipelineState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - ClientStreamingPipelineState::Next(ref mut b) => return b.poll(), - }; - self.state = ClientStreamingPipelineState::Next(next.into_future()); - } - } -} - -/// Future returned from `bind_transport` in the `ClientProto` implementation. -pub struct ClientStreamingMultiplexBind - where T: streaming::multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - state: ClientStreamingMultiplexState, -} - -enum ClientStreamingMultiplexState - where T: streaming::multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - First(ConnectAsync, Arc), - Next(::Future), -} - -impl streaming::multiplex::ClientProto for Client - where T: streaming::multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Request = T::Request; - type RequestBody = T::RequestBody; - type Response = T::Response; - type ResponseBody = T::ResponseBody; - type Error = T::Error; - type Transport = T::Transport; - type BindTransport = ClientStreamingMultiplexBind; - - fn bind_transport(&self, io: I) -> Self::BindTransport { - let proto = self.inner.clone(); - let io = self.connector.connect_async(self.hostname.as_ref(), io); - - ClientStreamingMultiplexBind { - state: ClientStreamingMultiplexState::First(io, proto), - } - } -} - -impl Future for ClientStreamingMultiplexBind - where T: streaming::multiplex::ClientProto>, - I: AsyncRead + AsyncWrite + 'static, -{ - type Item = T::Transport; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - loop { - let next = match self.state { - ClientStreamingMultiplexState::First(ref mut a, ref state) => { - let res = a.poll().map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }); - state.bind_transport(try_ready!(res)) - } - ClientStreamingMultiplexState::Next(ref mut b) => return b.poll(), - }; - self.state = ClientStreamingMultiplexState::Next(next.into_future()); - } - } -} From 034357336ef9b1831bb82e331c562536abc03dfd Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 23 Mar 2018 17:59:24 +0800 Subject: [PATCH 054/171] remove deadcode --- src/lib.rs | 86 ++++++------------------------------------------------ 1 file changed, 9 insertions(+), 77 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 293a112..3337e0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ extern crate webpki; use std::io; use std::sync::Arc; +use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, @@ -17,7 +18,7 @@ use rustls::{ /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ClientConfigExt: sealed::Sealed { - fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) + fn connect_async(&self, domain: DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write; } @@ -41,7 +42,7 @@ pub struct AcceptAsync(MidHandshake); impl sealed::Sealed for Arc {} impl ClientConfigExt for Arc { - fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) + fn connect_async(&self, domain: DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write { @@ -54,7 +55,9 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: io::Read + io::Write { - ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) + ConnectAsync(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + }) } impl sealed::Sealed for Arc {} @@ -73,7 +76,9 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: io::Read + io::Write { - AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) + AcceptAsync(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + }) } @@ -87,7 +92,6 @@ struct MidHandshake { #[derive(Debug)] pub struct TlsStream { is_shutdown: bool, - eof: bool, io: S, session: C } @@ -104,78 +108,6 @@ impl TlsStream { } } -impl TlsStream - where S: io::Read + io::Write, C: Session -{ - #[inline] - fn new(io: S, session: C) -> TlsStream { - TlsStream { - is_shutdown: false, - eof: false, - io: io, - session: session - } - } - - fn do_read(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result { - if !*eof && session.wants_read() { - if session.read_tls(io)? == 0 { - *eof = true; - } - - if let Err(err) = session.process_new_packets() { - // flush queued messages before returning an Err in - // order to send alerts instead of abruptly closing - // the socket - if session.wants_write() { - // ignore result to avoid masking original error - let _ = session.write_tls(io); - } - return Err(io::Error::new(io::ErrorKind::InvalidData, err)); - } - - Ok(true) - } else { - Ok(false) - } - } - - fn do_write(session: &mut C, io: &mut S) -> io::Result { - if session.wants_write() { - session.write_tls(io)?; - - Ok(true) - } else { - Ok(false) - } - } - - #[inline] - pub fn do_io(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result<()> { - macro_rules! try_wouldblock { - ( $r:expr ) => { - match $r { - Ok(true) => continue, - Ok(false) => false, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - }; - } - - loop { - let write_would_block = try_wouldblock!(Self::do_write(session, io)); - let read_would_block = try_wouldblock!(Self::do_read(session, io, eof)); - - if write_would_block || read_would_block { - return Err(io::Error::from(io::ErrorKind::WouldBlock)); - } else { - return Ok(()); - } - } - } -} - impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { From fddb77759f0dbaca2d9fdab534f19f2cc2266840 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 23 Mar 2018 19:03:30 +0800 Subject: [PATCH 055/171] feat: handle CloseNotify alert --- src/futures_impl.rs | 24 +++++++++++++++++++----- src/lib.rs | 27 +++++++++++++++++++++------ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/futures_impl.rs b/src/futures_impl.rs index 22c637c..86bc9da 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -100,13 +100,27 @@ impl AsyncRead for TlsStream C: Session { fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); + if self.eof { + return Ok(Async::Ready(0)); + } - match io::Read::read(&mut stream, buf) { + // TODO nll + let result = { + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + io::Read::read(&mut stream, buf) + }; + + match result { + Ok(0) => { self.eof = true; Ok(Async::Ready(0)) }, Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(Async::Ready(0)), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.eof = true; + self.is_shutdown = true; + self.session.send_close_notify(); + Ok(Async::Ready(0)) + }, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), Err(e) => Err(e) } diff --git a/src/lib.rs b/src/lib.rs index 3337e0d..18b3ae2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,7 +56,7 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) where S: io::Read + io::Write { ConnectAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) }) } @@ -77,7 +77,7 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) where S: io::Read + io::Write { AcceptAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) }) } @@ -92,6 +92,7 @@ struct MidHandshake { #[derive(Debug)] pub struct TlsStream { is_shutdown: bool, + eof: bool, io: S, session: C } @@ -112,12 +113,26 @@ impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let (io, session) = self.get_mut(); - let mut stream = Stream::new(session, io); + if self.eof { + return Ok(0); + } - match stream.read(buf) { + // TODO nll + let result = { + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); + stream.read(buf) + }; + + match result { + Ok(0) => { self.eof = true; Ok(0) }, Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.eof = true; + self.is_shutdown = true; + self.session.send_close_notify(); + Ok(0) + }, Err(e) => Err(e) } } From 062c10e31ed71d2a5d05e913dd6c31ebd0ab6a66 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 24 Mar 2018 00:46:21 +0800 Subject: [PATCH 056/171] fix futures_impl flush/close and README --- README.md | 13 ++++++++----- examples/server/src/main.rs | 6 ++---- src/futures_impl.rs | 10 +++++++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 66eb1f9..74997f1 100644 --- a/README.md +++ b/README.md @@ -31,25 +31,28 @@ TcpStream::connect(&addr) ### Client Example Program -See [examples/client.rs](examples/client.rs). You can run it with: +See [examples/client](examples/client/src/main.rs). You can run it with: ```sh -cargo run --example client hsts.badssl.com +cd examples/client +cargo run -- hsts.badssl.com ``` Currently on Windows the example client reads from stdin and writes to stdout using blocking I/O. Until this is fixed, do something this on Windows: ```sh -echo | cargo run --example client hsts.badssl.com +cd examples/client +echo | cargo run -- hsts.badssl.com ``` ### Server Example Program -See [examples/server.rs](examples/server.rs). You can run it with: +See [examples/server](examples/server/src/main.rs). You can run it with: ```sh -cargo run --example server -- 127.0.0.1 --cert mycert.der --key mykey.der +cd examples/server +cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der ``` ### License & Origin diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index c321048..7e59cd1 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -52,20 +52,18 @@ fn main() { let done = socket.incoming() .for_each(move |stream| if flag_echo { let addr = stream.peer_addr().ok(); - let addr2 = addr.clone(); let done = arc_config.accept_async(stream) .and_then(|stream| { let (reader, writer) = stream.split(); io::copy(reader, writer) }) .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) - .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); + .map_err(move |err| println!("Error: {:?} - {:?}", err, addr)); tokio::spawn(done); Ok(()) } else { let addr = stream.peer_addr().ok(); - let addr2 = addr.clone(); let done = arc_config.accept_async(stream) .and_then(|stream| io::write_all( stream, @@ -77,7 +75,7 @@ fn main() { )) .and_then(|(stream, _)| io::flush(stream)) .map(move |_| println!("Accept: {:?}", addr)) - .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); + .map_err(move |err| println!("Error: {:?} - {:?}", err, addr)); tokio::spawn(done); Ok(()) diff --git a/src/futures_impl.rs b/src/futures_impl.rs index 86bc9da..13d3c35 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -143,9 +143,13 @@ impl AsyncWrite for TlsStream fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { let (io, session) = self.get_mut(); let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - async!(from io::Write::flush(&mut stream)) + { + let mut stream = Stream::new(session, &mut taskio); + async!(from io::Write::flush(&mut stream))?; + } + + async!(from io::Write::flush(&mut taskio)) } fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { @@ -157,7 +161,7 @@ impl AsyncWrite for TlsStream { let (io, session) = self.get_mut(); let mut taskio = TaskStream { io, task: ctx }; - session.complete_io(&mut taskio)?; + async!(from session.complete_io(&mut taskio))?; } self.io.poll_close(ctx) From 8f2306854e8a5c1b9895fed63a194e74f4b8dd0d Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 24 Mar 2018 01:22:45 +0800 Subject: [PATCH 057/171] change: update client example --- examples/client/src/main.rs | 95 +++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 51 deletions(-) diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index d454b01..444bc18 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -15,18 +15,10 @@ use std::io::{ BufReader, stdout, stdin }; use std::fs; use tokio::io; use tokio::prelude::*; -use tokio_core::net::TcpStream; -use tokio_core::reactor::Core; use clap::{ App, Arg }; use rustls::ClientConfig; use tokio_rustls::ClientConfigExt; -#[cfg(unix)] -use tokio::io::AsyncRead; - -#[cfg(unix)] -use tokio_file_unix::{ StdFile, File }; - #[cfg(not(unix))] use std::io::{Read, Write}; @@ -44,17 +36,13 @@ fn main() { let matches = app().get_matches(); let host = matches.value_of("host").unwrap(); - let port = if let Some(port) = matches.value_of("port") { - port.parse().unwrap() - } else { - 443 - }; + let port = matches.value_of("port") + .map(|port| port.parse().unwrap()) + .unwrap_or(443); let domain = matches.value_of("domain").unwrap_or(host); let cafile = matches.value_of("cafile"); let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - let mut core = Core::new().unwrap(); - let handle = core.handle(); let addr = (host, port) .to_socket_addrs().unwrap() .next().unwrap(); @@ -67,55 +55,60 @@ fn main() { config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } let arc_config = Arc::new(config); - let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let socket = TcpStream::connect(&addr, &handle); - // Use async non-blocking I/O for stdin/stdout on Unixy platforms. - #[cfg(unix)] - let stdin = stdin(); + { + use tokio::io::AsyncRead; + use tokio_core::reactor::Core; + use tokio_core::net::TcpStream; + use tokio_file_unix::{ StdFile, File }; - #[cfg(unix)] - let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() - .into_io(&handle).unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let socket = TcpStream::connect(&addr, &handle); - #[cfg(unix)] - let stdout = stdout(); + let stdin = stdin(); + let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() + .into_io(&handle).unwrap(); - #[cfg(unix)] - let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() - .into_io(&handle).unwrap(); + let stdout = stdout(); + let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() + .into_io(&handle).unwrap(); - #[cfg(unix)] - let resp = socket - .and_then(|stream| arc_config.connect_async(domain, stream)) - .and_then(|stream| io::write_all(stream, text.as_bytes())) - .and_then(|(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(|_| ()) - .select(io::copy(stdin, w).map(|_| ())) - .map_err(|(e, _)| e) - }); + let resp = socket + .and_then(|stream| arc_config.connect_async(domain, stream)) + .and_then(|stream| io::write_all(stream, text.as_bytes())) + .and_then(|(stream, _)| { + let (r, w) = stream.split(); + io::copy(r, stdout) + .map(|_| ()) + .select(io::copy(stdin, w).map(|_| ())) + .map_err(|(e, _)| e) + }); + + core.run(resp).unwrap(); + } // XXX: For now, just use blocking I/O for stdin/stdout on other platforms. // The network I/O will still be asynchronous and non-blocking. - #[cfg(not(unix))] - let mut input = Vec::new(); + { + use tokio::net::TcpStream; - #[cfg(not(unix))] - stdin().read_to_end(&mut input).unwrap(); + let socket = TcpStream::connect(&addr); - #[cfg(not(unix))] - let resp = socket - .and_then(|stream| arc_config.connect_async(domain, stream)) - .and_then(|stream| io::write_all(stream, text.as_bytes())) - .and_then(|(stream, _)| io::write_all(stream, &input)) - .and_then(|(stream, _)| io::read_to_end(stream, Vec::new())) - .and_then(|(_, output)| stdout().write_all(&output)); + let mut input = Vec::new(); + stdin().read_to_end(&mut input).unwrap(); - core.run(resp).unwrap(); + let resp = socket + .and_then(|stream| arc_config.connect_async(domain, stream)) + .and_then(|stream| io::write_all(stream, text.as_bytes())) + .and_then(|(stream, _)| io::write_all(stream, &input)) + .and_then(|(stream, _)| io::read_to_end(stream, Vec::new())) + .and_then(|(_, output)| stdout().write_all(&output)); + + resp.wait().unwrap(); + } } From d0f13ce5f9d5ddbfac052dbcc37fd7fcb6675fa9 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 24 Mar 2018 11:20:48 +0800 Subject: [PATCH 058/171] change: update tokio --- .travis.yml | 4 ++++ Cargo.toml | 3 --- appveyor.yml | 6 ++++-- examples/server/Cargo.toml | 4 ---- src/lib.rs | 21 +++------------------ 5 files changed, 11 insertions(+), 27 deletions(-) diff --git a/.travis.yml b/.travis.yml index 0d7365a..3d5e1db 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,3 +8,7 @@ os: script: - cargo test --all-features + - cd examples/server + - cargo check + - cd ../../examples/client + - cargo check diff --git a/Cargo.toml b/Cargo.toml index dbad35b..bc6d23c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,3 @@ tokio = "0.1" [features] default = [ "unstable-futures", "tokio" ] unstable-futures = [ "futures", "tokio/unstable-futures" ] - -[patch.crates-io] -tokio = { git = "https://github.com/tokio-rs/tokio" } diff --git a/appveyor.yml b/appveyor.yml index 80ae684..7ede91c 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,9 +1,7 @@ environment: matrix: - TARGET: x86_64-pc-windows-msvc - BITS: 64 - TARGET: i686-pc-windows-msvc - BITS: 32 install: - appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe @@ -16,3 +14,7 @@ build: false test_script: - 'cargo test --all-features' + - 'cd examples/server' + - 'cargo check' + - 'cd ../../examples/client' + - 'cargo check' diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 1516e92..4d164ef 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -11,7 +11,3 @@ tokio = { version = "0.1", features = [ "unstable-futures" ] } futures = "0.2.0-beta" clap = "2.26" - - -[patch.crates-io] -tokio = { git = "https://github.com/tokio-rs/tokio" } diff --git a/src/lib.rs b/src/lib.rs index 18b3ae2..02dcf39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,14 +117,7 @@ impl io::Read for TlsStream return Ok(0); } - // TODO nll - let result = { - let (io, session) = self.get_mut(); - let mut stream = Stream::new(session, io); - stream.read(buf) - }; - - match result { + match Stream::new(&mut self.session, &mut self.io).read(buf) { Ok(0) => { self.eof = true; Ok(0) }, Ok(n) => Ok(n), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { @@ -142,19 +135,11 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - let (io, session) = self.get_mut(); - let mut stream = Stream::new(session, io); - - stream.write(buf) + Stream::new(&mut self.session, &mut self.io).write(buf) } fn flush(&mut self) -> io::Result<()> { - { - let (io, session) = self.get_mut(); - let mut stream = Stream::new(session, io); - stream.flush()?; - } - + Stream::new(&mut self.session, &mut self.io).flush()?; self.io.flush() } } From 40837e480595daf7c90228d0ed935cebdc526ecb Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 24 Mar 2018 11:28:45 +0800 Subject: [PATCH 059/171] fix client example --- examples/client/Cargo.toml | 4 +++- examples/client/src/main.rs | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 790b018..8cbf6c8 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -10,7 +10,9 @@ tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" tokio = "0.1" tokio-core = "0.1" -tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" + +[target.'cfg(unix)'.dependencies] +tokio-file-unix = "0.4" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 444bc18..e9171a4 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -19,9 +19,6 @@ use clap::{ App, Arg }; use rustls::ClientConfig; use tokio_rustls::ClientConfigExt; -#[cfg(not(unix))] -use std::io::{Read, Write}; - fn app() -> App<'static, 'static> { App::new("client") .about("tokio-rustls client example") @@ -95,6 +92,7 @@ fn main() { // The network I/O will still be asynchronous and non-blocking. #[cfg(not(unix))] { + use std::io::{ Read, Write }; use tokio::net::TcpStream; let socket = TcpStream::connect(&addr); From fff2f4a73b19559deb4faa4dabf237cde353b2a6 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 31 Mar 2018 15:16:31 +0800 Subject: [PATCH 060/171] fix(tokio_impl): shutdown WouldBlock --- Cargo.toml | 4 ++-- src/tokio_impl.rs | 8 +++++++- tests/test.rs | 14 ++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bc6d23c..a7a1bb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.6.0-alpha" +version = "0.6.0-alpha.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -24,5 +24,5 @@ webpki = "0.18.0-alpha" tokio = "0.1" [features] -default = [ "unstable-futures", "tokio" ] +default = [ "tokio" ] unstable-futures = [ "futures", "tokio/unstable-futures" ] diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index f5d3c6c..936c14b 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -64,7 +64,13 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - self.session.complete_io(&mut self.io)?; + + match self.session.complete_io(&mut self.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) + } + self.io.shutdown() } } diff --git a/tests/test.rs b/tests/test.rs index 92c904a..246e85a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -73,10 +73,11 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option Date: Sat, 7 Apr 2018 20:20:05 +0800 Subject: [PATCH 061/171] update futures 0.2.0 --- Cargo.toml | 8 ++++---- examples/server/Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a7a1bb5..6356094 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.6.0-alpha.1" +version = "0.6.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -15,13 +15,13 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = { version = "0.2.0-beta", optional = true } -tokio = { version = "0.1", optional = true } +futures = { version = "0.2.0", optional = true } +tokio = { version = "0.1.5", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] -tokio = "0.1" +tokio = "0.1.5" [features] default = [ "tokio" ] diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 4d164ef..0d044f8 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,7 +7,7 @@ authors = ["quininer "] rustls = "0.12" tokio-rustls = { path = "../..", default-features = false, features = [ "unstable-futures" ] } -tokio = { version = "0.1", features = [ "unstable-futures" ] } +tokio = { version = "0.1.5", features = [ "unstable-futures" ] } futures = "0.2.0-beta" clap = "2.26" From be00ca61687cea34e649410e170b211b8d8cea15 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 14 Apr 2018 18:17:04 +0800 Subject: [PATCH 062/171] change: split futures crate --- Cargo.toml | 10 ++++++++-- src/futures_impl.rs | 9 +++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6356094..16ed9f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,14 +15,20 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = { version = "0.2.0", optional = true } +futures-core = { version = "0.2.0", optional = true } +futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.5", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] +futures = "0.2.0" tokio = "0.1.5" [features] default = [ "tokio" ] -unstable-futures = [ "futures", "tokio/unstable-futures" ] +unstable-futures = [ + "futures-core", + "futures-io", + "tokio/unstable-futures" +] diff --git a/src/futures_impl.rs b/src/futures_impl.rs index 13d3c35..6771316 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -1,9 +1,10 @@ -extern crate futures; +extern crate futures_core; +extern crate futures_io; use super::*; -use self::futures::{ Future, Poll, Async }; -use self::futures::io::{ Error, AsyncRead, AsyncWrite }; -use self::futures::task::Context; +use self::futures_core::{ Future, Poll, Async }; +use self::futures_core::task::Context; +use self::futures_io::{ Error, AsyncRead, AsyncWrite }; impl Future for ConnectAsync { From 8fc4084e068ce742c29d4f2abf51bff453b56a3c Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 3 May 2018 12:09:08 +0800 Subject: [PATCH 063/171] change: temporarily remove futures 0.2 support --- Cargo.toml | 14 +++++++------- examples/server/Cargo.toml | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 16ed9f4..81cd294 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,18 +17,18 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] futures-core = { version = "0.2.0", optional = true } futures-io = { version = "0.2.0", optional = true } -tokio = { version = "0.1.5", optional = true } +tokio = { version = "0.1.6", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] futures = "0.2.0" -tokio = "0.1.5" +tokio = "0.1.6" [features] default = [ "tokio" ] -unstable-futures = [ - "futures-core", - "futures-io", - "tokio/unstable-futures" -] +# unstable-futures = [ +# "futures-core", +# "futures-io", +# "tokio/unstable-futures" +# ] diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 0d044f8..4bc4e8b 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -5,9 +5,9 @@ authors = ["quininer "] [dependencies] rustls = "0.12" -tokio-rustls = { path = "../..", default-features = false, features = [ "unstable-futures" ] } +tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } -tokio = { version = "0.1.5", features = [ "unstable-futures" ] } -futures = "0.2.0-beta" +tokio = { version = "0.1.6" } +# futures = "0.2.0-beta" clap = "2.26" From 41425e0c2afd15ca060162fc70551cbda130ea01 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 3 May 2018 18:07:31 +0800 Subject: [PATCH 064/171] update examples --- Cargo.toml | 2 +- examples/client/Cargo.toml | 6 ++- examples/client/src/main.rs | 102 ++++++++++++++++++------------------ 3 files changed, 56 insertions(+), 54 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 81cd294..e36dda8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] -futures = "0.2.0" +# futures = "0.2.0" tokio = "0.1.6" [features] diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 8cbf6c8..2449245 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -9,10 +9,12 @@ webpki = "0.18.0-alpha" tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } tokio = "0.1" -tokio-core = "0.1" clap = "2.26" webpki-roots = "0.14" [target.'cfg(unix)'.dependencies] -tokio-file-unix = "0.4" +tokio-file-unix = "0.5" + +[target.'cfg(not(unix))'.dependencies] +tokio-fs = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index e9171a4..3bfc157 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -1,19 +1,19 @@ extern crate clap; extern crate rustls; extern crate tokio; -extern crate tokio_core; extern crate webpki; extern crate webpki_roots; extern crate tokio_rustls; -#[cfg(unix)] -extern crate tokio_file_unix; +#[cfg(unix)] extern crate tokio_file_unix; +#[cfg(not(unix))] extern crate tokio_fs; use std::sync::Arc; use std::net::ToSocketAddrs; -use std::io::{ BufReader, stdout, stdin }; +use std::io::BufReader; use std::fs; use tokio::io; +use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; use rustls::ClientConfig; @@ -36,7 +36,7 @@ fn main() { let port = matches.value_of("port") .map(|port| port.parse().unwrap()) .unwrap_or(443); - let domain = matches.value_of("domain").unwrap_or(host); + let domain = matches.value_of("domain").unwrap_or(host).to_owned(); let cafile = matches.value_of("cafile"); let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); @@ -52,61 +52,61 @@ fn main() { config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } let arc_config = Arc::new(config); - let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - // Use async non-blocking I/O for stdin/stdout on Unixy platforms. + let socket = TcpStream::connect(&addr); + #[cfg(unix)] - { - use tokio::io::AsyncRead; - use tokio_core::reactor::Core; - use tokio_core::net::TcpStream; - use tokio_file_unix::{ StdFile, File }; + let resp = { + use tokio::reactor::Handle; + use tokio_file_unix::{ raw_stdin, raw_stdout, File }; - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let socket = TcpStream::connect(&addr, &handle); + let stdin = raw_stdin() + .and_then(File::new_nb) + .and_then(|fd| fd.into_reader(&Handle::current())) + .unwrap(); + let stdout = raw_stdout() + .and_then(File::new_nb) + .and_then(|fd| fd.into_io(&Handle::current())) + .unwrap(); - let stdin = stdin(); - let stdin = File::new_nb(StdFile(stdin.lock())).unwrap() - .into_io(&handle).unwrap(); - - let stdout = stdout(); - let stdout = File::new_nb(StdFile(stdout.lock())).unwrap() - .into_io(&handle).unwrap(); - - let resp = socket - .and_then(|stream| arc_config.connect_async(domain, stream)) - .and_then(|stream| io::write_all(stream, text.as_bytes())) - .and_then(|(stream, _)| { + socket + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + arc_config.connect_async(domain, stream) + }) + .and_then(move |stream| io::write_all(stream, text)) + .and_then(move |(stream, _)| { let (r, w) = stream.split(); io::copy(r, stdout) - .map(|_| ()) - .select(io::copy(stdin, w).map(|_| ())) - .map_err(|(e, _)| e) - }); + .map(drop) + .select2(io::copy(stdin, w).map(drop)) + .map_err(|res| res.split().0) + }) + .map(drop) + .map_err(|err| eprintln!("{:?}", err)) + }; - core.run(resp).unwrap(); - } - - // XXX: For now, just use blocking I/O for stdin/stdout on other platforms. - // The network I/O will still be asynchronous and non-blocking. #[cfg(not(unix))] - { - use std::io::{ Read, Write }; - use tokio::net::TcpStream; + let resp = { + use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout }; - let socket = TcpStream::connect(&addr); + let (stdin, stdout) = (tokio_stdin(), tokio_stdout()); - let mut input = Vec::new(); - stdin().read_to_end(&mut input).unwrap(); + socket + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + arc_config.connect_async(domain, stream) + }) + .and_then(move |stream| io::write_all(stream, text)) + .and_then(move |(stream, _)| { + let (r, w) = stream.split(); + io::copy(r, stdout) + .map(drop) + .join(io::copy(stdin, w).map(drop)) + }) + .map(drop) + .map_err(|err| eprintln!("{:?}", err)) + }; - let resp = socket - .and_then(|stream| arc_config.connect_async(domain, stream)) - .and_then(|stream| io::write_all(stream, text.as_bytes())) - .and_then(|(stream, _)| io::write_all(stream, &input)) - .and_then(|(stream, _)| io::read_to_end(stream, Vec::new())) - .and_then(|(_, output)| stdout().write_all(&output)); - - resp.wait().unwrap(); - } + tokio::run(resp); } From 8e36dd4541bf567ade5c0fccfba7d78061567e4a Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Sun, 15 Jul 2018 12:28:56 +0100 Subject: [PATCH 065/171] Update dependencies --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e36dda8..b6f8f59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.6.0" +version = "0.7.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -18,8 +18,8 @@ appveyor = { repository = "quininer/tokio-rustls" } futures-core = { version = "0.2.0", optional = true } futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } -rustls = "0.12" -webpki = "0.18.0-alpha" +rustls = "0.13" +webpki = "0.18.1" [dev-dependencies] # futures = "0.2.0" From 5d6d4740804788df799ba5964b993b3fe74432fd Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Sun, 15 Jul 2018 15:02:15 +0100 Subject: [PATCH 066/171] Also update dependencies in example code --- examples/client/Cargo.toml | 6 +++--- examples/server/Cargo.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 2449245..6c4af55 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -4,14 +4,14 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -rustls = "0.12" -webpki = "0.18.0-alpha" +rustls = "0.13" +webpki = "0.18.1" tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } tokio = "0.1" clap = "2.26" -webpki-roots = "0.14" +webpki-roots = "0.15" [target.'cfg(unix)'.dependencies] tokio-file-unix = "0.5" diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 4bc4e8b..bc4e6a8 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -rustls = "0.12" +rustls = "0.13" tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } tokio = { version = "0.1.6" } From a8e1e9ac35750453c62121347cf740435e238422 Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Sun, 15 Jul 2018 15:02:42 +0100 Subject: [PATCH 067/171] Fix warnings now set_single_cert yields a Result --- examples/server/src/main.rs | 3 ++- tests/test.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 7e59cd1..8d1718a 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -45,7 +45,8 @@ fn main() { let flag_echo = matches.occurrences_of("echo") > 0; let mut config = ServerConfig::new(NoClientAuth::new()); - config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)); + config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)) + .expect("invalid key or certificate"); let arc_config = Arc::new(config); let socket = TcpListener::bind(&addr).unwrap(); diff --git a/tests/test.rs b/tests/test.rs index 246e85a..e64dd82 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -26,7 +26,8 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { use tokio::io as aio; let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa); + config.set_single_cert(cert, rsa) + .expect("invalid key or certificate"); let config = Arc::new(config); let (send, recv) = channel(); From b4aa18277fd3d3a37a9714822be4878d565585ec Mon Sep 17 00:00:00 2001 From: Sander Maijers Date: Mon, 16 Jul 2018 12:50:05 +0200 Subject: [PATCH 068/171] Reexport deps This helps encapsulate deps and match crate versions for downstream use. --- src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 02dcf39..d1c6c7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -extern crate rustls; -extern crate webpki; +pub extern crate rustls; +pub extern crate webpki; #[cfg(feature = "tokio")] mod tokio_impl; #[cfg(feature = "unstable-futures")] mod futures_impl; From 761696ad2219176bf56e52bddb08b7f2590d2867 Mon Sep 17 00:00:00 2001 From: Sander Maijers Date: Mon, 16 Jul 2018 14:05:38 +0200 Subject: [PATCH 069/171] Use reexported `rustls` deps in examples --- README.md | 3 +-- examples/client/Cargo.toml | 1 - examples/client/src/main.rs | 4 +--- examples/server/Cargo.toml | 1 - examples/server/src/main.rs | 10 +++++----- 5 files changed, 7 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 74997f1..da08405 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,7 @@ Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using ```rust use webpki::DNSNameRef; -use rustls::ClientConfig; -use tokio_rustls::ClientConfigExt; +use tokio_rustls::{ClientConfigExt, rustls::ClientConfig}; // ... diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 6c4af55..acf13bd 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -rustls = "0.13" webpki = "0.18.1" tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 3bfc157..4673f47 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -1,5 +1,4 @@ extern crate clap; -extern crate rustls; extern crate tokio; extern crate webpki; extern crate webpki_roots; @@ -16,8 +15,7 @@ use tokio::io; use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; -use rustls::ClientConfig; -use tokio_rustls::ClientConfigExt; +use tokio_rustls::{ClientConfigExt, rustls::ClientConfig}; fn app() -> App<'static, 'static> { App::new("client") diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index bc4e6a8..98329ce 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -rustls = "0.13" tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } tokio = { version = "0.1.6" } diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 8d1718a..c10f8ee 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -1,5 +1,4 @@ extern crate clap; -extern crate rustls; extern crate tokio; extern crate tokio_rustls; @@ -7,14 +6,15 @@ use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; -use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; -use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::{ + ServerConfigExt, + rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig, + internal::pemfile::{ certs, rsa_private_keys }}, +}; use tokio::prelude::{ Future, Stream }; use tokio::io::{ self, AsyncRead }; use tokio::net::TcpListener; use clap::{ App, Arg }; -use tokio_rustls::ServerConfigExt; - fn app() -> App<'static, 'static> { App::new("server") From 32d3f46a9efaeb014f6efc8dc6c3a9538fad1e02 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 16 Jul 2018 21:26:03 +0800 Subject: [PATCH 070/171] publish 0.7.1 --- Cargo.toml | 2 +- README.md | 2 +- examples/client/src/main.rs | 2 +- examples/server/src/main.rs | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b6f8f59..e289e0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.7.0" +version = "0.7.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/README.md b/README.md index da08405..45d85b5 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using ```rust use webpki::DNSNameRef; -use tokio_rustls::{ClientConfigExt, rustls::ClientConfig}; +use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; // ... diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 4673f47..8499993 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -15,7 +15,7 @@ use tokio::io; use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; -use tokio_rustls::{ClientConfigExt, rustls::ClientConfig}; +use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; fn app() -> App<'static, 'static> { App::new("client") diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index c10f8ee..2222c1e 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -8,8 +8,10 @@ use std::io::BufReader; use std::fs::File; use tokio_rustls::{ ServerConfigExt, - rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig, - internal::pemfile::{ certs, rsa_private_keys }}, + rustls::{ + Certificate, NoClientAuth, PrivateKey, ServerConfig, + internal::pemfile::{ certs, rsa_private_keys } + }, }; use tokio::prelude::{ Future, Stream }; use tokio::io::{ self, AsyncRead }; From 37954cd647ea6b7d626f15925518b38f2a858bee Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 9 Aug 2018 10:46:28 +0800 Subject: [PATCH 071/171] use tokio-tls 0.2 api --- examples/client/src/main.rs | 8 +-- examples/server/src/main.rs | 8 +-- src/lib.rs | 111 +++++++++++++++++------------------- src/tokio_impl.rs | 4 +- tests/test.rs | 14 ++--- 5 files changed, 70 insertions(+), 75 deletions(-) diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 8499993..0d34f64 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -15,7 +15,7 @@ use tokio::io; use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; -use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; fn app() -> App<'static, 'static> { App::new("client") @@ -49,7 +49,7 @@ fn main() { } else { config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } - let arc_config = Arc::new(config); + let config = TlsConnector::from(Arc::new(config)); let socket = TcpStream::connect(&addr); @@ -70,7 +70,7 @@ fn main() { socket .and_then(move |stream| { let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) + config.connect(domain, stream) }) .and_then(move |stream| io::write_all(stream, text)) .and_then(move |(stream, _)| { @@ -93,7 +93,7 @@ fn main() { socket .and_then(move |stream| { let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) + config.connect(domain, stream) }) .and_then(move |stream| io::write_all(stream, text)) .and_then(move |(stream, _)| { diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 2222c1e..2a94c58 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -7,7 +7,7 @@ use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; use tokio_rustls::{ - ServerConfigExt, + TlsAcceptor, rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig, internal::pemfile::{ certs, rsa_private_keys } @@ -49,13 +49,13 @@ fn main() { let mut config = ServerConfig::new(NoClientAuth::new()); config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)) .expect("invalid key or certificate"); - let arc_config = Arc::new(config); + let config = TlsAcceptor::from(Arc::new(config)); let socket = TcpListener::bind(&addr).unwrap(); let done = socket.incoming() .for_each(move |stream| if flag_echo { let addr = stream.peer_addr().ok(); - let done = arc_config.accept_async(stream) + let done = config.accept(stream) .and_then(|stream| { let (reader, writer) = stream.split(); io::copy(reader, writer) @@ -67,7 +67,7 @@ fn main() { Ok(()) } else { let addr = stream.peer_addr().ok(); - let done = arc_config.accept_async(stream) + let done = config.accept(stream) .and_then(|stream| io::write_all( stream, &b"HTTP/1.0 200 ok\r\n\ diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..8d43c22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,70 +16,69 @@ use rustls::{ }; -/// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ClientConfigExt: sealed::Sealed { - fn connect_async(&self, domain: DNSNameRef, stream: S) - -> ConnectAsync - where S: io::Read + io::Write; +pub struct TlsConnector { + inner: Arc } -/// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ServerConfigExt: sealed::Sealed { - fn accept_async(&self, stream: S) - -> AcceptAsync - where S: io::Read + io::Write; +pub struct TlsAcceptor { + inner: Arc +} + +impl From> for TlsConnector { + fn from(inner: Arc) -> TlsConnector { + TlsConnector { inner } + } +} + +impl From> for TlsAcceptor { + fn from(inner: Arc) -> TlsAcceptor { + TlsAcceptor { inner } + } +} + +impl TlsConnector { + pub fn connect(&self, domain: DNSNameRef, stream: S) -> Connect + where S: io::Read + io::Write + { + Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) + } + + #[inline] + pub fn connect_with_session(stream: S, session: ClientSession) + -> Connect + where S: io::Read + io::Write + { + Connect(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) + }) + } +} + +impl TlsAcceptor { + pub fn accept(&self, stream: S) -> Accept + where S: io::Read + io::Write, + { + Self::accept_with_session(stream, ServerSession::new(&self.inner)) + } + + #[inline] + pub fn accept_with_session(stream: S, session: ServerSession) -> Accept + where S: io::Read + io::Write + { + Accept(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) + }) + } } /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct ConnectAsync(MidHandshake); +pub struct Connect(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct AcceptAsync(MidHandshake); - -impl sealed::Sealed for Arc {} - -impl ClientConfigExt for Arc { - fn connect_async(&self, domain: DNSNameRef, stream: S) - -> ConnectAsync - where S: io::Read + io::Write - { - connect_async_with_session(stream, ClientSession::new(self, domain)) - } -} - -#[inline] -pub fn connect_async_with_session(stream: S, session: ClientSession) - -> ConnectAsync - where S: io::Read + io::Write -{ - ConnectAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) - }) -} - -impl sealed::Sealed for Arc {} - -impl ServerConfigExt for Arc { - fn accept_async(&self, stream: S) - -> AcceptAsync - where S: io::Read + io::Write - { - accept_async_with_session(stream, ServerSession::new(self)) - } -} - -#[inline] -pub fn accept_async_with_session(stream: S, session: ServerSession) - -> AcceptAsync - where S: io::Read + io::Write -{ - AcceptAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) - }) -} +pub struct Accept(MidHandshake); struct MidHandshake { @@ -143,7 +142,3 @@ impl io::Write for TlsStream self.io.flush() } } - -mod sealed { - pub trait Sealed {} -} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..d9598bf 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -6,7 +6,7 @@ use self::tokio::io::{ AsyncRead, AsyncWrite }; use self::tokio::prelude::Poll; -impl Future for ConnectAsync { +impl Future for Connect { type Item = TlsStream; type Error = io::Error; @@ -15,7 +15,7 @@ impl Future for ConnectAsync { } } -impl Future for AcceptAsync { +impl Future for Accept { type Item = TlsStream; type Error = io::Error; diff --git a/tests/test.rs b/tests/test.rs index e64dd82..7eae2af 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -13,7 +13,7 @@ use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; use tokio::net::{ TcpListener, TcpStream }; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; +use tokio_rustls::{ TlsConnector, TlsAcceptor }; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); @@ -28,7 +28,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, rsa) .expect("invalid key or certificate"); - let config = Arc::new(config); + let config = TlsAcceptor::from(Arc::new(config)); let (send, recv) = channel(); @@ -40,7 +40,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { let done = listener.incoming() .for_each(move |stream| { - let done = config.accept_async(stream) + let done = config.accept(stream) .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) .and_then(|(stream, buf)| { assert_eq!(buf, HELLO_WORLD); @@ -68,10 +68,10 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option Date: Thu, 16 Aug 2018 12:20:27 +0800 Subject: [PATCH 072/171] impl vecbuf for tokio --- Cargo.toml | 2 + src/common.rs | 99 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 +++ src/tokio_impl.rs | 8 ++-- 4 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 src/common.rs diff --git a/Cargo.toml b/Cargo.toml index e289e0f..94edce4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ appveyor = { repository = "quininer/tokio-rustls" } futures-core = { version = "0.2.0", optional = true } futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } +bytes = { version = "*" } +iovec = { version = "*" } rustls = "0.13" webpki = "0.18.1" diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..799ee7e --- /dev/null +++ b/src/common.rs @@ -0,0 +1,99 @@ +use std::cmp::{ self, Ordering }; +use std::io::{ self, Read, Write }; +use rustls::{ Session, WriteV }; +use tokio::prelude::Async; +use tokio::io::AsyncWrite; +use bytes::Buf; +use iovec::IoVec; + + +pub struct Stream<'a, S: 'a, IO: 'a> { + session: &'a mut S, + io: &'a mut IO +} + +/* +impl<'a, S: Session, IO: Write> Stream<'a, S, IO> { + pub default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} +*/ + +impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { + pub fn write_tls(&mut self) -> io::Result { + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + + +struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => { + if self.pos < self.inner.len() { + self.pos += 1; + } + self.cur = 0; + }, + Ordering::Greater => { + if self.pos < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + for i in 0..len { + dst[i] = self.inner[self.pos + i].into(); + } + + len + } +} diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..81da5fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,12 @@ pub extern crate rustls; pub extern crate webpki; +extern crate tokio; +extern crate bytes; +extern crate iovec; + + +mod common; #[cfg(feature = "tokio")] mod tokio_impl; #[cfg(feature = "unstable-futures")] mod futures_impl; diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..e9a00a9 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,9 +1,7 @@ -extern crate tokio; - use super::*; -use self::tokio::prelude::*; -use self::tokio::io::{ AsyncRead, AsyncWrite }; -use self::tokio::prelude::Poll; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; impl Future for ConnectAsync { From 518ad51376ace135487d29dd1e41e0f1392b9c40 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 15:29:16 +0800 Subject: [PATCH 073/171] impl complete_io --- src/common.rs | 97 +++++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 4 +- src/tokio_impl.rs | 18 +++++---- 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/src/common.rs b/src/common.rs index 799ee7e..df83537 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,16 +12,98 @@ pub struct Stream<'a, S: 'a, IO: 'a> { io: &'a mut IO } -/* -impl<'a, S: Session, IO: Write> Stream<'a, S, IO> { - pub default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) +pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { + fn write_tls(&mut self) -> io::Result; + fn complete_io(&mut self) -> io::Result<(usize, usize)>; +} + +impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { + pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { + Stream { session, io } } } -*/ -impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { - pub fn write_tls(&mut self) -> io::Result { +impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } + + fn complete_io(&mut self) -> io::Result<(usize, usize)> { + // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 + + let until_handshaked = self.session.is_handshaking(); + let mut eof = false; + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + while self.session.wants_write() { + wrlen += self.write_tls()?; + } + + if !until_handshaked && wrlen > 0 { + return Ok((rdlen, wrlen)); + } + + if !eof && self.session.wants_read() { + match self.session.read_tls(self.io)? { + 0 => eof = true, + n => rdlen += n + } + } + + match self.session.process_new_packets() { + Ok(_) => {}, + Err(e) => { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ignored = self.write_tls(); + + return Err(io::Error::new(io::ErrorKind::InvalidData, e)); + }, + }; + + match (eof, until_handshaked, self.session.is_handshaking()) { + (_, true, false) => return Ok((rdlen, wrlen)), + (_, false, _) => return Ok((rdlen, wrlen)), + (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (..) => () + } + } + } +} + +impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + while self.session.wants_read() { + if let (0, 0) = self.complete_io()? { + break + } + } + + self.session.read(buf) + } +} + +impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = self.session.write(buf)?; + self.complete_io()?; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + self.session.flush()?; + if self.session.wants_write() { + self.complete_io()?; + } + Ok(()) + } +} + +impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { struct V<'a, IO: 'a>(&'a mut IO); impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { @@ -41,6 +123,7 @@ impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { } +// TODO test struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, diff --git a/src/lib.rs b/src/lib.rs index 81da5fe..f8432d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +#![feature(specialization)] + pub extern crate rustls; pub extern crate webpki; @@ -18,8 +20,8 @@ use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, - Stream }; +use common::Stream; /// Extension trait for the `Arc` type in the `rustls` crate. diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index e9a00a9..663d6ca 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -2,6 +2,7 @@ use super::*; use tokio::prelude::*; use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::prelude::Poll; +use common::{ Stream, CompleteIo }; impl Future for ConnectAsync { @@ -29,16 +30,17 @@ impl Future for MidHandshake type Error = io::Error; fn poll(&mut self) -> Poll { - loop { + { let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; + if stream.session.is_handshaking() { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(session, io); - let (io, session) = stream.get_mut(); - - match session.complete_io(io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) + match stream.complete_io() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) + } } } From 32f328fc142e6e6379f1ae6445fc58e1cc883ec7 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 17:59:59 +0800 Subject: [PATCH 074/171] remove futures 0.2 code --- Cargo.toml | 8 -------- src/lib.rs | 3 ++- src/tokio_impl.rs | 8 +++----- tests/test.rs | 37 ------------------------------------- 4 files changed, 5 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e289e0f..5d06420 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,20 +15,12 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures-core = { version = "0.2.0", optional = true } -futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } rustls = "0.13" webpki = "0.18.1" [dev-dependencies] -# futures = "0.2.0" tokio = "0.1.6" [features] default = [ "tokio" ] -# unstable-futures = [ -# "futures-core", -# "futures-io", -# "tokio/unstable-futures" -# ] diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..69db77d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,8 +3,9 @@ pub extern crate rustls; pub extern crate webpki; +extern crate tokio; + #[cfg(feature = "tokio")] mod tokio_impl; -#[cfg(feature = "unstable-futures")] mod futures_impl; use std::io; use std::sync::Arc; diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..e9a00a9 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,9 +1,7 @@ -extern crate tokio; - use super::*; -use self::tokio::prelude::*; -use self::tokio::io::{ AsyncRead, AsyncWrite }; -use self::tokio::prelude::Poll; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; impl Future for ConnectAsync { diff --git a/tests/test.rs b/tests/test.rs index e64dd82..c6262e9 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -83,32 +83,6 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { - use futures::FutureExt; - use futures::io::{ AsyncReadExt, AsyncWriteExt }; - use futures::executor::block_on; - - let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } - let config = Arc::new(config); - - let done = TcpStream::connect(addr) - .and_then(|stream| config.connect_async(domain, stream)) - .and_then(|stream| stream.write_all(HELLO_WORLD)) - .and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - stream.close() - }) - .map(drop); - - block_on(done) -} - #[test] fn pass() { @@ -120,17 +94,6 @@ fn pass() { start_client(&addr, "localhost", Some(chain)).unwrap(); } -#[cfg(feature = "unstable-futures")] -#[test] -fn pass2() { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let chain = BufReader::new(Cursor::new(CHAIN)); - - let addr = start_server(cert, keys.pop().unwrap()); - start_client2(&addr, "localhost", Some(chain)).unwrap(); -} - #[should_panic] #[test] fn fail() { From 26046efc3cef6ada79a18a6e208600d87cfe34ee Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 18:43:35 +0800 Subject: [PATCH 075/171] fix vecbuf --- src/common.rs | 125 +++++++++++++++++++++++++++++++++------------- src/tokio_impl.rs | 2 +- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/src/common.rs b/src/common.rs index df83537..070caeb 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,23 +12,16 @@ pub struct Stream<'a, S: 'a, IO: 'a> { io: &'a mut IO } -pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { +pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { fn write_tls(&mut self) -> io::Result; - fn complete_io(&mut self) -> io::Result<(usize, usize)>; } impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { Stream { session, io } } -} -impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { - default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) - } - - fn complete_io(&mut self) -> io::Result<(usize, usize)> { + pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 let until_handshaked = self.session.is_handshaking(); @@ -74,6 +67,32 @@ impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, I } } +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { @@ -102,28 +121,7 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { - fn write_tls(&mut self) -> io::Result { - struct V<'a, IO: 'a>(&'a mut IO); - impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { - fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { - let mut vbytes = VecBuf::new(vbytes); - match self.0.write_buf(&mut vbytes) { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), - Err(err) => Err(err) - } - } - } - - let mut vecbuf = V(self.io); - self.session.writev_tls(&mut vecbuf) - } -} - - -// TODO test struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, @@ -153,14 +151,14 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn advance(&mut self, cnt: usize) { let current = self.inner[self.pos].len(); match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => { - if self.pos < self.inner.len() { - self.pos += 1; - } + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; self.cur = 0; + } else { + self.cur += cnt; }, Ordering::Greater => { - if self.pos < self.inner.len() { + if self.pos + 1 < self.inner.len() { self.pos += 1; } let remaining = self.cur + cnt - current; @@ -180,3 +178,60 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { len } } + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(2); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16_be()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16_be(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [&IoVec; 2] = + [b1.into(), b2.into()]; + + assert_eq!(2, buf.bytes_vec(&mut dst[..])); + } +} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 663d6ca..9f09705 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -2,7 +2,7 @@ use super::*; use tokio::prelude::*; use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::prelude::Poll; -use common::{ Stream, CompleteIo }; +use common::Stream; impl Future for ConnectAsync { From 4a2354c1cc19aaf937da1b1d8f70e5c6705771db Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 19:13:18 +0800 Subject: [PATCH 076/171] rename tokio feature --- Cargo.toml | 4 ++-- examples/client/Cargo.toml | 6 ++---- examples/server/Cargo.toml | 7 ++----- src/lib.rs | 2 +- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 367bcca..41c0671 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,5 +25,5 @@ webpki = "0.18.1" tokio = "0.1.6" [features] -default = [ "tokio_impl" ] -tokio_impl = [ "tokio", "bytes", "iovec" ] +default = [ "tokio-support" ] +tokio-support = [ "tokio", "bytes", "iovec" ] diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index acf13bd..2253096 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -5,11 +5,9 @@ authors = ["quininer "] [dependencies] webpki = "0.18.1" -tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } - +tokio-rustls = { path = "../.." } tokio = "0.1" - -clap = "2.26" +clap = "2" webpki-roots = "0.15" [target.'cfg(unix)'.dependencies] diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 98329ce..170693f 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -4,9 +4,6 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] } - +tokio-rustls = { path = "../.." } tokio = { version = "0.1.6" } -# futures = "0.2.0-beta" - -clap = "2.26" +clap = "2" diff --git a/src/lib.rs b/src/lib.rs index afd3cf0..9d910e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ extern crate iovec; mod common; -#[cfg(feature = "tokio_impl")] mod tokio_impl; +#[cfg(feature = "tokio-support")] mod tokio_impl; use std::io; use std::sync::Arc; From b040a9a65f1b16569652dd2b7564363055d7682e Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 16 Aug 2018 20:44:37 +0800 Subject: [PATCH 077/171] Test use lazy_static! --- Cargo.toml | 1 + src/futures_impl.rs | 170 -------------------------------------------- tests/test.rs | 105 +++++++++++++-------------- 3 files changed, 54 insertions(+), 222 deletions(-) delete mode 100644 src/futures_impl.rs diff --git a/Cargo.toml b/Cargo.toml index 5d06420..e8fcb70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ webpki = "0.18.1" [dev-dependencies] tokio = "0.1.6" +lazy_static = "1" [features] default = [ "tokio" ] diff --git a/src/futures_impl.rs b/src/futures_impl.rs deleted file mode 100644 index 6771316..0000000 --- a/src/futures_impl.rs +++ /dev/null @@ -1,170 +0,0 @@ -extern crate futures_core; -extern crate futures_io; - -use super::*; -use self::futures_core::{ Future, Poll, Async }; -use self::futures_core::task::Context; -use self::futures_io::{ Error, AsyncRead, AsyncWrite }; - - -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -impl Future for AcceptAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -macro_rules! async { - ( to $r:expr ) => { - match $r { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()), - Err(e) => Err(e) - } - }; - ( from $r:expr ) => { - match $r { - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - }; -} - -struct TaskStream<'a, 'b: 'a, S: 'a> { - io: &'a mut S, - task: &'a mut Context<'b> -} - -impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - async!(to self.io.poll_read(self.task, buf)) - } -} - -impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S> - where S: AsyncRead + AsyncWrite -{ - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - async!(to self.io.poll_write(self.task, buf)) - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - async!(to self.io.poll_flush(self.task)) - } -} - -impl Future for MidHandshake - where S: AsyncRead + AsyncWrite, C: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - loop { - let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; - - let (io, session) = stream.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - match session.complete_io(&mut taskio) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending), - Err(e) => return Err(e) - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - -impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - if self.eof { - return Ok(Async::Ready(0)); - } - - // TODO nll - let result = { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - io::Read::read(&mut stream, buf) - }; - - match result { - Ok(0) => { self.eof = true; Ok(Async::Ready(0)) }, - Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.eof = true; - self.is_shutdown = true; - self.session.send_close_notify(); - Ok(Async::Ready(0)) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), - Err(e) => Err(e) - } - } -} - -impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session -{ - fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); - - async!(from io::Write::write(&mut stream, buf)) - } - - fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - - { - let mut stream = Stream::new(session, &mut taskio); - async!(from io::Write::flush(&mut stream))?; - } - - async!(from io::Write::flush(&mut taskio)) - } - - fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - - { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - async!(from session.complete_io(&mut taskio))?; - } - - self.io.poll_close(ctx) - } -} diff --git a/tests/test.rs b/tests/test.rs index c6262e9..fa46f5a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,81 +1,89 @@ +#[macro_use] extern crate lazy_static; extern crate rustls; extern crate tokio; extern crate tokio_rustls; extern crate webpki; -#[cfg(feature = "unstable-futures")] extern crate futures; - use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; -use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; +use std::net::SocketAddr; use tokio::net::{ TcpListener, TcpStream }; -use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; +use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -const HELLO_WORLD: &[u8] = b"Hello world!"; +lazy_static!{ + static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { + use tokio::prelude::*; + use tokio::io as aio; -fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { - use tokio::prelude::*; - use tokio::io as aio; + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa) - .expect("invalid key or certificate"); - let config = Arc::new(config); + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, keys.pop().unwrap()) + .expect("invalid key or certificate"); + let config = Arc::new(config); - let (send, recv) = channel(); + let (send, recv) = channel(); - thread::spawn(move || { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let listener = TcpListener::bind(&addr).unwrap(); + thread::spawn(move || { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(&addr).unwrap(); - send.send(listener.local_addr().unwrap()).unwrap(); + send.send(listener.local_addr().unwrap()).unwrap(); - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept_async(stream) - .and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - aio::write_all(stream, HELLO_WORLD) - }) - .then(|_| Ok(())); + let done = listener.incoming() + .for_each(move |stream| { + let done = config.accept_async(stream) + .and_then(|stream| { + let (reader, writer) = stream.split(); + aio::copy(reader, writer) + }) + .then(|_| Ok(())); - tokio::spawn(done); - Ok(()) - }) - .map_err(|err| panic!("{:?}", err)); + tokio::spawn(done); + Ok(()) + }) + .map_err(|err| panic!("{:?}", err)); - tokio::run(done); - }); + tokio::run(done); + }); - recv.recv().unwrap() + let addr = recv.recv().unwrap(); + (addr, "localhost", CHAIN) + }; } -fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + +fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { + &*TEST_SERVER +} + +fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; + const FILE: &'static [u8] = include_bytes!("../README.md"); + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); - if let Some(mut chain) = chain { - config.root_store.add_pem_file(&mut chain).unwrap(); - } + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); let config = Arc::new(config); let done = TcpStream::connect(addr) .and_then(|stream| config.connect_async(domain, stream)) - .and_then(|stream| aio::write_all(stream, HELLO_WORLD)) - .and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()])) + .and_then(|stream| aio::write_all(stream, FILE)) + .and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()])) .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); + assert_eq!(buf, FILE); aio::shutdown(stream) }) .map(drop); @@ -86,22 +94,15 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option Date: Thu, 16 Aug 2018 22:35:50 +0800 Subject: [PATCH 078/171] fix vecbuf bytes_vec --- src/common.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/common.rs b/src/common.rs index 070caeb..f9900ff 100644 --- a/src/common.rs +++ b/src/common.rs @@ -171,7 +171,11 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { let len = cmp::min(self.inner.len() - self.pos, dst.len()); - for i in 0..len { + if len > 0 { + dst[0] = self.bytes().into(); + } + + for i in 1..len { dst[i] = self.inner[self.pos + i].into(); } From 5cbd5b8aa0c4e03b3006c4231b576e0618c9effa Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 09:18:53 +0800 Subject: [PATCH 079/171] fix: handle Stream non-blocking write --- src/common.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/common.rs b/src/common.rs index f9900ff..9e3cb5c 100644 --- a/src/common.rs +++ b/src/common.rs @@ -100,7 +100,6 @@ impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { break } } - self.session.read(buf) } } @@ -108,7 +107,13 @@ impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; - self.complete_io()?; + while self.session.wants_write() { + match self.complete_io() { + Ok(_) => (), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, + Err(err) => return Err(err) + } + } Ok(len) } From 762d7f952582b8430f79e658572578aa12533c4b Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 10:04:49 +0800 Subject: [PATCH 080/171] Add nightly feature --- Cargo.toml | 5 +- src/{common.rs => common/mod.rs} | 149 +++++-------------------------- src/common/vecbuf.rs | 122 +++++++++++++++++++++++++ src/lib.rs | 7 +- 4 files changed, 154 insertions(+), 129 deletions(-) rename src/{common.rs => common/mod.rs} (57%) create mode 100644 src/common/vecbuf.rs diff --git a/Cargo.toml b/Cargo.toml index 41c0671..58ece24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,5 +25,6 @@ webpki = "0.18.1" tokio = "0.1.6" [features] -default = [ "tokio-support" ] -tokio-support = [ "tokio", "bytes", "iovec" ] +default = ["tokio-support"] +nightly = ["bytes", "iovec"] +tokio-support = ["tokio"] diff --git a/src/common.rs b/src/common/mod.rs similarity index 57% rename from src/common.rs rename to src/common/mod.rs index 9e3cb5c..1f5d1d2 100644 --- a/src/common.rs +++ b/src/common/mod.rs @@ -1,11 +1,14 @@ -use std::cmp::{ self, Ordering }; -use std::io::{ self, Read, Write }; -use rustls::{ Session, WriteV }; -use tokio::prelude::Async; -use tokio::io::AsyncWrite; -use bytes::Buf; -use iovec::IoVec; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +mod vecbuf; +use std::io::{ self, Read, Write }; +use rustls::Session; +#[cfg(feature = "nightly")] +use rustls::WriteV; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] +use tokio::io::AsyncWrite; pub struct Stream<'a, S: 'a, IO: 'a> { session: &'a mut S, @@ -67,14 +70,27 @@ impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { } } +#[cfg(not(feature = "nightly"))] +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +#[cfg(feature = "nightly")] impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { default fn write_tls(&mut self) -> io::Result { self.session.write_tls(self.io) } } +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { fn write_tls(&mut self) -> io::Result { + use tokio::prelude::Async; + use self::vecbuf::VecBuf; + struct V<'a, IO: 'a>(&'a mut IO); impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { @@ -125,122 +141,3 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { Ok(()) } } - - -struct VecBuf<'a, 'b: 'a> { - pos: usize, - cur: usize, - inner: &'a [&'b [u8]] -} - -impl<'a, 'b> VecBuf<'a, 'b> { - fn new(vbytes: &'a [&'b [u8]]) -> Self { - VecBuf { pos: 0, cur: 0, inner: vbytes } - } -} - -impl<'a, 'b> Buf for VecBuf<'a, 'b> { - fn remaining(&self) -> usize { - let sum = self.inner - .iter() - .skip(self.pos) - .map(|bytes| bytes.len()) - .sum::(); - sum - self.cur - } - - fn bytes(&self) -> &[u8] { - &self.inner[self.pos][self.cur..] - } - - fn advance(&mut self, cnt: usize) { - let current = self.inner[self.pos].len(); - match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => if self.pos + 1 < self.inner.len() { - self.pos += 1; - self.cur = 0; - } else { - self.cur += cnt; - }, - Ordering::Greater => { - if self.pos + 1 < self.inner.len() { - self.pos += 1; - } - let remaining = self.cur + cnt - current; - self.advance(remaining); - }, - Ordering::Less => self.cur += cnt, - } - } - - fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { - let len = cmp::min(self.inner.len() - self.pos, dst.len()); - - if len > 0 { - dst[0] = self.bytes().into(); - } - - for i in 1..len { - dst[i] = self.inner[self.pos + i].into(); - } - - len - } -} - -#[cfg(test)] -mod test_vecbuf { - use super::*; - - #[test] - fn test_fresh_cursor_vec() { - let mut buf = VecBuf::new(&[b"he", b"llo"]); - - assert_eq!(buf.remaining(), 5); - assert_eq!(buf.bytes(), b"he"); - - buf.advance(2); - - assert_eq!(buf.remaining(), 3); - assert_eq!(buf.bytes(), b"llo"); - - buf.advance(3); - - assert_eq!(buf.remaining(), 0); - assert_eq!(buf.bytes(), b""); - } - - #[test] - fn test_get_u8() { - let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); - assert_eq!(0x21, buf.get_u8()); - } - - #[test] - fn test_get_u16() { - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x2154, buf.get_u16_be()); - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x5421, buf.get_u16_le()); - } - - #[test] - #[should_panic] - fn test_get_u16_buffer_underflow() { - let mut buf = VecBuf::new(&[b"\x21"]); - buf.get_u16_be(); - } - - #[test] - fn test_bufs_vec() { - let buf = VecBuf::new(&[b"he", b"llo"]); - - let b1: &[u8] = &mut [0]; - let b2: &[u8] = &mut [0]; - - let mut dst: [&IoVec; 2] = - [b1.into(), b2.into()]; - - assert_eq!(2, buf.bytes_vec(&mut dst[..])); - } -} diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs new file mode 100644 index 0000000..dd40163 --- /dev/null +++ b/src/common/vecbuf.rs @@ -0,0 +1,122 @@ +use std::cmp::{ self, Ordering }; +use bytes::Buf; +use iovec::IoVec; + +pub struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + pub fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + }, + Ordering::Greater => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + #[allow(needless_range_loop)] + fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + if len > 0 { + dst[0] = self.bytes().into(); + } + + for i in 1..len { + dst[i] = self.inner[self.pos + i].into(); + } + + len + } +} + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(2); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16_be()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16_be(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [&IoVec; 2] = + [b1.into(), b2.into()]; + + assert_eq!(2, buf.bytes_vec(&mut dst[..])); + } +} diff --git a/src/lib.rs b/src/lib.rs index 9d910e1..b06c227 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![feature(specialization)] +#![cfg_attr(feature = "nightly", feature(specialization))] pub extern crate rustls; pub extern crate webpki; +#[cfg(feature = "tokio-support")] extern crate tokio; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] extern crate bytes; +#[cfg(feature = "nightly")] +#[cfg(feature = "tokio-support")] extern crate iovec; From cf00bbb2f7d464c64ff2fc9ba2825b9a69fcfae2 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 10:14:17 +0800 Subject: [PATCH 081/171] fix ci --- .travis.yml | 20 ++++++++++++++------ appveyor.yml | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3d5e1db..043e804 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,21 @@ language: rust -rust: - - stable cache: cargo -os: - - linux - - osx + +matrix: + include: + - rust: stable + os: linux + - rust: nightly + env: FEATURE=nightly + os: linux + - rust: stable + os: osx + - rust: nightly + env: FEATURE=nightly + os: osx script: - - cargo test --all-features + - cargo test --features "$FEATURE" - cd examples/server - cargo check - cd ../../examples/client diff --git a/appveyor.yml b/appveyor.yml index 7ede91c..038274b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -13,7 +13,7 @@ install: build: false test_script: - - 'cargo test --all-features' + - 'cargo test' - 'cd examples/server' - 'cargo check' - 'cd ../../examples/client' From 482f3c3aa6f7c981a6c76f6903c0d7ca3cd416e3 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 13:07:26 +0800 Subject: [PATCH 082/171] impl prepare_uninitialized_buffer --- src/common/vecbuf.rs | 2 +- src/tokio_impl.rs | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs index dd40163..81bec86 100644 --- a/src/common/vecbuf.rs +++ b/src/common/vecbuf.rs @@ -48,7 +48,7 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { } } - #[allow(needless_range_loop)] + #[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))] fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { let len = cmp::min(self.inner.len() - self.pos, dst.len()); diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 9f09705..e3fd4d5 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -52,7 +52,11 @@ impl AsyncRead for TlsStream where S: AsyncRead + AsyncWrite, C: Session -{} +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} impl AsyncWrite for TlsStream where From 808df2f226bfd8c939209eef504357ee3a15676b Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 17 Aug 2018 13:09:24 +0800 Subject: [PATCH 083/171] publish 0.7.2 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 73ca707..6106795 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.7.1" +version = "0.7.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From f698c44e1a352629a786c490a622f6ed5431b1da Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 18 Aug 2018 13:52:00 +0800 Subject: [PATCH 084/171] Add stream test --- src/common/mod.rs | 3 + src/common/test_stream.rs | 161 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 src/common/test_stream.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 1f5d1d2..8580cc8 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -141,3 +141,6 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { Ok(()) } } + +#[cfg(test)] +mod test_stream; diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs new file mode 100644 index 0000000..2cd9e87 --- /dev/null +++ b/src/common/test_stream.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; +use std::io::{ self, Read, Write, BufReader, Cursor }; +use webpki::DNSNameRef; +use rustls::internal::pemfile::{ certs, rsa_private_keys }; +use rustls::{ + ServerConfig, ClientConfig, + ServerSession, ClientSession, + Session, NoClientAuth +}; +use super::Stream; + + +struct Good<'a>(&'a mut Session); + +impl<'a> Read for Good<'a> { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + self.0.write_tls(buf.by_ref()) + } +} + +impl<'a> Write for Good<'a> { + fn write(&mut self, mut buf: &[u8]) -> io::Result { + let len = self.0.read_tls(buf.by_ref())?; + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +struct Bad(bool); + +impl Read for Bad { + fn read(&mut self, _: &mut [u8]) -> io::Result { + Ok(0) + } +} + +impl Write for Bad { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.0 { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Ok(buf.len()) + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + + +#[test] +fn stream_good() -> io::Result<()> { + const FILE: &'static [u8] = include_bytes!("../../README.md"); + + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + io::copy(&mut Cursor::new(FILE), &mut server)?; + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut client, &mut good); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf)?; + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!")? + } + + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); + + Ok(()) +} + +#[test] +fn stream_bad() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + do_handshake(&mut client, &mut server); + client.set_buffer_limit(1024); + + let mut bad = Bad(true); + let mut stream = Stream::new(&mut client, &mut bad); + assert_eq!(stream.write(&[0x42; 8])?, 8); + assert_eq!(stream.write(&[0x42; 8])?, 8); + let r = stream.write(&[0x00; 1024])?; // fill buffer + assert!(r < 1024); + assert_eq!( + stream.write(&[0x01]).unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + + Ok(()) +} + +#[test] +fn stream_handshake() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut client, &mut good); + let (r, w) = stream.complete_io()?; + + assert!(r > 0); + assert!(w > 0); + + stream.complete_io()?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) +} + +#[test] +fn stream_handshake_eof() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut bad = Bad(false); + let mut stream = Stream::new(&mut client, &mut bad); + let r = stream.complete_io(); + + assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + + Ok(()) +} + +fn make_pair() -> (ServerSession, ClientSession) { + const CERT: &str = include_str!("../../tests/end.cert"); + const CHAIN: &str = include_str!("../../tests/end.chain"); + const RSA: &str = include_str!("../../tests/end.rsa"); + + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let mut sconfig = ServerConfig::new(NoClientAuth::new()); + sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap(); + let server = ServerSession::new(&Arc::new(sconfig)); + + let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap(); + let mut cconfig = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(CHAIN)); + cconfig.root_store.add_pem_file(&mut chain).unwrap(); + let client = ClientSession::new(&Arc::new(cconfig), domain); + + (server, client) +} + +fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { + let mut good = Good(server); + let mut stream = Stream::new(client, &mut good); + stream.complete_io().unwrap(); + stream.complete_io().unwrap(); +} From 686b75bd4623f9311cc302308a059c87fe475a61 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 21 Aug 2018 09:57:27 +0800 Subject: [PATCH 085/171] fix #5 --- examples/client/Cargo.toml | 7 +--- examples/client/src/main.rs | 74 ++++++++++--------------------------- 2 files changed, 20 insertions(+), 61 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 2253096..780ea88 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -9,9 +9,4 @@ tokio-rustls = { path = "../.." } tokio = "0.1" clap = "2" webpki-roots = "0.15" - -[target.'cfg(unix)'.dependencies] -tokio-file-unix = "0.5" - -[target.'cfg(not(unix))'.dependencies] -tokio-fs = "0.1" +tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 8499993..e58a633 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -4,8 +4,7 @@ extern crate webpki; extern crate webpki_roots; extern crate tokio_rustls; -#[cfg(unix)] extern crate tokio_file_unix; -#[cfg(not(unix))] extern crate tokio_fs; +extern crate tokio_stdin_stdout; use std::sync::Arc; use std::net::ToSocketAddrs; @@ -16,6 +15,7 @@ use tokio::net::TcpStream; use tokio::prelude::*; use clap::{ App, Arg }; use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; +use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; fn app() -> App<'static, 'static> { App::new("client") @@ -52,59 +52,23 @@ fn main() { let arc_config = Arc::new(config); let socket = TcpStream::connect(&addr); + let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0)); - #[cfg(unix)] - let resp = { - use tokio::reactor::Handle; - use tokio_file_unix::{ raw_stdin, raw_stdout, File }; + let done = socket + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + arc_config.connect_async(domain, stream) + }) + .and_then(move |stream| io::write_all(stream, text)) + .and_then(move |(stream, _)| { + let (r, w) = stream.split(); + io::copy(r, stdout) + .map(drop) + .select2(io::copy(stdin, w).map(drop)) + .map_err(|res| res.split().0) + }) + .map(drop) + .map_err(|err| eprintln!("{:?}", err)); - let stdin = raw_stdin() - .and_then(File::new_nb) - .and_then(|fd| fd.into_reader(&Handle::current())) - .unwrap(); - let stdout = raw_stdout() - .and_then(File::new_nb) - .and_then(|fd| fd.into_io(&Handle::current())) - .unwrap(); - - socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .select2(io::copy(stdin, w).map(drop)) - .map_err(|res| res.split().0) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)) - }; - - #[cfg(not(unix))] - let resp = { - use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout }; - - let (stdin, stdout) = (tokio_stdin(), tokio_stdout()); - - socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - arc_config.connect_async(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .join(io::copy(stdin, w).map(drop)) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)) - }; - - tokio::run(resp); + tokio::run(done); } From 9378e415ce407bc8e0b26f6e70517b943070e190 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 6 Sep 2018 13:56:00 +0800 Subject: [PATCH 086/171] impl read_initializer --- src/common/mod.rs | 7 +++++++ src/lib.rs | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 8580cc8..7db198e 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,6 +3,8 @@ mod vecbuf; use std::io::{ self, Read, Write }; +#[cfg(feature = "nightly")] +use std::io::Initializer; use rustls::Session; #[cfg(feature = "nightly")] use rustls::WriteV; @@ -110,6 +112,11 @@ impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S } impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { + #[cfg(feature = "nightly")] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } + fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { if let (0, 0) = self.complete_io()? { diff --git a/src/lib.rs b/src/lib.rs index b06c227..15bce44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![cfg_attr(feature = "nightly", feature(specialization))] +#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))] pub extern crate rustls; pub extern crate webpki; From e7231821086ed7d1c3525e9455624f146450713c Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 19 Sep 2018 18:03:39 +0800 Subject: [PATCH 087/171] publish 0.8.0-alpha --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 6106795..ec25a73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.7.2" +version = "0.8.0-alpha" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 6b493615a928b9ff3bfa39bf93609e13ba7574fc Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 23 Sep 2018 02:00:52 +0800 Subject: [PATCH 088/171] rename generic name --- src/lib.rs | 69 +++++++++++++++++++++++++++++++---------------- src/tokio_impl.rs | 26 +++++++++--------- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d61caf0..43bd51f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,10 +28,12 @@ use rustls::{ use common::Stream; +#[derive(Clone)] pub struct TlsConnector { inner: Arc } +#[derive(Clone)] pub struct TlsAcceptor { inner: Arc } @@ -49,16 +51,16 @@ impl From> for TlsAcceptor { } impl TlsConnector { - pub fn connect(&self, domain: DNSNameRef, stream: S) -> Connect - where S: io::Read + io::Write + pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect + where IO: io::Read + io::Write { Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) } #[inline] - pub fn connect_with_session(stream: S, session: ClientSession) - -> Connect - where S: io::Read + io::Write + pub fn connect_with_session(stream: IO, session: ClientSession) + -> Connect + where IO: io::Read + io::Write { Connect(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) @@ -67,15 +69,15 @@ impl TlsConnector { } impl TlsAcceptor { - pub fn accept(&self, stream: S) -> Accept - where S: io::Read + io::Write, + pub fn accept(&self, stream: IO) -> Accept + where IO: io::Read + io::Write, { Self::accept_with_session(stream, ServerSession::new(&self.inner)) } #[inline] - pub fn accept_with_session(stream: S, session: ServerSession) -> Accept - where S: io::Read + io::Write + pub fn accept_with_session(stream: IO, session: ServerSession) -> Accept + where IO: io::Read + io::Write { Accept(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) @@ -86,43 +88,64 @@ impl TlsAcceptor { /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct Connect(MidHandshake); +pub struct Connect(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct Accept(MidHandshake); +pub struct Accept(MidHandshake); -struct MidHandshake { - inner: Option> +struct MidHandshake { + inner: Option> } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] -pub struct TlsStream { +pub struct TlsStream { is_shutdown: bool, eof: bool, - io: S, - session: C + io: IO, + session: S } -impl TlsStream { +impl TlsStream { #[inline] - pub fn get_ref(&self) -> (&S, &C) { + pub fn get_ref(&self) -> (&IO, &S) { (&self.io, &self.session) } #[inline] - pub fn get_mut(&mut self) -> (&mut S, &mut C) { + pub fn get_mut(&mut self) -> (&mut IO, &mut S) { (&mut self.io, &mut self.session) } + + #[inline] + pub fn into_inner(self) -> (IO, S) { + (self.io, self.session) + } } -impl io::Read for TlsStream - where S: io::Read + io::Write, C: Session +impl From<(IO, S)> for TlsStream { + #[inline] + fn from((io, session): (IO, S)) -> TlsStream { + TlsStream { + is_shutdown: false, + eof: false, + io, session + } + } +} + +impl io::Read for TlsStream + where IO: io::Read + io::Write, S: Session { + #[cfg(feature = "nightly")] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } + fn read(&mut self, buf: &mut [u8]) -> io::Result { if self.eof { return Ok(0); @@ -142,8 +165,8 @@ impl io::Read for TlsStream } } -impl io::Write for TlsStream - where S: io::Read + io::Write, C: Session +impl io::Write for TlsStream + where IO: io::Read + io::Write, S: Session { fn write(&mut self, buf: &[u8]) -> io::Result { Stream::new(&mut self.session, &mut self.io).write(buf) diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 11179dc..00b4722 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -5,8 +5,8 @@ use tokio::prelude::Poll; use common::Stream; -impl Future for Connect { - type Item = TlsStream; +impl Future for Connect { + type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -14,8 +14,8 @@ impl Future for Connect { } } -impl Future for Accept { - type Item = TlsStream; +impl Future for Accept { + type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -23,10 +23,10 @@ impl Future for Accept { } } -impl Future for MidHandshake - where S: io::Read + io::Write, C: Session +impl Future for MidHandshake + where IO: io::Read + io::Write, S: Session { - type Item = TlsStream; + type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -48,20 +48,20 @@ impl Future for MidHandshake } } -impl AsyncRead for TlsStream +impl AsyncRead for TlsStream where - S: AsyncRead + AsyncWrite, - C: Session + IO: AsyncRead + AsyncWrite, + S: Session { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false } } -impl AsyncWrite for TlsStream +impl AsyncWrite for TlsStream where - S: AsyncRead + AsyncWrite, - C: Session + IO: AsyncRead + AsyncWrite, + S: Session { fn shutdown(&mut self) -> Poll<(), io::Error> { if !self.is_shutdown { From 30cacd04a07f51e41003b5f1f1fa118025a1ee63 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 24 Sep 2018 02:09:28 +0800 Subject: [PATCH 089/171] fix nightly feature --- src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 43bd51f..736cad6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ mod common; use std::io; use std::sync::Arc; +#[cfg(feature = "nightly")] +use std::io::Initializer; use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, From 1f98d87a620cf23b09fb851e8154fda2f3e4ccde Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 24 Sep 2018 23:02:30 +0800 Subject: [PATCH 090/171] more complete handshake --- src/common/mod.rs | 5 +++-- src/lib.rs | 4 +++- src/tokio_impl.rs | 39 +++++++++++++++++++++++++-------------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 7db198e..e18aa6e 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -12,9 +12,10 @@ use rustls::WriteV; #[cfg(feature = "tokio-support")] use tokio::io::AsyncWrite; + pub struct Stream<'a, S: 'a, IO: 'a> { - session: &'a mut S, - io: &'a mut IO + pub session: &'a mut S, + pub io: &'a mut IO } pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { diff --git a/src/lib.rs b/src/lib.rs index 736cad6..ee1524d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,9 +129,11 @@ impl TlsStream { } } -impl From<(IO, S)> for TlsStream { +impl From<(IO, S)> for TlsStream { #[inline] fn from((io, session): (IO, S)) -> TlsStream { + assert!(!session.is_handshaking()); + TlsStream { is_shutdown: false, eof: false, diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 00b4722..644e4f0 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -5,6 +5,17 @@ use tokio::prelude::Poll; use common::Stream; +macro_rules! try_async { + ( $e:expr ) => { + match $e { + Ok(n) => n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => + return Ok(Async::NotReady), + Err(e) => return Err(e) + } + } +} + impl Future for Connect { type Item = TlsStream; type Error = io::Error; @@ -24,7 +35,9 @@ impl Future for Accept { } impl Future for MidHandshake - where IO: io::Read + io::Write, S: Session +where + IO: io::Read + io::Write, + S: Session { type Item = TlsStream; type Error = io::Error; @@ -32,15 +45,15 @@ impl Future for MidHandshake fn poll(&mut self) -> Poll { { let stream = self.inner.as_mut().unwrap(); - if stream.session.is_handshaking() { - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(session, io); + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(session, io); - match stream.complete_io() { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) - } + if stream.session.is_handshaking() { + try_async!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_async!(stream.complete_io()); } } @@ -69,12 +82,10 @@ impl AsyncWrite for TlsStream self.is_shutdown = true; } - match self.session.complete_io(&mut self.io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) + { + let mut stream = Stream::new(&mut self.session, &mut self.io); + try_async!(stream.complete_io()); } - self.io.shutdown() } } From f6e8f86382330143b3fdc8d82eac1eba09616db6 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 1 Oct 2018 01:38:42 +0800 Subject: [PATCH 091/171] publish 0.8.0 --- Cargo.toml | 4 ++-- README.md | 16 ++++------------ 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ec25a73..b4cd9e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.8.0-alpha" +version = "0.8.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -18,7 +18,7 @@ appveyor = { repository = "quininer/tokio-rustls" } tokio = { version = "0.1.6", optional = true } bytes = { version = "0.4", optional = true } iovec = { version = "0.1", optional = true } -rustls = "0.13" +rustls = "0.14" webpki = "0.18.1" [dev-dependencies] diff --git a/README.md b/README.md index 45d85b5..f22e97e 100644 --- a/README.md +++ b/README.md @@ -13,17 +13,17 @@ Asynchronous TLS/SSL streams for [Tokio](https://tokio.rs/) using ```rust use webpki::DNSNameRef; -use tokio_rustls::{ ClientConfigExt, rustls::ClientConfig }; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; // ... let mut config = ClientConfig::new(); config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); -let config = Arc::new(config); -let domain = DNSNameRef::try_from_ascii_str("www.rust-lang.org").unwrap(); +let config = TlsConnector::from(Arc::new(config)); +let dnsname = DNSNameRef::try_from_ascii_str("www.rust-lang.org").unwrap(); TcpStream::connect(&addr) - .and_then(|socket| config.connect_async(domain, socket)) + .and_then(move |socket| config.connect(dnsname, socket)) // ... ``` @@ -37,14 +37,6 @@ cd examples/client cargo run -- hsts.badssl.com ``` -Currently on Windows the example client reads from stdin and writes to stdout using -blocking I/O. Until this is fixed, do something this on Windows: - -```sh -cd examples/client -echo | cargo run -- hsts.badssl.com -``` - ### Server Example Program See [examples/server](examples/server/src/main.rs). You can run it with: From 6f1787e9d1adf96828630bca320de236b6c6801a Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Fri, 11 Jan 2019 14:03:37 -0800 Subject: [PATCH 092/171] Shrink down the dependency on tokio it turns out that tokio-rustls only requires a small portion of the tokio stack. This patch slims down the dependencies since not all clients need the full tokio stack. --- Cargo.toml | 5 +++-- src/common/mod.rs | 4 ++-- src/lib.rs | 4 +++- src/tokio_impl.rs | 5 ++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b4cd9e1..a0e2017 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,8 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -tokio = { version = "0.1.6", optional = true } +futures = { version = "0.1", optional = true } +tokio-io = { version = "0.1.6", optional = true } bytes = { version = "0.4", optional = true } iovec = { version = "0.1", optional = true } rustls = "0.14" @@ -28,4 +29,4 @@ lazy_static = "1" [features] default = ["tokio-support"] nightly = ["bytes", "iovec"] -tokio-support = ["tokio"] +tokio-support = ["futures", "tokio-io"] diff --git a/src/common/mod.rs b/src/common/mod.rs index e18aa6e..8e4a5b1 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -10,7 +10,7 @@ use rustls::Session; use rustls::WriteV; #[cfg(feature = "nightly")] #[cfg(feature = "tokio-support")] -use tokio::io::AsyncWrite; +use tokio_io::AsyncWrite; pub struct Stream<'a, S: 'a, IO: 'a> { @@ -91,7 +91,7 @@ impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> #[cfg(feature = "tokio-support")] impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { fn write_tls(&mut self) -> io::Result { - use tokio::prelude::Async; + use futures::Async; use self::vecbuf::VecBuf; struct V<'a, IO: 'a>(&'a mut IO); diff --git a/src/lib.rs b/src/lib.rs index ee1524d..2aaf2e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,9 @@ pub extern crate rustls; pub extern crate webpki; #[cfg(feature = "tokio-support")] -extern crate tokio; +extern crate futures; +#[cfg(feature = "tokio-support")] +extern crate tokio_io; #[cfg(feature = "nightly")] #[cfg(feature = "tokio-support")] extern crate bytes; diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 644e4f0..7d9ba57 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,7 +1,6 @@ use super::*; -use tokio::prelude::*; -use tokio::io::{ AsyncRead, AsyncWrite }; -use tokio::prelude::Poll; +use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::{Async, Future, Poll}; use common::Stream; From d72eb459b2de777f07da65fcdbbe293bb54a6b3b Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 12 Jan 2019 12:38:20 +0800 Subject: [PATCH 093/171] publish 0.8.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a0e2017..81fd3ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.8.0" +version = "0.8.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 7a54c9fa079c67dd12360a964ee61542c1df7ad8 Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Sun, 20 Jan 2019 20:38:36 +0000 Subject: [PATCH 094/171] Update to rustls 0.15, webpki 0.19 --- Cargo.toml | 4 ++-- examples/client/Cargo.toml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 81fd3ff..9fb9730 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,8 @@ futures = { version = "0.1", optional = true } tokio-io = { version = "0.1.6", optional = true } bytes = { version = "0.4", optional = true } iovec = { version = "0.1", optional = true } -rustls = "0.14" -webpki = "0.18.1" +rustls = "0.15" +webpki = "0.19" [dev-dependencies] tokio = "0.1.6" diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 780ea88..3765efc 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -4,9 +4,9 @@ version = "0.1.0" authors = ["quininer "] [dependencies] -webpki = "0.18.1" +webpki = "0.19" tokio-rustls = { path = "../.." } tokio = "0.1" clap = "2" -webpki-roots = "0.15" +webpki-roots = "0.16" tokio-stdin-stdout = "0.1" From db21c4c94751be3a5040af1eea9c58641dc37bda Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 22 Jan 2019 09:48:55 +0800 Subject: [PATCH 095/171] publish 0.9.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9fb9730..58973f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.8.1" +version = "0.9.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 5f6d0233ed971ece8955851c26a93fd852151e2d Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 16 Feb 2019 01:31:46 +0800 Subject: [PATCH 096/171] tokio only * remove io::Read/io::Write support * stable vecio --- Cargo.toml | 15 +++++---------- src/common/mod.rs | 40 ++++++--------------------------------- src/common/test_stream.rs | 16 ++++++++++++++++ src/lib.rs | 30 ++++++++--------------------- src/tokio_impl.rs | 2 +- 5 files changed, 36 insertions(+), 67 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 58973f0..9de53ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.9.0" +version = "0.10.0-alpha" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -15,18 +15,13 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = { version = "0.1", optional = true } -tokio-io = { version = "0.1.6", optional = true } -bytes = { version = "0.4", optional = true } -iovec = { version = "0.1", optional = true } +futures = "0.1" +tokio-io = "0.1.6" +bytes = "0.4" +iovec = "0.1" rustls = "0.15" webpki = "0.19" [dev-dependencies] tokio = "0.1.6" lazy_static = "1" - -[features] -default = ["tokio-support"] -nightly = ["bytes", "iovec"] -tokio-support = ["futures", "tokio-io"] diff --git a/src/common/mod.rs b/src/common/mod.rs index 8e4a5b1..19fedb1 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,16 +1,9 @@ -#[cfg(feature = "nightly")] -#[cfg(feature = "tokio-support")] mod vecbuf; use std::io::{ self, Read, Write }; -#[cfg(feature = "nightly")] -use std::io::Initializer; use rustls::Session; -#[cfg(feature = "nightly")] use rustls::WriteV; -#[cfg(feature = "nightly")] -#[cfg(feature = "tokio-support")] -use tokio_io::AsyncWrite; +use tokio_io::{ AsyncRead, AsyncWrite }; pub struct Stream<'a, S: 'a, IO: 'a> { @@ -18,11 +11,11 @@ pub struct Stream<'a, S: 'a, IO: 'a> { pub io: &'a mut IO } -pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { +pub trait WriteTls<'a, S: Session, IO: AsyncRead + AsyncWrite>: Read + Write { fn write_tls(&mut self) -> io::Result; } -impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { +impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> { pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { Stream { session, io } } @@ -73,23 +66,7 @@ impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { } } -#[cfg(not(feature = "nightly"))] -impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { - fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) - } -} - -#[cfg(feature = "nightly")] -impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { - default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) - } -} - -#[cfg(feature = "nightly")] -#[cfg(feature = "tokio-support")] -impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { +impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { fn write_tls(&mut self) -> io::Result { use futures::Async; use self::vecbuf::VecBuf; @@ -112,12 +89,7 @@ impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S } } -impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { - #[cfg(feature = "nightly")] - unsafe fn initializer(&self) -> Initializer { - Initializer::nop() - } - +impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { if let (0, 0) = self.complete_io()? { @@ -128,7 +100,7 @@ impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { +impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Write for Stream<'a, S, IO> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; while self.session.wants_write() { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 2cd9e87..66b34b6 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -7,6 +7,8 @@ use rustls::{ ServerSession, ClientSession, Session, NoClientAuth }; +use futures::{ Async, Poll }; +use tokio_io::{ AsyncRead, AsyncWrite }; use super::Stream; @@ -31,6 +33,13 @@ impl<'a> Write for Good<'a> { } } +impl<'a> AsyncRead for Good<'a> {} +impl<'a> AsyncWrite for Good<'a> { + fn shutdown(&mut self) -> Poll<(), io::Error> { + Ok(Async::Ready(())) + } +} + struct Bad(bool); impl Read for Bad { @@ -53,6 +62,13 @@ impl Write for Bad { } } +impl AsyncRead for Bad {} +impl AsyncWrite for Bad { + fn shutdown(&mut self) -> Poll<(), io::Error> { + Ok(Async::Ready(())) + } +} + #[test] fn stream_good() -> io::Result<()> { diff --git a/src/lib.rs b/src/lib.rs index 2aaf2e4..0cb0e3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,34 +1,25 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))] - pub extern crate rustls; pub extern crate webpki; -#[cfg(feature = "tokio-support")] extern crate futures; -#[cfg(feature = "tokio-support")] extern crate tokio_io; -#[cfg(feature = "nightly")] -#[cfg(feature = "tokio-support")] extern crate bytes; -#[cfg(feature = "nightly")] -#[cfg(feature = "tokio-support")] extern crate iovec; mod common; -#[cfg(feature = "tokio-support")] mod tokio_impl; +mod tokio_impl; use std::io; use std::sync::Arc; -#[cfg(feature = "nightly")] -use std::io::Initializer; use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, }; +use tokio_io::{ AsyncRead, AsyncWrite }; use common::Stream; @@ -56,7 +47,7 @@ impl From> for TlsAcceptor { impl TlsConnector { pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect - where IO: io::Read + io::Write + where IO: AsyncRead + AsyncWrite { Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) } @@ -64,7 +55,7 @@ impl TlsConnector { #[inline] pub fn connect_with_session(stream: IO, session: ClientSession) -> Connect - where IO: io::Read + io::Write + where IO: AsyncRead + AsyncWrite { Connect(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) @@ -74,14 +65,14 @@ impl TlsConnector { impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept - where IO: io::Read + io::Write, + where IO: AsyncRead + AsyncWrite, { Self::accept_with_session(stream, ServerSession::new(&self.inner)) } #[inline] pub fn accept_with_session(stream: IO, session: ServerSession) -> Accept - where IO: io::Read + io::Write + where IO: AsyncRead + AsyncWrite { Accept(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) @@ -145,13 +136,8 @@ impl From<(IO, S)> for TlsStream { } impl io::Read for TlsStream - where IO: io::Read + io::Write, S: Session + where IO: AsyncRead + AsyncWrite, S: Session { - #[cfg(feature = "nightly")] - unsafe fn initializer(&self) -> Initializer { - Initializer::nop() - } - fn read(&mut self, buf: &mut [u8]) -> io::Result { if self.eof { return Ok(0); @@ -172,7 +158,7 @@ impl io::Read for TlsStream } impl io::Write for TlsStream - where IO: io::Read + io::Write, S: Session + where IO: AsyncRead + AsyncWrite, S: Session { fn write(&mut self, buf: &[u8]) -> io::Result { Stream::new(&mut self.session, &mut self.io).write(buf) diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 7d9ba57..0897e93 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,7 +35,7 @@ impl Future for Accept { impl Future for MidHandshake where - IO: io::Read + io::Write, + IO: AsyncRead + AsyncWrite, S: Session { type Item = TlsStream; From 7d6ed0acfcfcdbeb38492463d80741df65b88e37 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 16 Feb 2019 13:05:36 +0800 Subject: [PATCH 097/171] fix travis ci --- .travis.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 043e804..3653f1f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,16 +6,14 @@ matrix: - rust: stable os: linux - rust: nightly - env: FEATURE=nightly os: linux - rust: stable os: osx - rust: nightly - env: FEATURE=nightly os: osx script: - - cargo test --features "$FEATURE" + - cargo test - cd examples/server - cargo check - cd ../../examples/client From 3e605aafe4746bcf8156aef6c065b2b66ddd87bb Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 18 Feb 2019 16:51:31 +0800 Subject: [PATCH 098/171] Add 0-RTT support --- src/common/mod.rs | 20 ++-- src/common/test_stream.rs | 10 +- src/lib.rs | 228 +++++++++++++++++++++++++++++--------- src/tokio_impl.rs | 79 +++++++++---- 4 files changed, 247 insertions(+), 90 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 19fedb1..9010d8d 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -6,18 +6,18 @@ use rustls::WriteV; use tokio_io::{ AsyncRead, AsyncWrite }; -pub struct Stream<'a, S: 'a, IO: 'a> { - pub session: &'a mut S, - pub io: &'a mut IO +pub struct Stream<'a, IO: 'a, S: 'a> { + pub io: &'a mut IO, + pub session: &'a mut S } -pub trait WriteTls<'a, S: Session, IO: AsyncRead + AsyncWrite>: Read + Write { +pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { fn write_tls(&mut self) -> io::Result; } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> { - pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { - Stream { session, io } +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { + pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { + Stream { io, session } } pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { @@ -66,7 +66,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> { fn write_tls(&mut self) -> io::Result { use futures::Async; use self::vecbuf::VecBuf; @@ -89,7 +89,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream< } } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> { +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { if let (0, 0) = self.complete_io()? { @@ -100,7 +100,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Write for Stream<'a, S, IO> { +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; while self.session.wants_write() { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 66b34b6..a43622c 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -80,7 +80,7 @@ fn stream_good() -> io::Result<()> { { let mut good = Good(&mut server); - let mut stream = Stream::new(&mut client, &mut good); + let mut stream = Stream::new(&mut good, &mut client); let mut buf = Vec::new(); stream.read_to_end(&mut buf)?; @@ -102,7 +102,7 @@ fn stream_bad() -> io::Result<()> { client.set_buffer_limit(1024); let mut bad = Bad(true); - let mut stream = Stream::new(&mut client, &mut bad); + let mut stream = Stream::new(&mut bad, &mut client); assert_eq!(stream.write(&[0x42; 8])?, 8); assert_eq!(stream.write(&[0x42; 8])?, 8); let r = stream.write(&[0x00; 1024])?; // fill buffer @@ -121,7 +121,7 @@ fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); - let mut stream = Stream::new(&mut client, &mut good); + let mut stream = Stream::new(&mut good, &mut client); let (r, w) = stream.complete_io()?; assert!(r > 0); @@ -141,7 +141,7 @@ fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair(); let mut bad = Bad(false); - let mut stream = Stream::new(&mut client, &mut bad); + let mut stream = Stream::new(&mut bad, &mut client); let r = stream.complete_io(); assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); @@ -171,7 +171,7 @@ fn make_pair() -> (ServerSession, ClientSession) { fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { let mut good = Good(server); - let mut stream = Stream::new(client, &mut good); + let mut stream = Stream::new(&mut good, client); stream.complete_io().unwrap(); stream.complete_io().unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index 0cb0e3c..378c693 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,12 +12,13 @@ extern crate iovec; mod common; mod tokio_impl; -use std::io; +use std::mem; +use std::io::{ self, Write }; use std::sync::Arc; use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, - ClientConfig, ServerConfig, + ClientConfig, ServerConfig }; use tokio_io::{ AsyncRead, AsyncWrite }; use common::Stream; @@ -25,7 +26,8 @@ use common::Stream; #[derive(Clone)] pub struct TlsConnector { - inner: Arc + inner: Arc, + early_data: bool } #[derive(Clone)] @@ -35,7 +37,7 @@ pub struct TlsAcceptor { impl From> for TlsConnector { fn from(inner: Arc) -> TlsConnector { - TlsConnector { inner } + TlsConnector { inner, early_data: false } } } @@ -46,19 +48,39 @@ impl From> for TlsAcceptor { } impl TlsConnector { + pub fn early_data(mut self, flag: bool) -> TlsConnector { + self.early_data = flag; + self + } + pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite { - Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) + self.connect_with(domain, stream, |_| ()) } #[inline] - pub fn connect_with_session(stream: IO, session: ClientSession) + pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect - where IO: AsyncRead + AsyncWrite + where + IO: AsyncRead + AsyncWrite, + F: FnOnce(&mut ClientSession) { - Connect(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) + let mut session = ClientSession::new(&self.inner, domain); + f(&mut session); + + Connect(if self.early_data { + MidHandshake::EarlyData(TlsStream { + session, io: stream, + state: TlsState::EarlyData, + early_data: (0, Vec::new()) + }) + } else { + MidHandshake::Handshaking(TlsStream { + session, io: stream, + state: TlsState::Stream, + early_data: (0, Vec::new()) + }) }) } } @@ -67,16 +89,24 @@ impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept where IO: AsyncRead + AsyncWrite, { - Self::accept_with_session(stream, ServerSession::new(&self.inner)) + self.accept_with(stream, |_| ()) } #[inline] - pub fn accept_with_session(stream: IO, session: ServerSession) -> Accept - where IO: AsyncRead + AsyncWrite + pub fn accept_with(&self, stream: IO, f: F) + -> Accept + where + IO: AsyncRead + AsyncWrite, + F: FnOnce(&mut ServerSession) { - Accept(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) - }) + let mut session = ServerSession::new(&self.inner); + f(&mut session); + + Accept(MidHandshake::Handshaking(TlsStream { + session, io: stream, + state: TlsState::Stream, + early_data: (0, Vec::new()) + })) } } @@ -89,9 +119,10 @@ pub struct Connect(MidHandshake); /// once the accept handshake has finished. pub struct Accept(MidHandshake); - -struct MidHandshake { - inner: Option> +enum MidHandshake { + Handshaking(TlsStream), + EarlyData(TlsStream), + End } @@ -99,10 +130,18 @@ struct MidHandshake { /// protocol. #[derive(Debug)] pub struct TlsStream { - is_shutdown: bool, - eof: bool, io: IO, - session: S + session: S, + state: TlsState, + early_data: (usize, Vec) +} + +#[derive(Debug)] +enum TlsState { + EarlyData, + Stream, + Eof, + Shutdown } impl TlsStream { @@ -122,50 +161,135 @@ impl TlsStream { } } -impl From<(IO, S)> for TlsStream { - #[inline] - fn from((io, session): (IO, S)) -> TlsStream { - assert!(!session.is_handshaking()); - - TlsStream { - is_shutdown: false, - eof: false, - io, session - } - } -} - -impl io::Read for TlsStream - where IO: AsyncRead + AsyncWrite, S: Session +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite { fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.eof { - return Ok(0); - } + let mut stream = Stream::new(&mut self.io, &mut self.session); - match Stream::new(&mut self.session, &mut self.io).read(buf) { - Ok(0) => { self.eof = true; Ok(0) }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.eof = true; - self.is_shutdown = true; - self.session.send_close_notify(); - Ok(0) + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + *pos = 0; + data.clear(); + stream.read(buf) }, - Err(e) => Err(e) + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0), } } } -impl io::Write for TlsStream - where IO: AsyncRead + AsyncWrite, S: Session +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0), + TlsState::EarlyData => unreachable!() + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite { fn write(&mut self, buf: &[u8]) -> io::Result { - Stream::new(&mut self.session, &mut self.io).write(buf) + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = early_data.write(buf)?; + data.extend_from_slice(&buf[..len]); + return Ok(len); + } + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + *pos = 0; + data.clear(); + stream.write(buf) + }, + _ => stream.write(buf) + } } fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.session, &mut self.io).flush()?; + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + stream.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; self.io.flush() } } diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 0897e93..f97cde3 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -42,47 +42,80 @@ where type Error = io::Error; fn poll(&mut self) -> Poll { - { - let stream = self.inner.as_mut().unwrap(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(session, io); + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); - if stream.session.is_handshaking() { - try_async!(stream.complete_io()); - } + if stream.session.is_handshaking() { + try_async!(stream.complete_io()); + } - if stream.session.wants_write() { - try_async!(stream.complete_io()); - } + if stream.session.wants_write() { + try_async!(stream.complete_io()); + } + }, + _ => () } - Ok(Async::Ready(self.inner.take().unwrap())) + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) + | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } } } -impl AsyncRead for TlsStream - where - IO: AsyncRead + AsyncWrite, - S: Session +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false } } -impl AsyncWrite for TlsStream - where - IO: AsyncRead + AsyncWrite, - S: Session +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } } { - let mut stream = Stream::new(&mut self.session, &mut self.io); + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_async!(stream.complete_io()); + } + self.io.shutdown() + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); try_async!(stream.complete_io()); } self.io.shutdown() From 65932f5150158aa1816b4e5915a34cce637637cf Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 18 Feb 2019 20:01:37 +0800 Subject: [PATCH 099/171] Add 0-RTT test --- Cargo.toml | 1 + src/lib.rs | 9 +++++++-- src/test_0rtt.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/test.rs | 21 +++++++++++++------- 4 files changed, 73 insertions(+), 9 deletions(-) create mode 100644 src/test_0rtt.rs diff --git a/Cargo.toml b/Cargo.toml index 9de53ac..b15bea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ webpki = "0.19" [dev-dependencies] tokio = "0.1.6" lazy_static = "1" +webpki-roots = "0.16" diff --git a/src/lib.rs b/src/lib.rs index 378c693..dd34452 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,10 @@ impl From> for TlsAcceptor { } impl TlsConnector { + /// Enable 0-RTT. + /// + /// Note that you want to use 0-RTT. + /// You must set `enable_early_data` to `true` in `ClientConfig`. pub fn early_data(mut self, flag: bool) -> TlsConnector { self.early_data = flag; self @@ -186,7 +190,6 @@ where IO: AsyncRead + AsyncWrite // end self.state = TlsState::Stream; - *pos = 0; data.clear(); stream.read(buf) }, @@ -266,7 +269,6 @@ where IO: AsyncRead + AsyncWrite // end self.state = TlsState::Stream; - *pos = 0; data.clear(); stream.write(buf) }, @@ -293,3 +295,6 @@ where IO: AsyncRead + AsyncWrite self.io.flush() } } + +#[cfg(test)] +mod test_0rtt; diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs new file mode 100644 index 0000000..56c9d7b --- /dev/null +++ b/src/test_0rtt.rs @@ -0,0 +1,51 @@ +extern crate tokio; +extern crate webpki; +extern crate webpki_roots; + +use std::io; +use std::sync::Arc; +use std::net::ToSocketAddrs; +use self::tokio::io as aio; +use self::tokio::prelude::*; +use self::tokio::net::TcpStream; +use rustls::{ ClientConfig, ClientSession }; +use ::{ TlsConnector, TlsStream }; + + +fn get(config: Arc, domain: &str, rtt0: bool) + -> io::Result<(TlsStream, String)> +{ + let config = TlsConnector::from(config).early_data(rtt0); + let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); + + let addr = (domain, 443) + .to_socket_addrs()? + .next().unwrap(); + + TcpStream::connect(&addr) + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + config.connect(domain, stream) + }) + .and_then(move |stream| aio::write_all(stream, input)) + .and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new())) + .map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap())) + .wait() +} + +#[test] +fn test_0rtt() { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config.enable_early_data = true; + let config = Arc::new(config); + let domain = "mozilla-modern.badssl.com"; + + let (_, output) = get(config.clone(), domain, false).unwrap(); + assert!(output.contains("mozilla-modern.badssl.com")); + + let (io, output) = get(config.clone(), domain, true).unwrap(); + assert!(output.contains("mozilla-modern.badssl.com")); + + assert_eq!(io.early_data.0, 0); +} diff --git a/tests/test.rs b/tests/test.rs index 8833253..f0703f8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -66,17 +66,14 @@ fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { &*TEST_SERVER } -fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { +fn start_client(addr: &SocketAddr, domain: &str, config: Arc) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; const FILE: &'static [u8] = include_bytes!("../README.md"); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let mut config = ClientConfig::new(); - let mut chain = BufReader::new(Cursor::new(chain)); - config.root_store.add_pem_file(&mut chain).unwrap(); - let config = TlsConnector::from(Arc::new(config)); + let config = TlsConnector::from(config); let done = TcpStream::connect(addr) .and_then(|stream| config.connect(domain, stream)) @@ -95,13 +92,23 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> fn pass() { let (addr, domain, chain) = start_server(); - start_client(addr, domain, chain).unwrap(); + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + + start_client(addr, domain, config.clone()).unwrap(); } #[test] fn fail() { let (addr, domain, chain) = start_server(); + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + assert_ne!(domain, &"google.com"); - assert!(start_client(addr, "google.com", chain).is_err()); + assert!(start_client(addr, "google.com", config).is_err()); } From 527db99d02772ba339c82159eb52aab0f7ded154 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 18 Feb 2019 20:41:52 +0800 Subject: [PATCH 100/171] Improve for ServerSesssion --- src/client.rs | 196 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 203 ++++++---------------------------------------- src/server.rs | 139 +++++++++++++++++++++++++++++++ src/test_0rtt.rs | 6 +- src/tokio_impl.rs | 123 ---------------------------- 5 files changed, 362 insertions(+), 305 deletions(-) create mode 100644 src/client.rs create mode 100644 src/server.rs delete mode 100644 src/tokio_impl.rs diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..8d6758e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,196 @@ +use super::*; +use std::io::Write; +use rustls::Session; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ClientSession, + pub(crate) state: TlsState, + pub(crate) early_data: (usize, Vec) +} + +#[derive(Debug)] +pub(crate) enum TlsState { + EarlyData, + Stream, + Eof, + Shutdown +} + +pub(crate) enum MidHandshake { + Handshaking(TlsStream), + EarlyData(TlsStream), + End +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ClientSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ClientSession) { + (self.io, self.session) + } +} + +impl Future for MidHandshake +where IO: AsyncRead + AsyncWrite, +{ + type Item = TlsStream; + type Error = io::Error; + + #[inline] + fn poll(&mut self) -> Poll { + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); + + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } + }, + _ => () + } + + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) + | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } + } +} + +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); + stream.read(buf) + }, + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0), + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = early_data.write(buf)?; + data.extend_from_slice(&buf[..len]); + return Ok(len); + } + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); + stream.write(buf) + }, + _ => stream.write(buf) + } + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + } + self.io.shutdown() + } +} diff --git a/src/lib.rs b/src/lib.rs index dd34452..446e80d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,19 +8,19 @@ extern crate tokio_io; extern crate bytes; extern crate iovec; - mod common; -mod tokio_impl; +pub mod client; +pub mod server; -use std::mem; -use std::io::{ self, Write }; +use std::{ io, mem }; use std::sync::Arc; use webpki::DNSNameRef; use rustls::{ - Session, ClientSession, ServerSession, + ClientSession, ServerSession, ClientConfig, ServerConfig }; -use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::{Async, Future, Poll}; +use tokio_io::{ AsyncRead, AsyncWrite, try_nb }; use common::Stream; @@ -74,15 +74,15 @@ impl TlsConnector { f(&mut session); Connect(if self.early_data { - MidHandshake::EarlyData(TlsStream { + client::MidHandshake::EarlyData(client::TlsStream { session, io: stream, - state: TlsState::EarlyData, + state: client::TlsState::EarlyData, early_data: (0, Vec::new()) }) } else { - MidHandshake::Handshaking(TlsStream { + client::MidHandshake::Handshaking(client::TlsStream { session, io: stream, - state: TlsState::Stream, + state: client::TlsState::Stream, early_data: (0, Vec::new()) }) }) @@ -106,10 +106,9 @@ impl TlsAcceptor { let mut session = ServerSession::new(&self.inner); f(&mut session); - Accept(MidHandshake::Handshaking(TlsStream { + Accept(server::MidHandshake::Handshaking(server::TlsStream { session, io: stream, - state: TlsState::Stream, - early_data: (0, Vec::new()) + state: server::TlsState::Stream, })) } } @@ -117,182 +116,28 @@ impl TlsAcceptor { /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct Connect(MidHandshake); +pub struct Connect(client::MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct Accept(MidHandshake); - -enum MidHandshake { - Handshaking(TlsStream), - EarlyData(TlsStream), - End -} +pub struct Accept(server::MidHandshake); -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -#[derive(Debug)] -pub struct TlsStream { - io: IO, - session: S, - state: TlsState, - early_data: (usize, Vec) -} +impl Future for Connect { + type Item = client::TlsStream; + type Error = io::Error; -#[derive(Debug)] -enum TlsState { - EarlyData, - Stream, - Eof, - Shutdown -} - -impl TlsStream { - #[inline] - pub fn get_ref(&self) -> (&IO, &S) { - (&self.io, &self.session) - } - - #[inline] - pub fn get_mut(&mut self) -> (&mut IO, &mut S) { - (&mut self.io, &mut self.session) - } - - #[inline] - pub fn into_inner(self) -> (IO, S) { - (self.io, self.session) + fn poll(&mut self) -> Poll { + self.0.poll() } } -impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); +impl Future for Accept { + type Item = server::TlsStream; + type Error = io::Error; - match self.state { - TlsState::EarlyData => { - let (pos, data) = &mut self.early_data; - - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; - } - } - - // end - self.state = TlsState::Stream; - data.clear(); - stream.read(buf) - }, - TlsState::Stream => match stream.read(buf) { - Ok(0) => { - self.state = TlsState::Eof; - Ok(0) - }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); - Ok(0) - }, - Err(e) => Err(e) - }, - TlsState::Eof | TlsState::Shutdown => Ok(0), - } - } -} - -impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - - match self.state { - TlsState::Stream => match stream.read(buf) { - Ok(0) => { - self.state = TlsState::Eof; - Ok(0) - }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); - Ok(0) - }, - Err(e) => Err(e) - }, - TlsState::Eof | TlsState::Shutdown => Ok(0), - TlsState::EarlyData => unreachable!() - } - } -} - -impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - - match self.state { - TlsState::EarlyData => { - let (pos, data) = &mut self.early_data; - - // write early data - if let Some(mut early_data) = stream.session.early_data() { - let len = early_data.write(buf)?; - data.extend_from_slice(&buf[..len]); - return Ok(len); - } - - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; - } - } - - // end - self.state = TlsState::Stream; - data.clear(); - stream.write(buf) - }, - _ => stream.write(buf) - } - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() - } -} - -impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - stream.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() + fn poll(&mut self) -> Poll { + self.0.poll() } } diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..42dd18d --- /dev/null +++ b/src/server.rs @@ -0,0 +1,139 @@ +use super::*; +use rustls::Session; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ServerSession, + pub(crate) state: TlsState +} + +#[derive(Debug)] +pub(crate) enum TlsState { + Stream, + Eof, + Shutdown +} + +pub(crate) enum MidHandshake { + Handshaking(TlsStream), + End +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ServerSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ServerSession) { + (self.io, self.session) + } +} + +impl Future for MidHandshake +where IO: AsyncRead + AsyncWrite, +{ + type Item = TlsStream; + type Error = io::Error; + + #[inline] + fn poll(&mut self) -> Poll { + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); + + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } + }, + _ => () + } + + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } + } +} + +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0) + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + stream.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + } + self.io.shutdown() + } +} diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 56c9d7b..0182406 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -8,12 +8,12 @@ use std::net::ToSocketAddrs; use self::tokio::io as aio; use self::tokio::prelude::*; use self::tokio::net::TcpStream; -use rustls::{ ClientConfig, ClientSession }; -use ::{ TlsConnector, TlsStream }; +use rustls::ClientConfig; +use ::{ TlsConnector, client::TlsStream }; fn get(config: Arc, domain: &str, rtt0: bool) - -> io::Result<(TlsStream, String)> + -> io::Result<(TlsStream, String)> { let config = TlsConnector::from(config).early_data(rtt0); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs deleted file mode 100644 index f97cde3..0000000 --- a/src/tokio_impl.rs +++ /dev/null @@ -1,123 +0,0 @@ -use super::*; -use tokio_io::{ AsyncRead, AsyncWrite }; -use futures::{Async, Future, Poll}; -use common::Stream; - - -macro_rules! try_async { - ( $e:expr ) => { - match $e { - Ok(n) => n, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - return Ok(Async::NotReady), - Err(e) => return Err(e) - } - } -} - -impl Future for Connect { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -impl Future for Accept { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite, - S: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - match self { - MidHandshake::Handshaking(stream) => { - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); - - if stream.session.is_handshaking() { - try_async!(stream.complete_io()); - } - - if stream.session.wants_write() { - try_async!(stream.complete_io()); - } - }, - _ => () - } - - match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) - | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), - MidHandshake::End => panic!() - } - } -} - -impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite, -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - TlsState::Shutdown => (), - _ => { - self.session.send_close_notify(); - self.state = TlsState::Shutdown; - } - } - - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_async!(stream.complete_io()); - } - self.io.shutdown() - } -} - -impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite, -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - TlsState::Shutdown => (), - _ => { - self.session.send_close_notify(); - self.state = TlsState::Shutdown; - } - } - - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_async!(stream.complete_io()); - } - self.io.shutdown() - } -} From 681cbe68caadc07b6208d08812c781bb073becfe Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 23 Feb 2019 01:38:55 +0800 Subject: [PATCH 101/171] fix: not write zero --- src/client.rs | 8 +++----- src/common/mod.rs | 8 +++++++- src/lib.rs | 2 ++ src/server.rs | 8 +++----- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8d6758e..386d242 100644 --- a/src/client.rs +++ b/src/client.rs @@ -187,10 +187,8 @@ where IO: AsyncRead + AsyncWrite } } - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.complete_io()); - } - self.io.shutdown() + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + stream.io.shutdown() } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 9010d8d..a88a278 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -110,7 +110,13 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { Err(err) => return Err(err) } } - Ok(len) + + if len == 0 && !buf.is_empty() { + // not write zero + Err(io::ErrorKind::WouldBlock.into()) + } else { + Ok(len) + } } fn flush(&mut self) -> io::Result<()> { diff --git a/src/lib.rs b/src/lib.rs index 446e80d..511cdd9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,12 +24,14 @@ use tokio_io::{ AsyncRead, AsyncWrite, try_nb }; use common::Stream; +/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { inner: Arc, early_data: bool } +/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { inner: Arc diff --git a/src/server.rs b/src/server.rs index 42dd18d..fd02a35 100644 --- a/src/server.rs +++ b/src/server.rs @@ -130,10 +130,8 @@ where IO: AsyncRead + AsyncWrite, } } - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.complete_io()); - } - self.io.shutdown() + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + stream.io.shutdown() } } From 163a96b0623a24924a37cc01b7bc82c4a5639a4d Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 23 Feb 2019 01:48:09 +0800 Subject: [PATCH 102/171] fix clippy --- src/client.rs | 21 +++++++++------------ src/common/mod.rs | 4 ++-- src/common/vecbuf.rs | 2 +- src/server.rs | 21 +++++++++------------ 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/client.rs b/src/client.rs index 386d242..3e9d73f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,20 +52,17 @@ where IO: AsyncRead + AsyncWrite, #[inline] fn poll(&mut self) -> Poll { - match self { - MidHandshake::Handshaking(stream) => { - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + if let MidHandshake::Handshaking(stream) = self { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); - if stream.session.is_handshaking() { - try_nb!(stream.complete_io()); - } + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } - if stream.session.wants_write() { - try_nb!(stream.complete_io()); - } - }, - _ => () + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } } match mem::replace(self, MidHandshake::End) { diff --git a/src/common/mod.rs b/src/common/mod.rs index a88a278..c150189 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -84,8 +84,8 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream< } } - let mut vecbuf = V(self.io); - self.session.writev_tls(&mut vecbuf) + let mut vecio = V(self.io); + self.session.writev_tls(&mut vecio) } } diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs index 81bec86..e550505 100644 --- a/src/common/vecbuf.rs +++ b/src/common/vecbuf.rs @@ -48,7 +48,7 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { } } - #[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))] + #[allow(clippy::needless_range_loop)] fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { let len = cmp::min(self.inner.len() - self.pos, dst.len()); diff --git a/src/server.rs b/src/server.rs index fd02a35..67d47d3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,20 +48,17 @@ where IO: AsyncRead + AsyncWrite, #[inline] fn poll(&mut self) -> Poll { - match self { - MidHandshake::Handshaking(stream) => { - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + if let MidHandshake::Handshaking(stream) = self { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); - if stream.session.is_handshaking() { - try_nb!(stream.complete_io()); - } + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } - if stream.session.wants_write() { - try_nb!(stream.complete_io()); - } - }, - _ => () + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } } match mem::replace(self, MidHandshake::End) { From 02ff36428ce1680465514920f1e3d061e03a0b66 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 25 Feb 2019 00:23:34 +0800 Subject: [PATCH 103/171] write buf again --- src/common/mod.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index c150189..e4e25cb 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -111,11 +111,16 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { } } - if len == 0 && !buf.is_empty() { - // not write zero - Err(io::ErrorKind::WouldBlock.into()) - } else { + if len != 0 || buf.is_empty() { Ok(len) + } else { + // not write zero + self.session.write(buf) + .and_then(|len| if len != 0 { + Ok(len) + } else { + Err(io::ErrorKind::WouldBlock.into()) + }) } } From 485cf8463989c25f0faa8d66c9c3dfdb7bce0063 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 25 Feb 2019 23:48:06 +0800 Subject: [PATCH 104/171] make 0-RTT optional --- .travis.yml | 1 + Cargo.toml | 3 +++ appveyor.yml | 1 + src/client.rs | 16 +++++++++++----- src/lib.rs | 38 +++++++++++++++++++++++++++----------- 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3653f1f..9efee9d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ matrix: script: - cargo test + - cargo test --features early-data - cd examples/server - cargo check - cd ../../examples/client diff --git a/Cargo.toml b/Cargo.toml index b15bea2..ff95d5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ iovec = "0.1" rustls = "0.15" webpki = "0.19" +[features] +early-data = [] + [dev-dependencies] tokio = "0.1.6" lazy_static = "1" diff --git a/appveyor.yml b/appveyor.yml index 038274b..26db365 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -14,6 +14,7 @@ build: false test_script: - 'cargo test' + - 'cargo test --features early-data' - 'cd examples/server' - 'cargo check' - 'cd ../../examples/client' diff --git a/src/client.rs b/src/client.rs index 3e9d73f..91a65aa 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,4 @@ use super::*; -use std::io::Write; use rustls::Session; @@ -10,12 +9,14 @@ pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ClientSession, pub(crate) state: TlsState, + + #[cfg(feature = "early-data")] pub(crate) early_data: (usize, Vec) } #[derive(Debug)] pub(crate) enum TlsState { - EarlyData, + #[cfg(feature = "early-data")] EarlyData, Stream, Eof, Shutdown @@ -23,7 +24,7 @@ pub(crate) enum TlsState { pub(crate) enum MidHandshake { Handshaking(TlsStream), - EarlyData(TlsStream), + #[cfg(feature = "early-data")] EarlyData(TlsStream), End } @@ -66,8 +67,9 @@ where IO: AsyncRead + AsyncWrite, } match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) - | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + #[cfg(feature = "early-data")] + MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), MidHandshake::End => panic!() } } @@ -80,7 +82,10 @@ where IO: AsyncRead + AsyncWrite let mut stream = Stream::new(&mut self.io, &mut self.session); match self.state { + #[cfg(feature = "early-data")] TlsState::EarlyData => { + use std::io::Write; + let (pos, data) = &mut self.early_data; // complete handshake @@ -126,6 +131,7 @@ where IO: AsyncRead + AsyncWrite let mut stream = Stream::new(&mut self.io, &mut self.session); match self.state { + #[cfg(feature = "early-data")] TlsState::EarlyData => { let (pos, data) = &mut self.early_data; diff --git a/src/lib.rs b/src/lib.rs index 511cdd9..6a77fbb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ use common::Stream; #[derive(Clone)] pub struct TlsConnector { inner: Arc, + #[cfg(feature = "early-data")] early_data: bool } @@ -39,7 +40,11 @@ pub struct TlsAcceptor { impl From> for TlsConnector { fn from(inner: Arc) -> TlsConnector { - TlsConnector { inner, early_data: false } + TlsConnector { + inner, + #[cfg(feature = "early-data")] + early_data: false + } } } @@ -54,6 +59,7 @@ impl TlsConnector { /// /// Note that you want to use 0-RTT. /// You must set `enable_early_data` to `true` in `ClientConfig`. + #[cfg(feature = "early-data")] pub fn early_data(mut self, flag: bool) -> TlsConnector { self.early_data = flag; self @@ -75,19 +81,28 @@ impl TlsConnector { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - Connect(if self.early_data { - client::MidHandshake::EarlyData(client::TlsStream { - session, io: stream, - state: client::TlsState::EarlyData, - early_data: (0, Vec::new()) - }) - } else { - client::MidHandshake::Handshaking(client::TlsStream { + #[cfg(not(feature = "early-data"))] { + Connect(client::MidHandshake::Handshaking(client::TlsStream { session, io: stream, state: client::TlsState::Stream, - early_data: (0, Vec::new()) + })) + } + + #[cfg(feature = "early-data")] { + Connect(if self.early_data { + client::MidHandshake::EarlyData(client::TlsStream { + session, io: stream, + state: client::TlsState::EarlyData, + early_data: (0, Vec::new()) + }) + } else { + client::MidHandshake::Handshaking(client::TlsStream { + session, io: stream, + state: client::TlsState::Stream, + early_data: (0, Vec::new()) + }) }) - }) + } } } @@ -143,5 +158,6 @@ impl Future for Accept { } } +#[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; From ee59a7cc8e4638bb2c2a50fc6b79765336ca385b Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Tue, 12 Mar 2019 09:40:54 -0700 Subject: [PATCH 105/171] Clarify the license Before this PR, this phrasing appears to have been copied from the [rust project](https://github.com/rust-lang/rust/tree/f0fe716dbcbf2363ab8f929325d32a17e51039d0#license), but it does not appear that any of your code is BSD licensed. Also, it is a little ambigious if there are portions that are not covered by MIT/Apache. This patch instead draws it's phrasing from [futures-rs](https://github.com/rust-lang-nursery/futures-rs). Does this phrasing align more with what you intended? Closes #29 --- README.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f22e97e..d1e7a30 100644 --- a/README.md +++ b/README.md @@ -48,8 +48,19 @@ cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der ### License & Origin -tokio-rustls is primarily distributed under the terms of both the [MIT license](LICENSE-MIT) and -the [Apache License (Version 2.0)](LICENSE-APACHE), with portions covered by various BSD-like -licenses. +This project is licensed under either of + + * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or + http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or + http://opensource.org/licenses/MIT) + +at your option. This started as a fork of [tokio-tls](https://github.com/tokio-rs/tokio-tls). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in tokio-rustls by you, as defined in the Apache-2.0 license, shall be +dual licensed as above, without any additional terms or conditions. From d8ab52db551e4d4080b73cf3e338e9f07d265fdd Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 26 Mar 2019 10:44:38 +0800 Subject: [PATCH 106/171] fix early-data read --- Cargo.toml | 2 +- src/client.rs | 64 ++++++++++++++++++++++++++++----------------------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ff95d5f..1082c02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.10.0-alpha" +version = "0.10.0-alpha.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/client.rs b/src/client.rs index 91a65aa..27ab944 100644 --- a/src/client.rs +++ b/src/client.rs @@ -79,45 +79,51 @@ impl io::Read for TlsStream where IO: AsyncRead + AsyncWrite { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { use std::io::Write; - let (pos, data) = &mut self.early_data; + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + let (pos, data) = &mut self.early_data; - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); } - // end - self.state = TlsState::Stream; - data.clear(); - stream.read(buf) + self.read(buf) }, - TlsState::Stream => match stream.read(buf) { - Ok(0) => { - self.state = TlsState::Eof; - Ok(0) - }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); - Ok(0) - }, - Err(e) => Err(e) + TlsState::Stream => { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + } }, TlsState::Eof | TlsState::Shutdown => Ok(0), } From 75b94acaba72fcab4eaf990dccfe256a3a20053d Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 26 Mar 2019 10:46:22 +0800 Subject: [PATCH 107/171] remove git journal --- .gitjournal.toml | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .gitjournal.toml diff --git a/.gitjournal.toml b/.gitjournal.toml deleted file mode 100644 index 508a97e..0000000 --- a/.gitjournal.toml +++ /dev/null @@ -1,10 +0,0 @@ -categories = ["Added", "Changed", "Fixed", "Improved", "Removed"] -category_delimiters = ["[", "]"] -colored_output = true -enable_debug = true -enable_footers = false -excluded_commit_tags = [] -show_commit_hash = false -show_prefix = false -sort_by = "date" -template_prefix = "" From 0485be9e4bb9bea80398d5a2c220ec21a5a96443 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 16 Apr 2019 20:32:11 +0800 Subject: [PATCH 108/171] refactor complete_io --- src/client.rs | 3 +- src/common/mod.rs | 103 +++++++++++++++++++++++++++----------- src/common/test_stream.rs | 2 +- 3 files changed, 76 insertions(+), 32 deletions(-) diff --git a/src/client.rs b/src/client.rs index 27ab944..c4d93ee 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,5 @@ use super::*; +use std::io::Write; use rustls::Session; @@ -197,7 +198,7 @@ where IO: AsyncRead + AsyncWrite } let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.complete_io()); + try_nb!(stream.flush()); stream.io.shutdown() } } diff --git a/src/common/mod.rs b/src/common/mod.rs index e4e25cb..14d2f71 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -15,51 +15,94 @@ pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { fn write_tls(&mut self) -> io::Result; } +#[derive(Clone, Copy)] +enum Focus { + Empty, + Readable, + Writable +} + impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { io, session } } pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { - // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 + self.complete_inner_io(Focus::Empty) + } - let until_handshaked = self.session.is_handshaking(); - let mut eof = false; + fn complete_read_io(&mut self) -> io::Result { + let n = self.session.read_tls(self.io)?; + + self.session.process_new_packets() + .map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_tls(); + + io::Error::new(io::ErrorKind::InvalidData, err) + })?; + + Ok(n) + } + + fn complete_write_io(&mut self) -> io::Result { + self.write_tls() + } + + fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> { let mut wrlen = 0; let mut rdlen = 0; + let mut eof = false; loop { + let mut write_would_block = false; + let mut read_would_block = false; + while self.session.wants_write() { - wrlen += self.write_tls()?; - } - - if !until_handshaked && wrlen > 0 { - return Ok((rdlen, wrlen)); - } - - if !eof && self.session.wants_read() { - match self.session.read_tls(self.io)? { - 0 => eof = true, - n => rdlen += n + match self.complete_write_io() { + Ok(n) => wrlen += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + write_would_block = true; + break + }, + Err(err) => return Err(err) } } - match self.session.process_new_packets() { - Ok(_) => {}, - Err(e) => { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ignored = self.write_tls(); + if !eof && self.session.wants_read() { + match self.complete_read_io() { + Ok(0) => eof = true, + Ok(n) => rdlen += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true, + Err(err) => return Err(err) + } + } - return Err(io::Error::new(io::ErrorKind::InvalidData, e)); - }, + let would_block = match focus { + Focus::Empty => write_would_block || read_would_block, + Focus::Readable => read_would_block, + Focus::Writable => write_would_block, }; - match (eof, until_handshaked, self.session.is_handshaking()) { - (_, true, false) => return Ok((rdlen, wrlen)), + match (eof, self.session.is_handshaking(), would_block) { + (true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()), + (_, false, true) => { + let would_block = match focus { + Focus::Empty => rdlen == 0 && wrlen == 0, + Focus::Readable => rdlen == 0, + Focus::Writable => wrlen == 0 + }; + + return if would_block { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Ok((rdlen, wrlen)) + }; + }, (_, false, _) => return Ok((rdlen, wrlen)), - (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (_, true, true) => return Err(io::ErrorKind::WouldBlock.into()), (..) => () } } @@ -92,7 +135,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream< impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { - if let (0, 0) = self.complete_io()? { + if let (0, _) = self.complete_inner_io(Focus::Readable)? { break } } @@ -104,7 +147,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; while self.session.wants_write() { - match self.complete_io() { + match self.complete_inner_io(Focus::Writable) { Ok(_) => (), Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, Err(err) => return Err(err) @@ -126,8 +169,8 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - if self.session.wants_write() { - self.complete_io()?; + while self.session.wants_write() { + self.complete_inner_io(Focus::Writable)?; } Ok(()) } diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index a43622c..744758a 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -85,7 +85,7 @@ fn stream_good() -> io::Result<()> { let mut buf = Vec::new(); stream.read_to_end(&mut buf)?; assert_eq!(buf, FILE); - stream.write_all(b"Hello World!")? + stream.write_all(b"Hello World!")?; } let mut buf = String::new(); From b1a98b908872ddb3762b5dfa7cd960955da4d115 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 17 Apr 2019 10:33:00 +0800 Subject: [PATCH 109/171] bump version fix #32 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 1082c02..98d70b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.10.0-alpha.1" +version = "0.10.0-alpha.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From b6e39450ce48d4b19c0095952df867e31fe5a51d Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 17 Apr 2019 10:42:57 +0800 Subject: [PATCH 110/171] fix clippy --- src/client.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index c4d93ee..9d57268 100644 --- a/src/client.rs +++ b/src/client.rs @@ -83,8 +83,6 @@ where IO: AsyncRead + AsyncWrite match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { - use std::io::Write; - { let mut stream = Stream::new(&mut self.io, &mut self.session); let (pos, data) = &mut self.early_data; From 87916dade66f1a30a2440396cd07eacbb6c9e89a Mon Sep 17 00:00:00 2001 From: Yan Zhai Date: Fri, 19 Apr 2019 21:08:18 +0000 Subject: [PATCH 111/171] #34 properly implement TLS-1.3 shutdown behavior --- src/client.rs | 66 +++++++++++++++---------------- src/lib.rs | 107 +++++++++++++++++++++++++++++++++----------------- src/server.rs | 54 +++++++++++++------------ 3 files changed, 131 insertions(+), 96 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9d57268..9961503 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,6 @@ use super::*; -use std::io::Write; use rustls::Session; - +use std::io::Write; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -12,21 +11,14 @@ pub struct TlsStream { pub(crate) state: TlsState, #[cfg(feature = "early-data")] - pub(crate) early_data: (usize, Vec) -} - -#[derive(Debug)] -pub(crate) enum TlsState { - #[cfg(feature = "early-data")] EarlyData, - Stream, - Eof, - Shutdown + pub(crate) early_data: (usize, Vec), } pub(crate) enum MidHandshake { Handshaking(TlsStream), - #[cfg(feature = "early-data")] EarlyData(TlsStream), - End + #[cfg(feature = "early-data")] + EarlyData(TlsStream), + End, } impl TlsStream { @@ -47,7 +39,8 @@ impl TlsStream { } impl Future for MidHandshake -where IO: AsyncRead + AsyncWrite, +where + IO: AsyncRead + AsyncWrite, { type Item = TlsStream; type Error = io::Error; @@ -71,13 +64,14 @@ where IO: AsyncRead + AsyncWrite, MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), #[cfg(feature = "early-data")] MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), - MidHandshake::End => panic!() + MidHandshake::End => panic!(), } } } impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self.state { @@ -106,31 +100,35 @@ where IO: AsyncRead + AsyncWrite } self.read(buf) - }, - TlsState::Stream => { + } + TlsState::Stream | TlsState::WriteShutdown => { let mut stream = Stream::new(&mut self.io, &mut self.session); match stream.read(buf) { Ok(0) => { - self.state = TlsState::Eof; + self.state.shutdown_read(); Ok(0) - }, + } Ok(n) => Ok(n), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); + self.state.shutdown_read(); + if self.state.writeable() { + stream.session.send_close_notify(); + self.state.shutdown_write(); + } Ok(0) - }, - Err(e) => Err(e) + } + Err(e) => Err(e), } - }, - TlsState::Eof | TlsState::Shutdown => Ok(0), + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), } } } impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn write(&mut self, buf: &[u8]) -> io::Result { let mut stream = Stream::new(&mut self.io, &mut self.session); @@ -164,8 +162,8 @@ where IO: AsyncRead + AsyncWrite self.state = TlsState::Stream; data.clear(); stream.write(buf) - }, - _ => stream.write(buf) + } + _ => stream.write(buf), } } @@ -176,7 +174,8 @@ where IO: AsyncRead + AsyncWrite } impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false @@ -184,14 +183,15 @@ where IO: AsyncRead + AsyncWrite } impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { match self.state { - TlsState::Shutdown => (), + s if !s.writeable() => (), _ => { self.session.send_close_notify(); - self.state = TlsState::Shutdown; + self.state.shutdown_write(); } } diff --git a/src/lib.rs b/src/lib.rs index 6a77fbb..b77fee4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,39 +3,68 @@ pub extern crate rustls; pub extern crate webpki; -extern crate futures; -extern crate tokio_io; extern crate bytes; +extern crate futures; extern crate iovec; +extern crate tokio_io; -mod common; pub mod client; +mod common; pub mod server; -use std::{ io, mem }; -use std::sync::Arc; -use webpki::DNSNameRef; -use rustls::{ - ClientSession, ServerSession, - ClientConfig, ServerConfig -}; -use futures::{Async, Future, Poll}; -use tokio_io::{ AsyncRead, AsyncWrite, try_nb }; use common::Stream; +use futures::{Async, Future, Poll}; +use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use std::sync::Arc; +use std::{io, mem}; +use tokio_io::{try_nb, AsyncRead, AsyncWrite}; +use webpki::DNSNameRef; +#[derive(Debug, Copy, Clone)] +pub enum TlsState { + #[cfg(feature = "early-data")] + EarlyData, + Stream, + ReadShutdown, + WriteShutdown, + FullyShutdown, +} + +impl TlsState { + pub(crate) fn shutdown_read(&mut self) { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::ReadShutdown, + } + } + + pub(crate) fn shutdown_write(&mut self) { + match *self { + TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::WriteShutdown, + } + } + + pub(crate) fn writeable(&self) -> bool { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => true, + _ => false, + } + } +} /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { inner: Arc, #[cfg(feature = "early-data")] - early_data: bool + early_data: bool, } /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { - inner: Arc + inner: Arc, } impl From> for TlsConnector { @@ -43,7 +72,7 @@ impl From> for TlsConnector { TlsConnector { inner, #[cfg(feature = "early-data")] - early_data: false + early_data: false, } } } @@ -66,40 +95,45 @@ impl TlsConnector { } pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect - where IO: AsyncRead + AsyncWrite + where + IO: AsyncRead + AsyncWrite, { self.connect_with(domain, stream, |_| ()) } #[inline] - pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) - -> Connect + pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite, - F: FnOnce(&mut ClientSession) + F: FnOnce(&mut ClientSession), { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - #[cfg(not(feature = "early-data"))] { + #[cfg(not(feature = "early-data"))] + { Connect(client::MidHandshake::Handshaking(client::TlsStream { - session, io: stream, - state: client::TlsState::Stream, + session, + io: stream, + state: TlsState::Stream, })) } - #[cfg(feature = "early-data")] { + #[cfg(feature = "early-data")] + { Connect(if self.early_data { client::MidHandshake::EarlyData(client::TlsStream { - session, io: stream, - state: client::TlsState::EarlyData, - early_data: (0, Vec::new()) + session, + io: stream, + state: TlsState::EarlyData, + early_data: (0, Vec::new()), }) } else { client::MidHandshake::Handshaking(client::TlsStream { - session, io: stream, - state: client::TlsState::Stream, - early_data: (0, Vec::new()) + session, + io: stream, + state: TlsState::Stream, + early_data: (0, Vec::new()), }) }) } @@ -108,29 +142,29 @@ impl TlsConnector { impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept - where IO: AsyncRead + AsyncWrite, + where + IO: AsyncRead + AsyncWrite, { self.accept_with(stream, |_| ()) } #[inline] - pub fn accept_with(&self, stream: IO, f: F) - -> Accept + pub fn accept_with(&self, stream: IO, f: F) -> Accept where IO: AsyncRead + AsyncWrite, - F: FnOnce(&mut ServerSession) + F: FnOnce(&mut ServerSession), { let mut session = ServerSession::new(&self.inner); f(&mut session); Accept(server::MidHandshake::Handshaking(server::TlsStream { - session, io: stream, - state: server::TlsState::Stream, + session, + io: stream, + state: TlsState::Stream, })) } } - /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. pub struct Connect(client::MidHandshake); @@ -139,7 +173,6 @@ pub struct Connect(client::MidHandshake); /// once the accept handshake has finished. pub struct Accept(server::MidHandshake); - impl Future for Connect { type Item = client::TlsStream; type Error = io::Error; diff --git a/src/server.rs b/src/server.rs index 67d47d3..e6f5701 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,26 +1,18 @@ use super::*; use rustls::Session; - /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ServerSession, - pub(crate) state: TlsState -} - -#[derive(Debug)] -pub(crate) enum TlsState { - Stream, - Eof, - Shutdown + pub(crate) state: TlsState, } pub(crate) enum MidHandshake { Handshaking(TlsStream), - End + End, } impl TlsStream { @@ -41,7 +33,8 @@ impl TlsStream { } impl Future for MidHandshake -where IO: AsyncRead + AsyncWrite, +where + IO: AsyncRead + AsyncWrite, { type Item = TlsStream; type Error = io::Error; @@ -63,38 +56,45 @@ where IO: AsyncRead + AsyncWrite, match mem::replace(self, MidHandshake::End) { MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), - MidHandshake::End => panic!() + MidHandshake::End => panic!(), } } } impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn read(&mut self, buf: &mut [u8]) -> io::Result { let mut stream = Stream::new(&mut self.io, &mut self.session); match self.state { - TlsState::Stream => match stream.read(buf) { + TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) { Ok(0) => { - self.state = TlsState::Eof; + self.state.shutdown_read(); Ok(0) - }, + } Ok(n) => Ok(n), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); + self.state.shutdown_read(); + if self.state.writeable() { + stream.session.send_close_notify(); + self.state.shutdown_write(); + } Ok(0) - }, - Err(e) => Err(e) + } + Err(e) => Err(e), }, - TlsState::Eof | TlsState::Shutdown => Ok(0) + TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), + #[cfg(feature = "early-data")] + s => unreachable!("server TLS can not hit this state: {:?}", s), } } } impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn write(&mut self, buf: &[u8]) -> io::Result { let mut stream = Stream::new(&mut self.io, &mut self.session); @@ -108,7 +108,8 @@ where IO: AsyncRead + AsyncWrite } impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false @@ -116,14 +117,15 @@ where IO: AsyncRead + AsyncWrite } impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite, +where + IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { match self.state { - TlsState::Shutdown => (), + s if !s.writeable() => (), _ => { self.session.send_close_notify(); - self.state = TlsState::Shutdown; + self.state.shutdown_write(); } } From 0cbd252ee49beb90ce60c22540128e36518ab6e0 Mon Sep 17 00:00:00 2001 From: Yan Zhai Date: Mon, 22 Apr 2019 04:08:13 +0000 Subject: [PATCH 112/171] #34 minor style change --- src/client.rs | 9 +++------ src/server.rs | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9961503..616c151 100644 --- a/src/client.rs +++ b/src/client.rs @@ -187,12 +187,9 @@ where IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - s if !s.writeable() => (), - _ => { - self.session.send_close_notify(); - self.state.shutdown_write(); - } + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); } let mut stream = Stream::new(&mut self.io, &mut self.session); diff --git a/src/server.rs b/src/server.rs index e6f5701..1568414 100644 --- a/src/server.rs +++ b/src/server.rs @@ -121,12 +121,9 @@ where IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - s if !s.writeable() => (), - _ => { - self.session.send_close_notify(); - self.state.shutdown_write(); - } + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); } let mut stream = Stream::new(&mut self.io, &mut self.session); From f3c9fece1b785fda2a32f0cba2ead2f6800ecc14 Mon Sep 17 00:00:00 2001 From: Yan Zhai Date: Mon, 22 Apr 2019 04:44:24 +0000 Subject: [PATCH 113/171] #34 writable condition reversed --- src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b77fee4..04d7421 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,8 +47,8 @@ impl TlsState { pub(crate) fn writeable(&self) -> bool { match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => true, - _ => false, + TlsState::WriteShutdown | TlsState::FullyShutdown => false, + _ => true, } } } From 017b1b64d18d80875909e7bf0f09c7e2529f0989 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 4 May 2019 22:44:40 +0800 Subject: [PATCH 114/171] start migrate to futures 0.3 (again) --- Cargo.toml | 5 +- src/common/mod.rs | 173 ++++++++++++++++++++++++++++------------------ src/lib.rs | 14 ++-- 3 files changed, 114 insertions(+), 78 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98d70b8..5a61dcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,14 +9,15 @@ documentation = "https://docs.rs/tokio-rustls" readme = "README.md" description = "Asynchronous TLS/SSL streams for Tokio using Rustls." categories = ["asynchronous", "cryptography", "network-programming"] +edition = "2018" [badges] travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = "0.1" -tokio-io = "0.1.6" +smallvec = "*" +futures = { package = "futures-preview", version = "0.3.0-alpha.15" } bytes = "0.4" iovec = "0.1" rustls = "0.15" diff --git a/src/common/mod.rs b/src/common/mod.rs index 14d2f71..ed29f09 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,18 +1,23 @@ -mod vecbuf; +// mod vecbuf; +use std::pin::Pin; +use std::task::Poll; +use std::marker::Unpin; use std::io::{ self, Read, Write }; use rustls::Session; use rustls::WriteV; -use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::task::Context; +use futures::io::{ AsyncRead, AsyncWrite, IoVec }; +use smallvec::SmallVec; -pub struct Stream<'a, IO: 'a, S: 'a> { +pub struct Stream<'a, IO, S> { pub io: &'a mut IO, - pub session: &'a mut S + pub session: &'a mut S, } -pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { - fn write_tls(&mut self) -> io::Result; +pub trait WriteTls { + fn write_tls(&mut self, cx: &mut Context) -> io::Result; } #[derive(Clone, Copy)] @@ -22,36 +27,59 @@ enum Focus { Writable } -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { io, session } } - pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { - self.complete_inner_io(Focus::Empty) + pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { + self.complete_inner_io(cx, Focus::Empty) } - fn complete_read_io(&mut self) -> io::Result { - let n = self.session.read_tls(self.io)?; + fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { + struct Reader<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match Pin::new(&mut self.io).poll_read(self.cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut reader = Reader { io: self.io, cx }; + + let n = match self.session.read_tls(&mut reader) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; self.session.process_new_packets() .map_err(|err| { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_tls(); + let _ = self.write_tls(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; - Ok(n) + Poll::Ready(Ok(n)) } - fn complete_write_io(&mut self) -> io::Result { - self.write_tls() + fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { + match self.write_tls(cx) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } } - fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> { + fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; let mut eof = false; @@ -61,22 +89,22 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { let mut read_would_block = false; while self.session.wants_write() { - match self.complete_write_io() { - Ok(n) => wrlen += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + match self.complete_write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { write_would_block = true; break }, - Err(err) => return Err(err) + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } if !eof && self.session.wants_read() { - match self.complete_read_io() { - Ok(0) => eof = true, - Ok(n) => rdlen += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true, - Err(err) => return Err(err) + match self.complete_read_io(cx) { + Poll::Ready(Ok(0)) => eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } @@ -87,7 +115,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { }; match (eof, self.session.is_handshaking(), would_block) { - (true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()), + (true, true, _) => return Poll::Pending, (_, false, true) => { let would_block = match focus { Focus::Empty => rdlen == 0 && wrlen == 0, @@ -96,83 +124,96 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { }; return if would_block { - Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending } else { - Ok((rdlen, wrlen)) + Poll::Ready(Ok((rdlen, wrlen))) }; }, - (_, false, _) => return Ok((rdlen, wrlen)), - (_, true, true) => return Err(io::ErrorKind::WouldBlock.into()), + (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => return Poll::Pending, (..) => () } } } } -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> { - fn write_tls(&mut self) -> io::Result { - use futures::Async; - use self::vecbuf::VecBuf; +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { + fn write_tls(&mut self, cx: &mut Context) -> io::Result { + struct Writer<'a, 'b, IO> { + io: &'a mut IO, + cx: &'a mut Context<'b> + } - struct V<'a, IO: 'a>(&'a mut IO); - - impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + impl<'a, 'b, IO: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, IO> { fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { - let mut vbytes = VecBuf::new(vbytes); - match self.0.write_buf(&mut vbytes) { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), - Err(err) => Err(err) + let vbytes = vbytes + .into_iter() + .try_fold(SmallVec::<[&'_ IoVec; 16]>::new(), |mut sum, next| { + sum.push(IoVec::from_bytes(next)?); + Some(sum) + }) + .unwrap_or_default(); + + match Pin::new(&mut self.io).poll_vectored_write(self.cx, &vbytes) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) } } } - let mut vecio = V(self.io); + let mut vecio = Writer { io: self.io, cx }; self.session.writev_tls(&mut vecio) } } -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { - fn read(&mut self, buf: &mut [u8]) -> io::Result { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { + fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { while self.session.wants_read() { - if let (0, _) = self.complete_inner_io(Focus::Readable)? { - break + match self.complete_inner_io(cx, Focus::Readable) { + Poll::Ready(Ok((0, _))) => break, + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - self.session.read(buf) - } -} -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { - fn write(&mut self, buf: &[u8]) -> io::Result { + // FIXME rustls always ready ? + Poll::Ready(self.session.read(buf)) + } + + fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { let len = self.session.write(buf)?; while self.session.wants_write() { - match self.complete_inner_io(Focus::Writable) { - Ok(_) => (), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, - Err(err) => return Err(err) + match self.complete_inner_io(cx, Focus::Writable) { + Poll::Ready(Ok(_)) => (), + Poll::Pending if len != 0 => break, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } if len != 0 || buf.is_empty() { - Ok(len) + Poll::Ready(Ok(len)) } else { // not write zero - self.session.write(buf) - .and_then(|len| if len != 0 { - Ok(len) - } else { - Err(io::ErrorKind::WouldBlock.into()) - }) + match self.session.write(buf) { + Ok(0) => Poll::Pending, + Ok(n) => Poll::Ready(Ok(n)), + Err(err) => Poll::Ready(Err(err)) + } } } - fn flush(&mut self) -> io::Result<()> { + fn poll_flush(&mut self, cx: &mut Context) -> Poll> { self.session.flush()?; while self.session.wants_write() { - self.complete_inner_io(Focus::Writable)?; + match self.complete_inner_io(cx, Focus::Writable) { + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } } - Ok(()) + Poll::Ready(Ok(())) } } diff --git a/src/lib.rs b/src/lib.rs index 04d7421..9e15ed7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,10 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -pub extern crate rustls; -pub extern crate webpki; - -extern crate bytes; -extern crate futures; -extern crate iovec; -extern crate tokio_io; - -pub mod client; +// pub mod client; mod common; -pub mod server; +// pub mod server; +/* use common::Stream; use futures::{Async, Future, Poll}; use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; @@ -194,3 +187,4 @@ impl Future for Accept { #[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; +*/ From 41c26ee63a63974bd8e2a82932023664a86568e2 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 18 May 2019 16:05:10 +0800 Subject: [PATCH 115/171] wip client --- Cargo.toml | 2 +- src/client.rs | 196 ++++++++++++++++++++++------------------------ src/common/mod.rs | 29 ++++--- src/lib.rs | 20 ++++- 4 files changed, 132 insertions(+), 115 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5a61dcb..f892061 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.10.0-alpha.2" +version = "0.12.0-alpha" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/client.rs b/src/client.rs index 616c151..613cd69 100644 --- a/src/client.rs +++ b/src/client.rs @@ -40,160 +40,154 @@ impl TlsStream { impl Future for MidHandshake where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - type Item = TlsStream; - type Error = io::Error; + type Output = io::Result>; #[inline] - fn poll(&mut self) -> Poll { - if let MidHandshake::Handshaking(stream) = self { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let MidHandshake::Handshaking(stream) = &mut *self { let (io, session) = stream.get_mut(); let mut stream = Stream::new(io, session); if stream.session.is_handshaking() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } if stream.session.wants_write() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } } - match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + match mem::replace(&mut *self, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), #[cfg(feature = "early-data")] - MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), MidHandshake::End => panic!(), } } } -impl io::Read for TlsStream +impl AsyncRead for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.state { - #[cfg(feature = "early-data")] - TlsState::EarlyData => { - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - let (pos, data) = &mut self.early_data; - - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; - } - } - - // end - self.state = TlsState::Stream; - data.clear(); - } - - self.read(buf) - } - TlsState::Stream | TlsState::WriteShutdown => { - let mut stream = Stream::new(&mut self.io, &mut self.session); - - match stream.read(buf) { - Ok(0) => { - self.state.shutdown_read(); - Ok(0) - } - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state.shutdown_read(); - if self.state.writeable() { - stream.session.send_close_notify(); - self.state.shutdown_write(); - } - Ok(0) - } - Err(e) => Err(e), - } - } - TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), - } + unsafe fn initializer(&self) -> Initializer { + // TODO + Initializer::nop() } -} - -impl io::Write for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { - let (pos, data) = &mut self.early_data; + let this = self.get_mut(); - // write early data - if let Some(mut early_data) = stream.session.early_data() { - let len = early_data.write(buf)?; - data.extend_from_slice(&buf[..len]); - return Ok(len); - } + let mut stream = Stream::new(&mut this.io, &mut this.session); + let (pos, data) = &mut this.early_data; // complete handshake if stream.session.is_handshaking() { - stream.complete_io()?; + try_ready!(stream.complete_io(cx)); } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = stream.write(&data[*pos..])?; + let len = try_ready!(stream.poll_write(cx, &data[*pos..])); *pos += len; } } // end - self.state = TlsState::Stream; + this.state = TlsState::Stream; data.clear(); - stream.write(buf) + + Pin::new(this).poll_read(cx, buf) } - _ => stream.write(buf), + TlsState::Stream | TlsState::WriteShutdown => { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session); + + match stream.poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) + } + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { + stream.session.send_close_notify(); + this.state.shutdown_write(); + } + Poll::Ready(Ok(0)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), } } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() - } -} - -impl AsyncRead for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } } impl AsyncWrite for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn shutdown(&mut self) -> Poll<(), io::Error> { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session); + + match this.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData => { + let (pos, data) = &mut this.early_data; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = early_data.write(buf)?; // TODO check pending + data.extend_from_slice(&buf[..len]); + return Poll::Ready(Ok(len)); + } + + // complete handshake + if stream.session.is_handshaking() { + try_ready!(stream.complete_io(cx)); + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = try_ready!(stream.poll_write(cx, &data[*pos..])); + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + data.clear(); + stream.poll_write(cx, buf) + } + _ => stream.poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + Stream::new(&mut this.io, &mut this.session).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.flush()); - stream.io.shutdown() + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session); + try_ready!(stream.poll_flush(cx)); + Pin::new(&mut this.io).poll_close(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index ed29f09..71b442f 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,6 +14,7 @@ use smallvec::SmallVec; pub struct Stream<'a, IO, S> { pub io: &'a mut IO, pub session: &'a mut S, + pub eof: bool } pub trait WriteTls { @@ -29,7 +30,18 @@ enum Focus { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { - Stream { io, session } + Stream { + io, + session, + // The state so far is only used to detect EOF, so either Stream + // or EarlyData state should both be all right. + eof: false, + } + } + + pub fn set_eof(mut self, eof: bool) -> Self { + self.eof = eof; + self } pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { @@ -82,7 +94,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; - let mut eof = false; loop { let mut write_would_block = false; @@ -99,9 +110,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - if !eof && self.session.wants_read() { + if !self.eof && self.session.wants_read() { match self.complete_read_io(cx) { - Poll::Ready(Ok(0)) => eof = true, + Poll::Ready(Ok(0)) => self.eof = true, Poll::Ready(Ok(n)) => rdlen += n, Poll::Pending => read_would_block = true, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -114,7 +125,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Focus::Writable => write_would_block, }; - match (eof, self.session.is_handshaking(), would_block) { + match (self.eof, self.session.is_handshaking(), would_block) { (true, true, _) => return Poll::Pending, (_, false, true) => { let would_block = match focus { @@ -167,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { - fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { + pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { while self.session.wants_read() { match self.complete_inner_io(cx, Focus::Readable) { Poll::Ready(Ok((0, _))) => break, @@ -181,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(self.session.read(buf)) } - fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { + pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { let len = self.session.write(buf)?; while self.session.wants_write() { match self.complete_inner_io(cx, Focus::Writable) { @@ -204,7 +215,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - fn poll_flush(&mut self, cx: &mut Context) -> Poll> { + pub fn poll_flush(&mut self, cx: &mut Context) -> Poll> { self.session.flush()?; while self.session.wants_write() { match self.complete_inner_io(cx, Focus::Writable) { @@ -213,7 +224,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - Poll::Ready(Ok(())) + Pin::new(&mut self.io).poll_flush(cx) } } diff --git a/src/lib.rs b/src/lib.rs index 9e15ed7..19f35dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,27 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -// pub mod client; +macro_rules! try_ready { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(output)) => output, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending + } + } +} + +pub mod client; mod common; // pub mod server; -/* use common::Stream; -use futures::{Async, Future, Poll}; +use std::pin::Pin; +use std::task::{ Poll, Context }; +use std::future::Future; +use futures::io::{ AsyncRead, AsyncWrite, Initializer }; use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; use std::sync::Arc; use std::{io, mem}; -use tokio_io::{try_nb, AsyncRead, AsyncWrite}; use webpki::DNSNameRef; #[derive(Debug, Copy, Clone)] @@ -54,6 +65,7 @@ pub struct TlsConnector { early_data: bool, } +/* /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { From 4cc374fd4ce174ce960fc2c644b461a488ddf1f8 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 18 May 2019 18:18:26 +0800 Subject: [PATCH 116/171] wip server --- src/client.rs | 31 ++++++++----- src/common/mod.rs | 2 +- src/lib.rs | 41 +++++++++-------- src/server.rs | 114 +++++++++++++++++++++++----------------------- 4 files changed, 102 insertions(+), 86 deletions(-) diff --git a/src/client.rs b/src/client.rs index 613cd69..8527121 100644 --- a/src/client.rs +++ b/src/client.rs @@ -45,10 +45,13 @@ where type Output = io::Result>; #[inline] - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - if let MidHandshake::Handshaking(stream) = &mut *self { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(stream) = this { + let eof = !stream.state.readable(); let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { try_ready!(stream.complete_io(cx)); @@ -59,7 +62,7 @@ where } } - match mem::replace(&mut *self, MidHandshake::End) { + match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), #[cfg(feature = "early-data")] MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), @@ -83,7 +86,8 @@ where TlsState::EarlyData => { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); let (pos, data) = &mut this.early_data; // complete handshake @@ -107,7 +111,8 @@ where } TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); match stream.poll_read(cx, buf) { Poll::Ready(Ok(0)) => { @@ -136,9 +141,10 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); match this.state { #[cfg(feature = "early-data")] @@ -174,9 +180,11 @@ where } } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session).poll_flush(cx) + Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()) + .poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -186,7 +194,8 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); try_ready!(stream.poll_flush(cx)); Pin::new(&mut this.io).poll_close(cx) } diff --git a/src/common/mod.rs b/src/common/mod.rs index 71b442f..dabdcbc 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use std::task::Poll; use std::marker::Unpin; -use std::io::{ self, Read, Write }; +use std::io::{ self, Read }; use rustls::Session; use rustls::WriteV; use futures::task::Context; diff --git a/src/lib.rs b/src/lib.rs index 19f35dc..f962fc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ macro_rules! try_ready { pub mod client; mod common; -// pub mod server; +pub mod server; use common::Stream; use std::pin::Pin; @@ -25,7 +25,7 @@ use std::{io, mem}; use webpki::DNSNameRef; #[derive(Debug, Copy, Clone)] -pub enum TlsState { +enum TlsState { #[cfg(feature = "early-data")] EarlyData, Stream, @@ -35,26 +35,33 @@ pub enum TlsState { } impl TlsState { - pub(crate) fn shutdown_read(&mut self) { + fn shutdown_read(&mut self) { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, _ => *self = TlsState::ReadShutdown, } } - pub(crate) fn shutdown_write(&mut self) { + fn shutdown_write(&mut self) { match *self { TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, _ => *self = TlsState::WriteShutdown, } } - pub(crate) fn writeable(&self) -> bool { + fn writeable(&self) -> bool { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => false, _ => true, } } + + fn readable(self) -> bool { + match self { + TlsState::ReadShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } } /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. @@ -65,7 +72,6 @@ pub struct TlsConnector { early_data: bool, } -/* /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { @@ -170,32 +176,31 @@ impl TlsAcceptor { } } -/// Future returned from `ClientConfigExt::connect_async` which will resolve +/// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(client::MidHandshake); -/// Future returned from `ServerConfigExt::accept_async` which will resolve +/// Future returned from `TlsAcceptor::accept` which will resolve /// once the accept handshake has finished. pub struct Accept(server::MidHandshake); -impl Future for Connect { - type Item = client::TlsStream; - type Error = io::Error; +impl Future for Connect { + type Output = io::Result>; - fn poll(&mut self) -> Poll { - self.0.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.0).poll(cx) } } -impl Future for Accept { - type Item = server::TlsStream; - type Error = io::Error; +impl Future for Accept { + type Output = io::Result>; - fn poll(&mut self) -> Poll { - self.0.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.0).poll(cx) } } +/* #[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; diff --git a/src/server.rs b/src/server.rs index 1568414..9db4867 100644 --- a/src/server.rs +++ b/src/server.rs @@ -34,100 +34,102 @@ impl TlsStream { impl Future for MidHandshake where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - type Item = TlsStream; - type Error = io::Error; + type Output = io::Result>; #[inline] - fn poll(&mut self) -> Poll { - if let MidHandshake::Handshaking(stream) = self { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(stream) = this { + let eof = !stream.state.readable(); let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); + let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } if stream.session.wants_write() { - try_nb!(stream.complete_io()); + try_ready!(stream.complete_io(cx)); } } - match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + match mem::replace(this, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), MidHandshake::End => panic!(), } } } -impl io::Read for TlsStream +impl AsyncRead for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); + unsafe fn initializer(&self) -> Initializer { + // TODO + Initializer::nop() + } - match self.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) { - Ok(0) => { - self.state.shutdown_read(); - Ok(0) + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + + match this.state { + TlsState::Stream | TlsState::WriteShutdown => match stream.poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { + this.state.shutdown_read(); + Poll::Ready(Ok(0)) } - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state.shutdown_read(); - if self.state.writeable() { + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + if this.state.writeable() { stream.session.send_close_notify(); - self.state.shutdown_write(); + this.state.shutdown_write(); } - Ok(0) + Poll::Ready(Ok(0)) } - Err(e) => Err(e), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending }, - TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), #[cfg(feature = "early-data")] s => unreachable!("server TLS can not hit this state: {:?}", s), } } } -impl io::Write for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - stream.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() - } -} - -impl AsyncRead for TlsStream -where - IO: AsyncRead + AsyncWrite, -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - impl AsyncWrite for TlsStream where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { - fn shutdown(&mut self) -> Poll<(), io::Error> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()) + .poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()) + .poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_nb!(stream.complete_io()); - stream.io.shutdown() + let this = self.get_mut(); + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + try_ready!(stream.complete_io(cx)); + Pin::new(&mut this.io).poll_close(cx) } } From 4d673f9a72cab0be1fe50bbcb648d218ad924a20 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 18 May 2019 23:44:29 +0800 Subject: [PATCH 117/171] update iovec --- Cargo.toml | 4 +--- src/common/mod.rs | 11 ++++------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f892061..8a44728 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,7 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "*" -futures = { package = "futures-preview", version = "0.3.0-alpha.15" } -bytes = "0.4" -iovec = "0.1" +futures = { package = "futures-preview", version = "0.3.0-alpha.16" } rustls = "0.15" webpki = "0.19" diff --git a/src/common/mod.rs b/src/common/mod.rs index dabdcbc..98afcd6 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -7,7 +7,7 @@ use std::io::{ self, Read }; use rustls::Session; use rustls::WriteV; use futures::task::Context; -use futures::io::{ AsyncRead, AsyncWrite, IoVec }; +use futures::io::{ AsyncRead, AsyncWrite, IoSlice }; use smallvec::SmallVec; @@ -159,13 +159,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { let vbytes = vbytes .into_iter() - .try_fold(SmallVec::<[&'_ IoVec; 16]>::new(), |mut sum, next| { - sum.push(IoVec::from_bytes(next)?); - Some(sum) - }) - .unwrap_or_default(); + .map(|v| IoSlice::new(v)) + .collect::; 64]>>(); - match Pin::new(&mut self.io).poll_vectored_write(self.cx, &vbytes) { + match Pin::new(&mut self.io).poll_write_vectored(self.cx, &vbytes) { Poll::Ready(result) => result, Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) } From f7472e89a214b77826fbc57ec5078fe9ef4068ac Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 19 May 2019 00:48:56 +0800 Subject: [PATCH 118/171] make early data test work --- Cargo.toml | 1 + src/client.rs | 3 ++- src/common/mod.rs | 4 ++-- src/lib.rs | 20 ++++++++++---------- src/test_0rtt.rs | 37 ++++++++++++++++--------------------- tests/test.rs | 2 ++ 6 files changed, 33 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8a44728..cb60566 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ webpki = "0.19" early-data = [] [dev-dependencies] +romio = "0.3.0-alpha.8" tokio = "0.1.6" lazy_static = "1" webpki-roots = "0.16" diff --git a/src/client.rs b/src/client.rs index 8527121..a2ebdd2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,5 @@ use super::*; use rustls::Session; -use std::io::Write; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -149,6 +148,8 @@ where match this.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { + use std::io::Write; + let (pos, data) = &mut this.early_data; // write early data diff --git a/src/common/mod.rs b/src/common/mod.rs index 98afcd6..d20d5a8 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -225,5 +225,5 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } -#[cfg(test)] -mod test_stream; +// #[cfg(test)] +// mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index f962fc0..d849f33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +#![feature(async_await)] + macro_rules! try_ready { ( $e:expr ) => { match $e { @@ -10,19 +12,19 @@ macro_rules! try_ready { } } -pub mod client; mod common; +pub mod client; pub mod server; -use common::Stream; -use std::pin::Pin; -use std::task::{ Poll, Context }; -use std::future::Future; -use futures::io::{ AsyncRead, AsyncWrite, Initializer }; -use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use std::{ io, mem }; use std::sync::Arc; -use std::{io, mem}; +use std::pin::Pin; +use std::future::Future; +use std::task::{ Poll, Context }; +use futures::io::{ AsyncRead, AsyncWrite, Initializer }; +use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; use webpki::DNSNameRef; +use common::Stream; #[derive(Debug, Copy, Clone)] enum TlsState { @@ -200,8 +202,6 @@ impl Future for Accept { } } -/* #[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; -*/ diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 0182406..8c8db6c 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -1,36 +1,31 @@ -extern crate tokio; -extern crate webpki; -extern crate webpki_roots; - use std::io; use std::sync::Arc; use std::net::ToSocketAddrs; -use self::tokio::io as aio; -use self::tokio::prelude::*; -use self::tokio::net::TcpStream; +use futures::executor; +use futures::prelude::*; +use romio::tcp::TcpStream; use rustls::ClientConfig; -use ::{ TlsConnector, client::TlsStream }; +use crate::{ TlsConnector, client::TlsStream }; -fn get(config: Arc, domain: &str, rtt0: bool) +async fn get(config: Arc, domain: &str, rtt0: bool) -> io::Result<(TlsStream, String)> { - let config = TlsConnector::from(config).early_data(rtt0); + let connector = TlsConnector::from(config).early_data(rtt0); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let addr = (domain, 443) .to_socket_addrs()? .next().unwrap(); + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + let mut buf = Vec::new(); - TcpStream::connect(&addr) - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - config.connect(domain, stream) - }) - .and_then(move |stream| aio::write_all(stream, input)) - .and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new())) - .map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap())) - .wait() + let stream = TcpStream::connect(&addr).await?; + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(input.as_bytes()).await?; + stream.read_to_end(&mut buf).await?; + + Ok((stream, String::from_utf8(buf).unwrap())) } #[test] @@ -41,10 +36,10 @@ fn test_0rtt() { let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; - let (_, output) = get(config.clone(), domain, false).unwrap(); + let (_, output) = executor::block_on(get(config.clone(), domain, false)).unwrap(); assert!(output.contains("mozilla-modern.badssl.com")); - let (io, output) = get(config.clone(), domain, true).unwrap(); + let (io, output) = executor::block_on(get(config.clone(), domain, true)).unwrap(); assert!(output.contains("mozilla-modern.badssl.com")); assert_eq!(io.early_data.0, 0); diff --git a/tests/test.rs b/tests/test.rs index f0703f8..533e4e4 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,3 +1,5 @@ +#![cfg(not(test))] + #[macro_use] extern crate lazy_static; extern crate rustls; extern crate tokio; From b03c327ab6ae86ca89da602e7c4e716a2cb54e7c Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 20 May 2019 00:28:27 +0800 Subject: [PATCH 119/171] make simple test work --- src/client.rs | 10 ++-- src/common/mod.rs | 2 - src/common/vecbuf.rs | 122 ------------------------------------------- src/lib.rs | 14 ++--- src/server.rs | 10 ++-- tests/test.rs | 84 ++++++++++++++--------------- 6 files changed, 58 insertions(+), 184 deletions(-) delete mode 100644 src/common/vecbuf.rs diff --git a/src/client.rs b/src/client.rs index a2ebdd2..9e89468 100644 --- a/src/client.rs +++ b/src/client.rs @@ -44,7 +44,7 @@ where type Output = io::Result>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); if let MidHandshake::Handshaking(stream) = this { @@ -79,7 +79,7 @@ where Initializer::nop() } - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { @@ -140,7 +140,7 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); @@ -181,14 +181,14 @@ where } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); diff --git a/src/common/mod.rs b/src/common/mod.rs index d20d5a8..eacf585 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,5 +1,3 @@ -// mod vecbuf; - use std::pin::Pin; use std::task::Poll; use std::marker::Unpin; diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs deleted file mode 100644 index e550505..0000000 --- a/src/common/vecbuf.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::cmp::{ self, Ordering }; -use bytes::Buf; -use iovec::IoVec; - -pub struct VecBuf<'a, 'b: 'a> { - pos: usize, - cur: usize, - inner: &'a [&'b [u8]] -} - -impl<'a, 'b> VecBuf<'a, 'b> { - pub fn new(vbytes: &'a [&'b [u8]]) -> Self { - VecBuf { pos: 0, cur: 0, inner: vbytes } - } -} - -impl<'a, 'b> Buf for VecBuf<'a, 'b> { - fn remaining(&self) -> usize { - let sum = self.inner - .iter() - .skip(self.pos) - .map(|bytes| bytes.len()) - .sum::(); - sum - self.cur - } - - fn bytes(&self) -> &[u8] { - &self.inner[self.pos][self.cur..] - } - - fn advance(&mut self, cnt: usize) { - let current = self.inner[self.pos].len(); - match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => if self.pos + 1 < self.inner.len() { - self.pos += 1; - self.cur = 0; - } else { - self.cur += cnt; - }, - Ordering::Greater => { - if self.pos + 1 < self.inner.len() { - self.pos += 1; - } - let remaining = self.cur + cnt - current; - self.advance(remaining); - }, - Ordering::Less => self.cur += cnt, - } - } - - #[allow(clippy::needless_range_loop)] - fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { - let len = cmp::min(self.inner.len() - self.pos, dst.len()); - - if len > 0 { - dst[0] = self.bytes().into(); - } - - for i in 1..len { - dst[i] = self.inner[self.pos + i].into(); - } - - len - } -} - -#[cfg(test)] -mod test_vecbuf { - use super::*; - - #[test] - fn test_fresh_cursor_vec() { - let mut buf = VecBuf::new(&[b"he", b"llo"]); - - assert_eq!(buf.remaining(), 5); - assert_eq!(buf.bytes(), b"he"); - - buf.advance(2); - - assert_eq!(buf.remaining(), 3); - assert_eq!(buf.bytes(), b"llo"); - - buf.advance(3); - - assert_eq!(buf.remaining(), 0); - assert_eq!(buf.bytes(), b""); - } - - #[test] - fn test_get_u8() { - let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); - assert_eq!(0x21, buf.get_u8()); - } - - #[test] - fn test_get_u16() { - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x2154, buf.get_u16_be()); - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x5421, buf.get_u16_le()); - } - - #[test] - #[should_panic] - fn test_get_u16_buffer_underflow() { - let mut buf = VecBuf::new(&[b"\x21"]); - buf.get_u16_be(); - } - - #[test] - fn test_bufs_vec() { - let buf = VecBuf::new(&[b"he", b"llo"]); - - let b1: &[u8] = &mut [0]; - let b2: &[u8] = &mut [0]; - - let mut dst: [&IoVec; 2] = - [b1.into(), b2.into()]; - - assert_eq!(2, buf.bytes_vec(&mut dst[..])); - } -} diff --git a/src/lib.rs b/src/lib.rs index d849f33..df3c259 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,7 +109,7 @@ impl TlsConnector { pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { self.connect_with(domain, stream, |_| ()) } @@ -117,7 +117,7 @@ impl TlsConnector { #[inline] pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientSession), { let mut session = ClientSession::new(&self.inner, domain); @@ -156,7 +156,7 @@ impl TlsConnector { impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { self.accept_with(stream, |_| ()) } @@ -164,7 +164,7 @@ impl TlsAcceptor { #[inline] pub fn accept_with(&self, stream: IO, f: F) -> Accept where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ServerSession), { let mut session = ServerSession::new(&self.inner); @@ -189,7 +189,8 @@ pub struct Accept(server::MidHandshake); impl Future for Connect { type Output = io::Result>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } @@ -197,7 +198,8 @@ impl Future for Connect { impl Future for Accept { type Output = io::Result>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } diff --git a/src/server.rs b/src/server.rs index 9db4867..ba054a9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -39,7 +39,7 @@ where type Output = io::Result>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); if let MidHandshake::Handshaking(stream) = this { @@ -72,7 +72,7 @@ where Initializer::nop() } - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); @@ -106,21 +106,21 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .poll_write(cx, buf) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); diff --git a/tests/test.rs b/tests/test.rs index 533e4e4..a7fd2f2 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,17 +1,15 @@ -#![cfg(not(test))] - -#[macro_use] extern crate lazy_static; -extern crate rustls; -extern crate tokio; -extern crate tokio_rustls; -extern crate webpki; +#![feature(async_await)] use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::SocketAddr; -use tokio::net::{ TcpListener, TcpStream }; +use lazy_static::lazy_static; +use futures::prelude::*; +use futures::executor; +use futures::task::SpawnExt; +use romio::tcp::{ TcpListener, TcpStream }; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ TlsConnector, TlsAcceptor }; @@ -22,9 +20,6 @@ const RSA: &str = include_str!("end.rsa"); lazy_static!{ static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { - use tokio::prelude::*; - use tokio::io as aio; - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); @@ -36,26 +31,32 @@ lazy_static!{ let (send, recv) = channel(); thread::spawn(move || { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(&addr).unwrap(); + let done = async { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let mut pool = executor::ThreadPool::new()?; + let mut listener = TcpListener::bind(&addr)?; - send.send(listener.local_addr().unwrap()).unwrap(); + send.send(listener.local_addr()?).unwrap(); - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept(stream) - .and_then(|stream| { - let (reader, writer) = stream.split(); - aio::copy(reader, writer) - }) - .then(|_| Ok(())); + let mut incoming = listener.incoming(); + while let Some(stream) = incoming.next().await { + let config = config.clone(); + pool.spawn( + async move { + let stream = stream?; + let stream = config.accept(stream).await?; + let (mut reader, mut write) = stream.split(); + reader.copy_into(&mut write).await?; + Ok(()) as io::Result<()> + } + .unwrap_or_else(|err| eprintln!("{:?}", err)) + ).unwrap(); + } - tokio::spawn(done); - Ok(()) - }) - .map_err(|err| panic!("{:?}", err)); + Ok(()) as io::Result<()> + }; - tokio::run(done); + executor::block_on(done).unwrap(); }); let addr = recv.recv().unwrap(); @@ -63,31 +64,26 @@ lazy_static!{ }; } - fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { &*TEST_SERVER } -fn start_client(addr: &SocketAddr, domain: &str, config: Arc) -> io::Result<()> { - use tokio::prelude::*; - use tokio::io as aio; - +async fn start_client(addr: SocketAddr, domain: &str, config: Arc) -> io::Result<()> { const FILE: &'static [u8] = include_bytes!("../README.md"); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let config = TlsConnector::from(config); + let mut buf = vec![0; FILE.len()]; - let done = TcpStream::connect(addr) - .and_then(|stream| config.connect(domain, stream)) - .and_then(|stream| aio::write_all(stream, FILE)) - .and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, FILE); - aio::shutdown(stream) - }) - .map(drop); + let stream = TcpStream::connect(&addr).await?; + let mut stream = config.connect(domain, stream).await?; + stream.write_all(FILE).await?; + stream.read_exact(&mut buf).await?; - done.wait() + assert_eq!(buf, FILE); + + stream.close().await?; + Ok(()) } #[test] @@ -99,7 +95,7 @@ fn pass() { config.root_store.add_pem_file(&mut chain).unwrap(); let config = Arc::new(config); - start_client(addr, domain, config.clone()).unwrap(); + executor::block_on(start_client(addr.clone(), domain, config.clone())).unwrap(); } #[test] @@ -112,5 +108,5 @@ fn fail() { let config = Arc::new(config); assert_ne!(domain, &"google.com"); - assert!(start_client(addr, "google.com", config).is_err()); + assert!(executor::block_on(start_client(addr.clone(), "google.com", config)).is_err()); } From 7949f4377afa19105b9ae04331b5aa54a780faef Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 21 May 2019 01:47:50 +0800 Subject: [PATCH 120/171] make all test work! --- src/client.rs | 19 ++-- src/common/mod.rs | 64 ++++++++---- src/common/test_stream.rs | 204 +++++++++++++++++++++++--------------- src/server.rs | 20 ++-- 4 files changed, 184 insertions(+), 123 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9e89468..9f1f7f6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -97,7 +97,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = try_ready!(stream.poll_write(cx, &data[*pos..])); + let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..])); *pos += len; } } @@ -113,7 +113,7 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - match stream.poll_read(cx, buf) { + match stream.pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -167,7 +167,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = try_ready!(stream.poll_write(cx, &data[*pos..])); + let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..])); *pos += len; } } @@ -175,17 +175,17 @@ where // end this.state = TlsState::Stream; data.clear(); - stream.poll_write(cx, buf) + stream.pin().poll_write(cx, buf) } - _ => stream.poll_write(cx, buf), + _ => stream.pin().poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()) - .poll_flush(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -197,7 +197,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - try_ready!(stream.poll_flush(cx)); - Pin::new(&mut this.io).poll_close(cx) + stream.pin().poll_close(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index eacf585..585e6c9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -2,8 +2,7 @@ use std::pin::Pin; use std::task::Poll; use std::marker::Unpin; use std::io::{ self, Read }; -use rustls::Session; -use rustls::WriteV; +use rustls::{ Session, WriteV }; use futures::task::Context; use futures::io::{ AsyncRead, AsyncWrite, IoSlice }; use smallvec::SmallVec; @@ -42,6 +41,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { self } + pub fn pin(&mut self) -> Pin<&mut Self> { + Pin::new(self) + } + pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { self.complete_inner_io(cx, Focus::Empty) } @@ -124,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { }; match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => return Poll::Pending, + (true, true, _) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), (_, false, true) => { let would_block = match focus { Focus::Empty => rdlen == 0 && wrlen == 0, @@ -172,10 +175,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { - pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { - while self.session.wants_read() { - match self.complete_inner_io(cx, Focus::Readable) { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + let this = self.get_mut(); + + while this.session.wants_read() { + match this.complete_inner_io(cx, Focus::Readable) { Poll::Ready(Ok((0, _))) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, @@ -184,13 +189,17 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } // FIXME rustls always ready ? - Poll::Ready(self.session.read(buf)) + Poll::Ready(this.session.read(buf)) } +} - pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { - let len = self.session.write(buf)?; - while self.session.wants_write() { - match self.complete_inner_io(cx, Focus::Writable) { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + + let len = this.session.write(buf)?; + while this.session.wants_write() { + match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -202,7 +211,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(len)) } else { // not write zero - match self.session.write(buf) { + match this.session.write(buf) { Ok(0) => Poll::Pending, Ok(n) => Poll::Ready(Ok(n)), Err(err) => Poll::Ready(Err(err)) @@ -210,18 +219,33 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - pub fn poll_flush(&mut self, cx: &mut Context) -> Poll> { - self.session.flush()?; - while self.session.wants_write() { - match self.complete_inner_io(cx, Focus::Writable) { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + this.session.flush()?; + while this.session.wants_write() { + match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - Pin::new(&mut self.io).poll_flush(cx) + Pin::new(&mut this.io).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + while this.session.wants_write() { + match this.complete_inner_io(cx, Focus::Writable) { + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + Pin::new(&mut this.io).poll_close(cx) } } -// #[cfg(test)] -// mod test_stream; +#[cfg(test)] +mod test_stream; diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 744758a..67b9146 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -1,4 +1,10 @@ +use std::pin::Pin; +use std::task::Poll; use std::sync::Arc; +use futures::prelude::*; +use futures::task::{ Context, noop_waker_ref }; +use futures::executor; +use futures::io::{ AsyncRead, AsyncWrite }; use std::io::{ self, Read, Write, BufReader, Cursor }; use webpki::DNSNameRef; use rustls::internal::pemfile::{ certs, rsa_private_keys }; @@ -7,146 +13,172 @@ use rustls::{ ServerSession, ClientSession, Session, NoClientAuth }; -use futures::{ Async, Poll }; -use tokio_io::{ AsyncRead, AsyncWrite }; use super::Stream; struct Good<'a>(&'a mut Session); -impl<'a> Read for Good<'a> { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - self.0.write_tls(buf.by_ref()) +impl<'a> AsyncRead for Good<'a> { + fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll> { + Poll::Ready(self.0.write_tls(buf.by_ref())) } } -impl<'a> Write for Good<'a> { - fn write(&mut self, mut buf: &[u8]) -> io::Result { +impl<'a> AsyncWrite for Good<'a> { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll> { let len = self.0.read_tls(buf.by_ref())?; self.0.process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - Ok(len) + Poll::Ready(Ok(len)) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } -} -impl<'a> AsyncRead for Good<'a> {} -impl<'a> AsyncWrite for Good<'a> { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } struct Bad(bool); -impl Read for Bad { - fn read(&mut self, _: &mut [u8]) -> io::Result { - Ok(0) +impl AsyncRead for Bad { + fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + Poll::Ready(Ok(0)) } } -impl Write for Bad { - fn write(&mut self, buf: &[u8]) -> io::Result { +impl AsyncWrite for Bad { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { if self.0 { - Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending } else { - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } -impl AsyncRead for Bad {} -impl AsyncWrite for Bad { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) - } -} - - #[test] fn stream_good() -> io::Result<()> { const FILE: &'static [u8] = include_bytes!("../../README.md"); - let (mut server, mut client) = make_pair(); - do_handshake(&mut client, &mut server); - io::copy(&mut Cursor::new(FILE), &mut server)?; + let fut = async { + let (mut server, mut client) = make_pair(); + future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + io::copy(&mut Cursor::new(FILE), &mut server)?; - { - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client); + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); - let mut buf = Vec::new(); - stream.read_to_end(&mut buf)?; - assert_eq!(buf, FILE); - stream.write_all(b"Hello World!")?; - } + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!").await?; + } - let mut buf = String::new(); - server.read_to_string(&mut buf)?; - assert_eq!(buf, "Hello World!"); + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); - Ok(()) + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } #[test] fn stream_bad() -> io::Result<()> { - let (mut server, mut client) = make_pair(); - do_handshake(&mut client, &mut server); - client.set_buffer_limit(1024); + let fut = async { + let (mut server, mut client) = make_pair(); + future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + client.set_buffer_limit(1024); - let mut bad = Bad(true); - let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(stream.write(&[0x42; 8])?, 8); - assert_eq!(stream.write(&[0x42; 8])?, 8); - let r = stream.write(&[0x00; 1024])?; // fill buffer - assert!(r < 1024); - assert_eq!( - stream.write(&[0x01]).unwrap_err().kind(), - io::ErrorKind::WouldBlock - ); + let mut bad = Bad(true); + let mut stream = Stream::new(&mut bad, &mut client); + assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert!(r < 1024); - Ok(()) + let mut cx = Context::from_waker(noop_waker_ref()); + assert!(stream.pin().poll_write(&mut cx, &[0x01]).is_pending()); + + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } #[test] fn stream_handshake() -> io::Result<()> { - let (mut server, mut client) = make_pair(); + let fut = async { + let (mut server, mut client) = make_pair(); - { - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = stream.complete_io()?; + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); + let (r, w) = future::poll_fn(|cx| stream.complete_io(cx)).await?; - assert!(r > 0); - assert!(w > 0); + assert!(r > 0); + assert!(w > 0); - stream.complete_io()?; // finish server handshake - } + future::poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + } - assert!(!server.is_handshaking()); - assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); - Ok(()) + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } #[test] fn stream_handshake_eof() -> io::Result<()> { - let (_, mut client) = make_pair(); + let fut = async { + let (_, mut client) = make_pair(); - let mut bad = Bad(false); - let mut stream = Stream::new(&mut bad, &mut client); - let r = stream.complete_io(); + let mut bad = Bad(false); + let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + let mut cx = Context::from_waker(noop_waker_ref()); + let r = stream.complete_io(&mut cx); + assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); - Ok(()) + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) +} + +#[test] +fn stream_eof() -> io::Result<()> { + let fut = async { + let (mut server, mut client) = make_pair(); + future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 0); + + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } fn make_pair() -> (ServerSession, ClientSession) { @@ -169,9 +201,17 @@ fn make_pair() -> (ServerSession, ClientSession) { (server, client) } -fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { +fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll> { let mut good = Good(server); let mut stream = Stream::new(&mut good, client); - stream.complete_io().unwrap(); - stream.complete_io().unwrap(); + + if stream.session.is_handshaking() { + try_ready!(stream.complete_io(cx)); + } + + if stream.session.wants_write() { + try_ready!(stream.complete_io(cx)); + } + + Poll::Ready(Ok(())) } diff --git a/src/server.rs b/src/server.rs index ba054a9..21cc5e6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -78,7 +78,7 @@ where .set_eof(!this.state.readable()); match this.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.poll_read(cx, buf) { + TlsState::Stream | TlsState::WriteShutdown => match stream.pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -108,16 +108,16 @@ where { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()) - .poll_write(cx, buf) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.pin().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()) - .poll_flush(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -127,9 +127,7 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); - try_ready!(stream.complete_io(cx)); - Pin::new(&mut this.io).poll_close(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.pin().poll_close(cx) } } From 3ffb736d5e5247799b15964cfd377cd4c324fa09 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 22 May 2019 00:54:10 +0800 Subject: [PATCH 121/171] update server example --- Cargo.toml | 3 +- examples/server/Cargo.toml | 6 +- examples/server/src/main.rs | 150 +++++++++++++++++++----------------- src/common/mod.rs | 17 ++-- src/lib.rs | 3 + tests/test.rs | 7 +- 6 files changed, 98 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cb60566..84d6d63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -smallvec = "*" +smallvec = "0.6" futures = { package = "futures-preview", version = "0.3.0-alpha.16" } rustls = "0.15" webpki = "0.19" @@ -26,6 +26,5 @@ early-data = [] [dev-dependencies] romio = "0.3.0-alpha.8" -tokio = "0.1.6" lazy_static = "1" webpki-roots = "0.16" diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 170693f..0392ffb 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -2,8 +2,10 @@ name = "server" version = "0.1.0" authors = ["quininer "] +edition = "2018" [dependencies] +futures = { package = "futures-preview", version = "0.3.0-alpha.16" } +romio = "0.3.0-alpha.8" +structopt = "*" tokio-rustls = { path = "../.." } -tokio = { version = "0.1.6" } -clap = "2" diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index 2a94c58..a2a3b13 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -1,88 +1,100 @@ -extern crate clap; -extern crate tokio; -extern crate tokio_rustls; +#![feature(async_await)] +use std::fs::File; use std::sync::Arc; use std::net::ToSocketAddrs; -use std::io::BufReader; -use std::fs::File; -use tokio_rustls::{ - TlsAcceptor, - rustls::{ - Certificate, NoClientAuth, PrivateKey, ServerConfig, - internal::pemfile::{ certs, rsa_private_keys } - }, -}; -use tokio::prelude::{ Future, Stream }; -use tokio::io::{ self, AsyncRead }; -use tokio::net::TcpListener; -use clap::{ App, Arg }; +use std::path::{ PathBuf, Path }; +use std::io::{ self, BufReader }; +use structopt::StructOpt; +use futures::task::SpawnExt; +use futures::prelude::*; +use futures::executor; +use romio::TcpListener; +use tokio_rustls::rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; +use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys }; +use tokio_rustls::TlsAcceptor; -fn app() -> App<'static, 'static> { - App::new("server") - .about("tokio-rustls server example") - .arg(Arg::with_name("addr").value_name("ADDR").required(true)) - .arg(Arg::with_name("cert").short("c").long("cert").value_name("FILE").help("cert file.").required(true)) - .arg(Arg::with_name("key").short("k").long("key").value_name("FILE").help("key file, rsa only.").required(true)) - .arg(Arg::with_name("echo").short("e").long("echo-mode").help("echo mode.")) + +#[derive(StructOpt)] +struct Options { + addr: String, + + /// cert file + #[structopt(short="c", long="cert", parse(from_os_str))] + cert: PathBuf, + + /// key file + #[structopt(short="k", long="key", parse(from_os_str))] + key: PathBuf, + + /// echo mode + #[structopt(short="e", long="echo-mode")] + echo: bool } -fn load_certs(path: &str) -> Vec { - certs(&mut BufReader::new(File::open(path).unwrap())).unwrap() +fn load_certs(path: &Path) -> io::Result> { + certs(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) } -fn load_keys(path: &str) -> Vec { - rsa_private_keys(&mut BufReader::new(File::open(path).unwrap())).unwrap() +fn load_keys(path: &Path) -> io::Result> { + rsa_private_keys(&mut BufReader::new(File::open(path)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) } -fn main() { - let matches = app().get_matches(); +fn main() -> io::Result<()> { + let options = Options::from_args(); - let addr = matches.value_of("addr").unwrap() - .to_socket_addrs().unwrap() - .next().unwrap(); - let cert_file = matches.value_of("cert").unwrap(); - let key_file = matches.value_of("key").unwrap(); - let flag_echo = matches.occurrences_of("echo") > 0; + let addr = options.addr.to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::from(io::ErrorKind::AddrNotAvailable))?; + let certs = load_certs(&options.cert)?; + let mut keys = load_keys(&options.key)?; + let flag_echo = options.echo; + let mut pool = executor::ThreadPool::new()?; let mut config = ServerConfig::new(NoClientAuth::new()); - config.set_single_cert(load_certs(cert_file), load_keys(key_file).remove(0)) - .expect("invalid key or certificate"); - let config = TlsAcceptor::from(Arc::new(config)); + config.set_single_cert(certs, keys.remove(0)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + let acceptor = TlsAcceptor::from(Arc::new(config)); - let socket = TcpListener::bind(&addr).unwrap(); - let done = socket.incoming() - .for_each(move |stream| if flag_echo { - let addr = stream.peer_addr().ok(); - let done = config.accept(stream) - .and_then(|stream| { - let (reader, writer) = stream.split(); - io::copy(reader, writer) - }) - .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) - .map_err(move |err| println!("Error: {:?} - {:?}", err, addr)); - tokio::spawn(done); + let fut = async { + let mut listener = TcpListener::bind(&addr)?; + let mut incoming = listener.incoming(); - Ok(()) - } else { - let addr = stream.peer_addr().ok(); - let done = config.accept(stream) - .and_then(|stream| io::write_all( - stream, - &b"HTTP/1.0 200 ok\r\n\ - Connection: close\r\n\ - Content-length: 12\r\n\ - \r\n\ - Hello world!"[..] - )) - .and_then(|(stream, _)| io::flush(stream)) - .map(move |_| println!("Accept: {:?}", addr)) - .map_err(move |err| println!("Error: {:?} - {:?}", err, addr)); - tokio::spawn(done); + while let Some(stream) = incoming.next().await { + let acceptor = acceptor.clone(); - Ok(()) - }); + let fut = async move { + let stream = stream?; + let peer_addr = stream.peer_addr()?; + let mut stream = acceptor.accept(stream).await?; - tokio::run(done.map_err(drop)); + if flag_echo { + let (mut reader, mut writer) = stream.split(); + let n = reader.copy_into(&mut writer).await?; + println!("Echo: {} - {}", peer_addr, n); + } else { + stream.write_all( + &b"HTTP/1.0 200 ok\r\n\ + Connection: close\r\n\ + Content-length: 12\r\n\ + \r\n\ + Hello world!"[..] + ).await?; + stream.flush().await?; + println!("Hello: {}", peer_addr); + } + + Ok(()) as io::Result<()> + }; + + pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + } + + Ok(()) + }; + + executor::block_on(fut) } diff --git a/src/common/mod.rs b/src/common/mod.rs index 585e6c9..4d65be7 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -127,7 +127,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { }; match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + return Poll::Ready(Err(err)); + }, (_, false, true) => { let would_block = match focus { Focus::Empty => rdlen == 0 && wrlen == 0, @@ -224,11 +227,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { - Poll::Ready(Ok(_)) => (), - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } + try_ready!(this.complete_inner_io(cx, Focus::Writable)); } Pin::new(&mut this.io).poll_flush(cx) } @@ -237,11 +236,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { - Poll::Ready(Ok(_)) => (), - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } + try_ready!(this.complete_inner_io(cx, Focus::Writable)); } Pin::new(&mut this.io).poll_close(cx) } diff --git a/src/lib.rs b/src/lib.rs index df3c259..cca1c85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,9 @@ use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; use webpki::DNSNameRef; use common::Stream; +pub use rustls; +pub use webpki; + #[derive(Debug, Copy, Clone)] enum TlsState { #[cfg(feature = "early-data")] diff --git a/tests/test.rs b/tests/test.rs index a7fd2f2..acc67e3 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -26,7 +26,7 @@ lazy_static!{ let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, keys.pop().unwrap()) .expect("invalid key or certificate"); - let config = TlsAcceptor::from(Arc::new(config)); + let acceptor = TlsAcceptor::from(Arc::new(config)); let (send, recv) = channel(); @@ -40,11 +40,10 @@ lazy_static!{ let mut incoming = listener.incoming(); while let Some(stream) = incoming.next().await { - let config = config.clone(); + let acceptor = acceptor.clone(); pool.spawn( async move { - let stream = stream?; - let stream = config.accept(stream).await?; + let stream = acceptor.accept(stream?).await?; let (mut reader, mut write) = stream.split(); reader.copy_into(&mut write).await?; Ok(()) as io::Result<()> From b8e3fcb79e053ff2db492d3f92779d16c314f090 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 22 May 2019 23:57:14 +0800 Subject: [PATCH 122/171] update client example --- examples/client/Cargo.toml | 7 ++- examples/client/src/main.rs | 115 +++++++++++++++++++----------------- examples/server/Cargo.toml | 2 +- 3 files changed, 65 insertions(+), 59 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 3765efc..feec249 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -2,11 +2,12 @@ name = "client" version = "0.1.0" authors = ["quininer "] +edition = "2018" [dependencies] -webpki = "0.19" +futures = { package = "futures-preview", version = "0.3.0-alpha.16", features = ["io-compat"] } +romio = "0.3.0-alpha.8" +structopt = "0.2" tokio-rustls = { path = "../.." } -tokio = "0.1" -clap = "2" webpki-roots = "0.16" tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index ff6b315..6416db2 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -1,74 +1,79 @@ -extern crate clap; -extern crate tokio; -extern crate webpki; -extern crate webpki_roots; -extern crate tokio_rustls; - -extern crate tokio_stdin_stdout; +#![feature(async_await)] +use std::io; +use std::fs::File; +use std::path::PathBuf; use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; -use std::fs; -use tokio::io; -use tokio::net::TcpStream; -use tokio::prelude::*; -use clap::{ App, Arg }; -use tokio_rustls::{ TlsConnector, rustls::ClientConfig }; +use structopt::StructOpt; +use romio::TcpStream; +use futures::prelude::*; +use futures::executor; +use futures::compat::{ AsyncRead01CompatExt, AsyncWrite01CompatExt }; +use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef }; use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; -fn app() -> App<'static, 'static> { - App::new("client") - .about("tokio-rustls client example") - .arg(Arg::with_name("host").value_name("HOST").required(true)) - .arg(Arg::with_name("port").short("p").long("port").value_name("PORT").help("port, default `443`")) - .arg(Arg::with_name("domain").short("d").long("domain").value_name("DOMAIN").help("domain")) - .arg(Arg::with_name("cafile").short("c").long("cafile").value_name("FILE").help("CA certificate chain")) + +#[derive(StructOpt)] +struct Options { + host: String, + + /// port + #[structopt(short="p", long="port", default_value="443")] + port: u16, + + /// domain + #[structopt(short="d", long="domain")] + domain: Option, + + /// cafile + #[structopt(short="c", long="cafile", parse(from_os_str))] + cafile: Option } -fn main() { - let matches = app().get_matches(); +fn main() -> io::Result<()> { + let options = Options::from_args(); - let host = matches.value_of("host").unwrap(); - let port = matches.value_of("port") - .map(|port| port.parse().unwrap()) - .unwrap_or(443); - let domain = matches.value_of("domain").unwrap_or(host).to_owned(); - let cafile = matches.value_of("cafile"); - let text = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - - let addr = (host, port) - .to_socket_addrs().unwrap() - .next().unwrap(); + let addr = (options.host.as_str(), options.port) + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + let domain = options.domain.unwrap_or(options.host); + let content = format!( + "GET / HTTP/1.0\r\nHost: {}\r\n\r\n", + domain + ); let mut config = ClientConfig::new(); - if let Some(cafile) = cafile { - let mut pem = BufReader::new(fs::File::open(cafile).unwrap()); - config.root_store.add_pem_file(&mut pem).unwrap(); + if let Some(cafile) = &options.cafile { + let mut pem = BufReader::new(File::open(cafile)?); + config.root_store.add_pem_file(&mut pem) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?; } else { config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); } - let config = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config)); - let socket = TcpStream::connect(&addr); - let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0)); + let fut = async { + let stream = TcpStream::connect(&addr).await?; + let (mut stdin, mut stdout) = (tokio_stdin(0).compat(), tokio_stdout(0).compat()); - let done = socket - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - config.connect(domain, stream) - }) - .and_then(move |stream| io::write_all(stream, text)) - .and_then(move |(stream, _)| { - let (r, w) = stream.split(); - io::copy(r, stdout) - .map(drop) - .select2(io::copy(stdin, w).map(drop)) - .map_err(|res| res.split().0) - }) - .map(drop) - .map_err(|err| eprintln!("{:?}", err)); + let domain = DNSNameRef::try_from_ascii_str(&domain) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; - tokio::run(done); + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(content.as_bytes()).await?; + + let (mut reader, mut writer) = stream.split(); + future::try_join( + reader.copy_into(&mut stdout), + stdin.copy_into(&mut writer) + ).await?; + + Ok(()) + }; + + executor::block_on(fut) } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 0392ffb..9da4423 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,5 +7,5 @@ edition = "2018" [dependencies] futures = { package = "futures-preview", version = "0.3.0-alpha.16" } romio = "0.3.0-alpha.8" -structopt = "*" +structopt = "0.2" tokio-rustls = { path = "../.." } From 183e30f4862cbd09af322283914c4e0bd3cb3bcb Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 23 May 2019 00:05:42 +0800 Subject: [PATCH 123/171] use futures::ready! --- .travis.yml | 2 ++ src/client.rs | 12 ++++++------ src/common/mod.rs | 4 ++-- src/common/test_stream.rs | 4 ++-- src/lib.rs | 10 ---------- src/server.rs | 4 ++-- 6 files changed, 14 insertions(+), 22 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9efee9d..79678c8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,8 @@ matrix: os: osx - rust: nightly os: osx + allow_failures: + - rust: stable script: - cargo test diff --git a/src/client.rs b/src/client.rs index 9f1f7f6..7bce1b8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -53,11 +53,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } } @@ -91,13 +91,13 @@ where // complete handshake if stream.session.is_handshaking() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..])); + let len = futures::ready!(stream.pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -161,13 +161,13 @@ where // complete handshake if stream.session.is_handshaking() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..])); + let len = futures::ready!(stream.pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 4d65be7..bff9990 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -227,7 +227,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - try_ready!(this.complete_inner_io(cx, Focus::Writable)); + futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } Pin::new(&mut this.io).poll_flush(cx) } @@ -236,7 +236,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - try_ready!(this.complete_inner_io(cx, Focus::Writable)); + futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } Pin::new(&mut this.io).poll_close(cx) } diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 67b9146..1f7c14c 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -206,11 +206,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } Poll::Ready(Ok(())) diff --git a/src/lib.rs b/src/lib.rs index cca1c85..a928b87 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,16 +2,6 @@ #![feature(async_await)] -macro_rules! try_ready { - ( $e:expr ) => { - match $e { - Poll::Ready(Ok(output)) => output, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), - Poll::Pending => return Poll::Pending - } - } -} - mod common; pub mod client; pub mod server; diff --git a/src/server.rs b/src/server.rs index 21cc5e6..2ed7ba9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,11 +48,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - try_ready!(stream.complete_io(cx)); + futures::ready!(stream.complete_io(cx))?; } } From b7925003e27342c31d3ca2a4781ba6fa164d4a65 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 1 Jun 2019 22:37:13 +0800 Subject: [PATCH 124/171] clean code --- src/client.rs | 22 +++++++++++++--------- src/common/mod.rs | 18 +++++++++++++----- src/common/test_stream.rs | 8 ++++---- src/lib.rs | 3 ++- src/server.rs | 12 ++++++------ 5 files changed, 38 insertions(+), 25 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7bce1b8..11a0331 100644 --- a/src/client.rs +++ b/src/client.rs @@ -75,7 +75,6 @@ where IO: AsyncRead + AsyncWrite + Unpin, { unsafe fn initializer(&self) -> Initializer { - // TODO Initializer::nop() } @@ -97,7 +96,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.pin().poll_write(cx, &data[*pos..]))?; + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -113,7 +112,7 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - match stream.pin().poll_read(cx, buf) { + match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -154,7 +153,12 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = early_data.write(buf)?; // TODO check pending + let len = match early_data.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => + return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; data.extend_from_slice(&buf[..len]); return Poll::Ready(Ok(len)); } @@ -167,7 +171,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.pin().poll_write(cx, &data[*pos..]))?; + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -175,9 +179,9 @@ where // end this.state = TlsState::Stream; data.clear(); - stream.pin().poll_write(cx, buf) + stream.as_mut_pin().poll_write(cx, buf) } - _ => stream.pin().poll_write(cx, buf), + _ => stream.as_mut_pin().poll_write(cx, buf), } } @@ -185,7 +189,7 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_flush(cx) + stream.as_mut_pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -197,6 +201,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_close(cx) + stream.as_mut_pin().poll_close(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index bff9990..2e648ae 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,7 +14,7 @@ pub struct Stream<'a, IO, S> { pub eof: bool } -pub trait WriteTls { +trait WriteTls { fn write_tls(&mut self, cx: &mut Context) -> io::Result; } @@ -41,7 +41,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { self } - pub fn pin(&mut self) -> Pin<&mut Self> { + pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { Pin::new(self) } @@ -191,8 +191,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } } - // FIXME rustls always ready ? - Poll::Ready(this.session.read(buf)) + match this.session.read(buf) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } } } @@ -200,7 +202,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); - let len = this.session.write(buf)?; + let len = match this.session.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => + return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; while this.session.wants_write() { match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), @@ -217,6 +224,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' match this.session.write(buf) { Ok(0) => Poll::Pending, Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(err) => Poll::Ready(Err(err)) } } diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 1f7c14c..e774778 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -105,13 +105,13 @@ fn stream_bad() -> io::Result<()> { let mut bad = Bad(true); let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); - assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); - let r = future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer assert!(r < 1024); let mut cx = Context::from_waker(noop_waker_ref()); - assert!(stream.pin().poll_write(&mut cx, &[0x01]).is_pending()); + assert!(stream.as_mut_pin().poll_write(&mut cx, &[0x01]).is_pending()); Ok(()) as io::Result<()> }; diff --git a/src/lib.rs b/src/lib.rs index a928b87..65c11f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![feature(async_await)] +#![cfg_attr(test, feature(async_await))] + mod common; pub mod client; diff --git a/src/server.rs b/src/server.rs index 2ed7ba9..1e25145 100644 --- a/src/server.rs +++ b/src/server.rs @@ -68,7 +68,6 @@ where IO: AsyncRead + AsyncWrite + Unpin, { unsafe fn initializer(&self) -> Initializer { - // TODO Initializer::nop() } @@ -78,7 +77,7 @@ where .set_eof(!this.state.readable()); match this.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.pin().poll_read(cx, buf) { + TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -110,14 +109,14 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_write(cx, buf) + stream.as_mut_pin().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_flush(cx) + stream.as_mut_pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -127,7 +126,8 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); - stream.pin().poll_close(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.as_mut_pin().poll_close(cx) } } From 2f4419b285142730f4af7a13132cccf88356f189 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 13 Jul 2019 17:34:52 +0800 Subject: [PATCH 125/171] Switch to tokio-io 0.2 --- Cargo.toml | 5 +++-- src/client.rs | 8 ++++---- src/common/mod.rs | 41 ++++++++++++++++++++++------------------- src/lib.rs | 3 ++- src/server.rs | 8 ++++---- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84d6d63..f92dbef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,8 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "0.6" -futures = { package = "futures-preview", version = "0.3.0-alpha.16" } +tokio-io = { git = "https://github.com/tokio-rs/tokio" } +tokio-futures = { git = "https://github.com/tokio-rs/tokio" } rustls = "0.15" webpki = "0.19" @@ -25,6 +26,6 @@ webpki = "0.19" early-data = [] [dev-dependencies] -romio = "0.3.0-alpha.8" +tokio = { git = "https://github.com/tokio-rs/tokio" } lazy_static = "1" webpki-roots = "0.16" diff --git a/src/client.rs b/src/client.rs index 11a0331..ac961b2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -74,8 +74,8 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - unsafe fn initializer(&self) -> Initializer { - Initializer::nop() + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.prepare_uninitialized_buffer(buf) } fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { @@ -192,7 +192,7 @@ where stream.as_mut_pin().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); @@ -201,6 +201,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.as_mut_pin().poll_close(cx) + stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 2e648ae..228b1a3 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,11 +1,10 @@ use std::pin::Pin; -use std::task::Poll; +use std::task::{ Poll, Context }; use std::marker::Unpin; -use std::io::{ self, Read }; -use rustls::{ Session, WriteV }; -use futures::task::Context; -use futures::io::{ AsyncRead, AsyncWrite, IoSlice }; -use smallvec::SmallVec; +use std::io::{ self, Read, Write }; +use rustls::Session; +use tokio_io::{ AsyncRead, AsyncWrite }; +use tokio_futures as futures; pub struct Stream<'a, IO, S> { @@ -154,27 +153,31 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { fn write_tls(&mut self, cx: &mut Context) -> io::Result { - struct Writer<'a, 'b, IO> { - io: &'a mut IO, + // TODO writev + + struct Writer<'a, 'b, T> { + io: &'a mut T, cx: &'a mut Context<'b> } - impl<'a, 'b, IO: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, IO> { - fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { - let vbytes = vbytes - .into_iter() - .map(|v| IoSlice::new(v)) - .collect::; 64]>>(); + impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { + fn write(&mut self, buf: &[u8]) -> io::Result { + match Pin::new(&mut self.io).poll_write(self.cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } - match Pin::new(&mut self.io).poll_write_vectored(self.cx, &vbytes) { + fn flush(&mut self) -> io::Result<()> { + match Pin::new(&mut self.io).poll_flush(self.cx) { Poll::Ready(result) => result, Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) } } } - let mut vecio = Writer { io: self.io, cx }; - self.session.writev_tls(&mut vecio) + let mut writer = Writer { io: self.io, cx }; + self.session.write_tls(&mut writer) } } @@ -240,13 +243,13 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Pin::new(&mut this.io).poll_flush(cx) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); while this.session.wants_write() { futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } - Pin::new(&mut this.io).poll_close(cx) + Pin::new(&mut this.io).poll_shutdown(cx) } } diff --git a/src/lib.rs b/src/lib.rs index 65c11f3..59cfe44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,8 @@ use std::sync::Arc; use std::pin::Pin; use std::future::Future; use std::task::{ Poll, Context }; -use futures::io::{ AsyncRead, AsyncWrite, Initializer }; +use tokio_io::{ AsyncRead, AsyncWrite }; +use tokio_futures as futures; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; use webpki::DNSNameRef; use common::Stream; diff --git a/src/server.rs b/src/server.rs index 1e25145..6a94347 100644 --- a/src/server.rs +++ b/src/server.rs @@ -67,8 +67,8 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - unsafe fn initializer(&self) -> Initializer { - Initializer::nop() + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.prepare_uninitialized_buffer(buf) } fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { @@ -119,7 +119,7 @@ where stream.as_mut_pin().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); @@ -128,6 +128,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.as_mut_pin().poll_close(cx) + stream.as_mut_pin().poll_shutdown(cx) } } From 9daf87a17ab4e2804df92e2708313d838587e009 Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 21 Jul 2019 19:12:10 +0800 Subject: [PATCH 126/171] Update tokio --- Cargo.toml | 2 +- src/common/mod.rs | 2 +- src/lib.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f92dbef..c958a44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "0.6" tokio-io = { git = "https://github.com/tokio-rs/tokio" } -tokio-futures = { git = "https://github.com/tokio-rs/tokio" } +futures-core-preview = "0.3.0-alpha.17" rustls = "0.15" webpki = "0.19" diff --git a/src/common/mod.rs b/src/common/mod.rs index 228b1a3..60f0465 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -4,7 +4,7 @@ use std::marker::Unpin; use std::io::{ self, Read, Write }; use rustls::Session; use tokio_io::{ AsyncRead, AsyncWrite }; -use tokio_futures as futures; +use futures_core as futures; pub struct Stream<'a, IO, S> { diff --git a/src/lib.rs b/src/lib.rs index 59cfe44..f09e02d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ use std::pin::Pin; use std::future::Future; use std::task::{ Poll, Context }; use tokio_io::{ AsyncRead, AsyncWrite }; -use tokio_futures as futures; +use futures_core as futures; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; use webpki::DNSNameRef; use common::Stream; From 7fca72543c51b5392d4105f0b02466f4b60d75d2 Mon Sep 17 00:00:00 2001 From: Douman Date: Fri, 9 Aug 2019 12:37:52 +0200 Subject: [PATCH 127/171] Bump to tokio alpha (#43) --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c958a44..2bff353 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "0.6" -tokio-io = { git = "https://github.com/tokio-rs/tokio" } +tokio-io = "0.2.0-alpha.1" futures-core-preview = "0.3.0-alpha.17" rustls = "0.15" webpki = "0.19" @@ -26,6 +26,6 @@ webpki = "0.19" early-data = [] [dev-dependencies] -tokio = { git = "https://github.com/tokio-rs/tokio" } +tokio = "0.2.0-alpha.1" lazy_static = "1" webpki-roots = "0.16" From bbc66882929a2cb4a4b8870a60ddee48db8f948a Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 10 Aug 2019 23:43:19 +0800 Subject: [PATCH 128/171] Update some test --- .travis.yml | 12 ++++---- Cargo.toml | 7 +++-- src/common/test_stream.rs | 31 +++++++++----------- tests/test.rs | 61 ++++++++++++++++++++++++--------------- 4 files changed, 61 insertions(+), 50 deletions(-) diff --git a/.travis.yml b/.travis.yml index 79678c8..cab8ded 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,9 +15,9 @@ matrix: - rust: stable script: - - cargo test - - cargo test --features early-data - - cd examples/server - - cargo check - - cd ../../examples/client - - cargo check + - cargo test --test test + # - cargo test --features early-data + # - cd examples/server + # - cargo check + # - cd ../../examples/client + # - cargo check diff --git a/Cargo.toml b/Cargo.toml index 2bff353..cb318ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,13 +19,14 @@ appveyor = { repository = "quininer/tokio-rustls" } smallvec = "0.6" tokio-io = "0.2.0-alpha.1" futures-core-preview = "0.3.0-alpha.17" -rustls = "0.15" -webpki = "0.19" +rustls = "0.16" +webpki = "0.21" [features] early-data = [] [dev-dependencies] tokio = "0.2.0-alpha.1" +futures-util-preview = "0.3.0-alpha.17" lazy_static = "1" -webpki-roots = "0.16" +webpki-roots = "0.17" diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index e774778..335c67d 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -1,10 +1,8 @@ use std::pin::Pin; -use std::task::Poll; use std::sync::Arc; -use futures::prelude::*; -use futures::task::{ Context, noop_waker_ref }; -use futures::executor; -use futures::io::{ AsyncRead, AsyncWrite }; +use std::task::{ Poll, Context }; +use futures_util::future::poll_fn; +use tokio_io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; use std::io::{ self, Read, Write, BufReader, Cursor }; use webpki::DNSNameRef; use rustls::internal::pemfile::{ certs, rsa_private_keys }; @@ -73,7 +71,7 @@ fn stream_good() -> io::Result<()> { let fut = async { let (mut server, mut client) = make_pair(); - future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; io::copy(&mut Cursor::new(FILE), &mut server)?; { @@ -100,18 +98,18 @@ fn stream_good() -> io::Result<()> { fn stream_bad() -> io::Result<()> { let fut = async { let (mut server, mut client) = make_pair(); - future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; client.set_buffer_limit(1024); let mut bad = Bad(true); let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); - assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); - let r = future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer assert!(r < 1024); - let mut cx = Context::from_waker(noop_waker_ref()); - assert!(stream.as_mut_pin().poll_write(&mut cx, &[0x01]).is_pending()); + let ret = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x01])); + assert!(ret.is_pending()); Ok(()) as io::Result<()> }; @@ -127,12 +125,12 @@ fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = future::poll_fn(|cx| stream.complete_io(cx)).await?; + let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; assert!(r > 0); assert!(w > 0); - future::poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -152,8 +150,7 @@ fn stream_handshake_eof() -> io::Result<()> { let mut bad = Bad(false); let mut stream = Stream::new(&mut bad, &mut client); - let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.complete_io(&mut cx); + let r = poll_fn(|cx| stream.complete_io(&mut cx)).await?; assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); Ok(()) as io::Result<()> @@ -166,7 +163,7 @@ fn stream_handshake_eof() -> io::Result<()> { fn stream_eof() -> io::Result<()> { let fut = async { let (mut server, mut client) = make_pair(); - future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client).set_eof(true); diff --git a/tests/test.rs b/tests/test.rs index acc67e3..1db59ca 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -6,10 +6,10 @@ use std::sync::Arc; use std::sync::mpsc::channel; use std::net::SocketAddr; use lazy_static::lazy_static; -use futures::prelude::*; -use futures::executor; -use futures::task::SpawnExt; -use romio::tcp::{ TcpListener, TcpStream }; +use tokio::prelude::*; +use tokio::runtime::current_thread; +use tokio::net::{ TcpListener, TcpStream }; +use futures_util::try_future::TryFutureExt; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ TlsConnector, TlsAcceptor }; @@ -31,31 +31,39 @@ lazy_static!{ let (send, recv) = channel(); thread::spawn(move || { - let done = async { + let mut runtime = current_thread::Runtime::new().unwrap(); + let handle = runtime.handle(); + + let done = async move { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let mut pool = executor::ThreadPool::new()?; - let mut listener = TcpListener::bind(&addr)?; + let listener = TcpListener::bind(&addr)?; send.send(listener.local_addr()?).unwrap(); let mut incoming = listener.incoming(); while let Some(stream) = incoming.next().await { let acceptor = acceptor.clone(); - pool.spawn( - async move { - let stream = acceptor.accept(stream?).await?; - let (mut reader, mut write) = stream.split(); - reader.copy_into(&mut write).await?; - Ok(()) as io::Result<()> - } - .unwrap_or_else(|err| eprintln!("{:?}", err)) - ).unwrap(); + let fut = async move { + let mut stream = acceptor.accept(stream?).await?; + +// TODO split +// let (mut reader, mut write) = stream.split(); +// reader.copy(&mut write).await?; + + let mut buf = vec![0; 8192]; + let n = stream.read(&mut buf).await?; + stream.write(&buf[..n]).await?; + + Ok(()) as io::Result<()> + }; + + handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); } Ok(()) as io::Result<()> }; - executor::block_on(done).unwrap(); + runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err))); }); let addr = recv.recv().unwrap(); @@ -81,12 +89,12 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) assert_eq!(buf, FILE); - stream.close().await?; + stream.shutdown().await?; Ok(()) } -#[test] -fn pass() { +#[tokio::test] +async fn pass() -> io::Result<()> { let (addr, domain, chain) = start_server(); let mut config = ClientConfig::new(); @@ -94,11 +102,13 @@ fn pass() { config.root_store.add_pem_file(&mut chain).unwrap(); let config = Arc::new(config); - executor::block_on(start_client(addr.clone(), domain, config.clone())).unwrap(); + start_client(addr.clone(), domain, config.clone()).await?; + + Ok(()) } -#[test] -fn fail() { +#[tokio::test] +async fn fail() -> io::Result<()> { let (addr, domain, chain) = start_server(); let mut config = ClientConfig::new(); @@ -107,5 +117,8 @@ fn fail() { let config = Arc::new(config); assert_ne!(domain, &"google.com"); - assert!(executor::block_on(start_client(addr.clone(), "google.com", config)).is_err()); + let ret = start_client(addr.clone(), "google.com", config).await; + assert!(ret.is_err()); + + Ok(()) } From 0fceadf79919723564ef9033bb6b502fec28d8c1 Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 11 Aug 2019 00:00:49 +0800 Subject: [PATCH 129/171] publish 0.12.0-alpha.1 --- .travis.yml | 4 +- Cargo.toml | 2 +- src/common/test_stream.rs | 198 ++++++++++++++++++-------------------- src/test_0rtt.rs | 15 +-- 4 files changed, 102 insertions(+), 117 deletions(-) diff --git a/.travis.yml b/.travis.yml index cab8ded..b0b0082 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,8 +15,8 @@ matrix: - rust: stable script: - - cargo test --test test - # - cargo test --features early-data + - cargo test + - cargo test --features early-data # - cd examples/server # - cargo check # - cd ../../examples/client diff --git a/Cargo.toml b/Cargo.toml index cb318ff..96e1c79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha" +version = "0.12.0-alpha.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 335c67d..d109369 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -1,7 +1,9 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{ Poll, Context }; +use futures_core::ready; use futures_util::future::poll_fn; +use futures_util::task::noop_waker_ref; use tokio_io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; use std::io::{ self, Read, Write, BufReader, Cursor }; use webpki::DNSNameRef; @@ -14,7 +16,7 @@ use rustls::{ use super::Stream; -struct Good<'a>(&'a mut Session); +struct Good<'a>(&'a mut dyn Session); impl<'a> AsyncRead for Good<'a> { fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll> { @@ -34,7 +36,7 @@ impl<'a> AsyncWrite for Good<'a> { Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } @@ -60,122 +62,104 @@ impl AsyncWrite for Bad { Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } -#[test] -fn stream_good() -> io::Result<()> { +#[tokio::test] +async fn stream_good() -> io::Result<()> { const FILE: &'static [u8] = include_bytes!("../../README.md"); - let fut = async { - let (mut server, mut client) = make_pair(); - poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; - io::copy(&mut Cursor::new(FILE), &mut server)?; - - { - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client); - - let mut buf = Vec::new(); - stream.read_to_end(&mut buf).await?; - assert_eq!(buf, FILE); - stream.write_all(b"Hello World!").await?; - } - - let mut buf = String::new(); - server.read_to_string(&mut buf)?; - assert_eq!(buf, "Hello World!"); - - Ok(()) as io::Result<()> - }; - - executor::block_on(fut) -} - -#[test] -fn stream_bad() -> io::Result<()> { - let fut = async { - let (mut server, mut client) = make_pair(); - poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; - client.set_buffer_limit(1024); - - let mut bad = Bad(true); - let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); - assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); - let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer - assert!(r < 1024); - - let ret = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x01])); - assert!(ret.is_pending()); - - Ok(()) as io::Result<()> - }; - - executor::block_on(fut) -} - -#[test] -fn stream_handshake() -> io::Result<()> { - let fut = async { - let (mut server, mut client) = make_pair(); - - { - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; - - assert!(r > 0); - assert!(w > 0); - - poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake - } - - assert!(!server.is_handshaking()); - assert!(!client.is_handshaking()); - - Ok(()) as io::Result<()> - }; - - executor::block_on(fut) -} - -#[test] -fn stream_handshake_eof() -> io::Result<()> { - let fut = async { - let (_, mut client) = make_pair(); - - let mut bad = Bad(false); - let mut stream = Stream::new(&mut bad, &mut client); - - let r = poll_fn(|cx| stream.complete_io(&mut cx)).await?; - assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); - - Ok(()) as io::Result<()> - }; - - executor::block_on(fut) -} - -#[test] -fn stream_eof() -> io::Result<()> { - let fut = async { - let (mut server, mut client) = make_pair(); - poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + let (mut server, mut client) = make_pair(); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + io::copy(&mut Cursor::new(FILE), &mut server)?; + { let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + let mut stream = Stream::new(&mut good, &mut client); let mut buf = Vec::new(); stream.read_to_end(&mut buf).await?; - assert_eq!(buf.len(), 0); + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!").await?; + } - Ok(()) as io::Result<()> - }; + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); - executor::block_on(fut) + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_bad() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + client.set_buffer_limit(1024); + + let mut bad = Bad(true); + let mut stream = Stream::new(&mut bad, &mut client); + assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert!(r < 1024); + + let mut cx = Context::from_waker(noop_waker_ref()); + let ret = stream.as_mut_pin().poll_write(&mut cx, &[0x01]); + assert!(ret.is_pending()); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_handshake() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); + let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; + + assert!(r > 0); + assert!(w > 0); + + poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_handshake_eof() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut bad = Bad(false); + let mut stream = Stream::new(&mut bad, &mut client); + + let mut cx = Context::from_waker(noop_waker_ref()); + let r = stream.complete_io(&mut cx); + assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); + + Ok(()) as io::Result<()> +} + +#[tokio::test] +async fn stream_eof() -> io::Result<()> { + let (mut server, mut client) = make_pair(); + poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 0); + + Ok(()) as io::Result<()> } fn make_pair() -> (ServerSession, ClientSession) { @@ -203,11 +187,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + ready!(stream.complete_io(cx))?; } Poll::Ready(Ok(())) diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 8c8db6c..cb3e94b 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -1,9 +1,8 @@ use std::io; use std::sync::Arc; use std::net::ToSocketAddrs; -use futures::executor; -use futures::prelude::*; -use romio::tcp::TcpStream; +use tokio::prelude::*; +use tokio::net::TcpStream; use rustls::ClientConfig; use crate::{ TlsConnector, client::TlsStream }; @@ -28,19 +27,21 @@ async fn get(config: Arc, domain: &str, rtt0: bool) Ok((stream, String::from_utf8(buf).unwrap())) } -#[test] -fn test_0rtt() { +#[tokio::test] +async fn test_0rtt() -> io::Result<()> { let mut config = ClientConfig::new(); config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); config.enable_early_data = true; let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; - let (_, output) = executor::block_on(get(config.clone(), domain, false)).unwrap(); + let (_, output) = get(config.clone(), domain, false).await?; assert!(output.contains("mozilla-modern.badssl.com")); - let (io, output) = executor::block_on(get(config.clone(), domain, true)).unwrap(); + let (io, output) = get(config.clone(), domain, true).await?; assert!(output.contains("mozilla-modern.badssl.com")); assert_eq!(io.early_data.0, 0); + + Ok(()) } From c1900317619ad51846fe89136cf84d563f19a743 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Sun, 1 Sep 2019 01:29:07 -0400 Subject: [PATCH 130/171] Update to tokio 0.2.0-alpha.4 (#45) --- Cargo.toml | 8 ++++---- src/lib.rs | 23 ++++++++++------------- tests/test.rs | 11 ++++++++--- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 96e1c79..cd8a8de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,8 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "0.6" -tokio-io = "0.2.0-alpha.1" -futures-core-preview = "0.3.0-alpha.17" +tokio-io = "=0.2.0-alpha.4" +futures-core-preview = "=0.3.0-alpha.18" rustls = "0.16" webpki = "0.21" @@ -26,7 +26,7 @@ webpki = "0.21" early-data = [] [dev-dependencies] -tokio = "0.2.0-alpha.1" -futures-util-preview = "0.3.0-alpha.17" +tokio = "=0.2.0-alpha.4" +futures-util-preview = "0.3.0-alpha.18" lazy_static = "1" webpki-roots = "0.17" diff --git a/src/lib.rs b/src/lib.rs index f09e02d..f631a09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,19 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![cfg_attr(test, feature(async_await))] - - -mod common; pub mod client; +mod common; pub mod server; -use std::{ io, mem }; -use std::sync::Arc; -use std::pin::Pin; -use std::future::Future; -use std::task::{ Poll, Context }; -use tokio_io::{ AsyncRead, AsyncWrite }; -use futures_core as futures; -use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; -use webpki::DNSNameRef; use common::Stream; +use futures_core as futures; +use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{io, mem}; +use tokio_io::{AsyncRead, AsyncWrite}; +use webpki::DNSNameRef; pub use rustls; pub use webpki; diff --git a/tests/test.rs b/tests/test.rs index 1db59ca..42dd9ed 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,5 +1,3 @@ -#![feature(async_await)] - use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; @@ -36,7 +34,7 @@ lazy_static!{ let done = async move { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(&addr)?; + let listener = TcpListener::bind(&addr).await?; send.send(listener.local_addr()?).unwrap(); @@ -97,6 +95,13 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) async fn pass() -> io::Result<()> { let (addr, domain, chain) = start_server(); + // TODO: not sure how to resolve this right now but since + // TcpStream::bind now returns a future it creates a race + // condition until its ready sometimes. + use std::time::*; + let deadline = Instant::now() + Duration::from_secs(1); + tokio::timer::delay(deadline); + let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(chain)); config.root_store.add_pem_file(&mut chain).unwrap(); From 0386abcee1036f9ae18a7d50226a223c74a21eab Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 1 Sep 2019 13:30:29 +0800 Subject: [PATCH 131/171] Fix test --- Cargo.toml | 2 +- src/common/mod.rs | 1 + tests/test.rs | 8 +++++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cd8a8de..723ffaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.1" +version = "0.12.0-alpha.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/common/mod.rs b/src/common/mod.rs index 60f0465..ac72608 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -249,6 +249,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' while this.session.wants_write() { futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } + Pin::new(&mut this.io).poll_shutdown(cx) } } diff --git a/tests/test.rs b/tests/test.rs index 42dd9ed..5749efe 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -44,13 +44,15 @@ lazy_static!{ let fut = async move { let mut stream = acceptor.accept(stream?).await?; -// TODO split -// let (mut reader, mut write) = stream.split(); -// reader.copy(&mut write).await?; + // TODO split + // + // let (mut reader, mut write) = stream.split(); + // reader.copy(&mut write).await?; let mut buf = vec![0; 8192]; let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; + let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> }; From 41b25ec74551cb61830ab9c6ae7ff9a0aa0bb0ff Mon Sep 17 00:00:00 2001 From: Jeb Rosen Date: Tue, 24 Sep 2019 20:02:23 -0700 Subject: [PATCH 132/171] Update to tokio '0.2.0-alpha.5'. (#48) --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 723ffaa..a4c0e67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "0.6" -tokio-io = "=0.2.0-alpha.4" +tokio-io = "=0.2.0-alpha.5" futures-core-preview = "=0.3.0-alpha.18" rustls = "0.16" webpki = "0.21" @@ -26,7 +26,7 @@ webpki = "0.21" early-data = [] [dev-dependencies] -tokio = "=0.2.0-alpha.4" +tokio = "=0.2.0-alpha.5" futures-util-preview = "0.3.0-alpha.18" lazy_static = "1" webpki-roots = "0.17" From aa6e8444def3b4aa62099583459fbdf73e4f2e38 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 26 Sep 2019 00:24:32 +0800 Subject: [PATCH 133/171] Fix write behavior --- Cargo.toml | 2 +- src/common/mod.rs | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a4c0e67..63b4122 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.2" +version = "0.12.0-alpha.3" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/src/common/mod.rs b/src/common/mod.rs index ac72608..1af5ecb 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -110,6 +110,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } + if let Focus::Writable = focus { + if !write_would_block { + return Poll::Ready(Ok((rdlen, wrlen))); + } else { + return Poll::Pending; + } + } + if !self.eof && self.session.wants_read() { match self.complete_read_io(cx) { Poll::Ready(Ok(0)) => self.eof = true, From 4dd5e19a1909ef50aa98606b8cdb38db886ddd92 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 01:54:55 +0800 Subject: [PATCH 134/171] refactor: separate read and write --- src/client.rs | 55 +++++-------- src/common/mod.rs | 162 ++++++++++++++------------------------ src/common/test_stream.rs | 11 +-- src/server.rs | 4 +- src/test_0rtt.rs | 1 + tests/test.rs | 10 ++- 6 files changed, 96 insertions(+), 147 deletions(-) diff --git a/src/client.rs b/src/client.rs index ac961b2..c901043 100644 --- a/src/client.rs +++ b/src/client.rs @@ -53,11 +53,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } } @@ -81,32 +81,7 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => { - let this = self.get_mut(); - - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); - let (pos, data) = &mut this.early_data; - - // complete handshake - if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; - *pos += len; - } - } - - // end - this.state = TlsState::Stream; - data.clear(); - - Pin::new(this).poll_read(cx, buf) - } + TlsState::EarlyData => Poll::Pending, TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -116,7 +91,7 @@ where Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) - } + }, Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); @@ -125,9 +100,8 @@ where this.state.shutdown_write(); } Poll::Ready(Ok(0)) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => Poll::Pending + }, + output => output } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), @@ -153,7 +127,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match early_data.write(buf) { + let len = match dbg!(early_data.write(buf)) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -165,7 +139,7 @@ where // complete handshake if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } // write early data (fallback) @@ -189,6 +163,14 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] { + // complete handshake + if stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } + } + stream.as_mut_pin().poll_flush(cx) } @@ -201,6 +183,11 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); + + // TODO + // + // should we complete the handshake? + stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 1af5ecb..e9fc783 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -13,17 +13,6 @@ pub struct Stream<'a, IO, S> { pub eof: bool } -trait WriteTls { - fn write_tls(&mut self, cx: &mut Context) -> io::Result; -} - -#[derive(Clone, Copy)] -enum Focus { - Empty, - Readable, - Writable -} - impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { @@ -44,11 +33,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { - self.complete_inner_io(cx, Focus::Empty) - } - - fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { + fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -76,7 +61,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_tls(cx); + let _ = self.write_io(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; @@ -84,85 +69,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { - match self.write_tls(cx) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } - } - - fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { - let mut wrlen = 0; - let mut rdlen = 0; - - loop { - let mut write_would_block = false; - let mut read_would_block = false; - - while self.session.wants_write() { - match self.complete_write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, - Poll::Pending => { - write_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - if let Focus::Writable = focus { - if !write_would_block { - return Poll::Ready(Ok((rdlen, wrlen))); - } else { - return Poll::Pending; - } - } - - if !self.eof && self.session.wants_read() { - match self.complete_read_io(cx) { - Poll::Ready(Ok(0)) => self.eof = true, - Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - let would_block = match focus { - Focus::Empty => write_would_block || read_would_block, - Focus::Readable => read_would_block, - Focus::Writable => write_would_block, - }; - - match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); - return Poll::Ready(Err(err)); - }, - (_, false, true) => { - let would_block = match focus { - Focus::Empty => rdlen == 0 && wrlen == 0, - Focus::Readable => rdlen == 0, - Focus::Writable => wrlen == 0 - }; - - return if would_block { - Poll::Pending - } else { - Poll::Ready(Ok((rdlen, wrlen))) - }; - }, - (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => return Poll::Pending, - (..) => () - } - } - } -} - -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { - fn write_tls(&mut self, cx: &mut Context) -> io::Result { - // TODO writev - + fn write_io(&mut self, cx: &mut Context) -> Poll> { struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -185,7 +92,58 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } let mut writer = Writer { io: self.io, cx }; - self.session.write_tls(&mut writer) + + match self.session.write_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { + write_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + if !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + let would_block = write_would_block || read_would_block; + + return match (self.eof, self.session.is_handshaking(), would_block) { + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Poll::Ready(Err(err)) + }, + (_, false, true) => if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + }, + (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => Poll::Pending, + (..) => continue + } + } } } @@ -194,8 +152,8 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a let this = self.get_mut(); while this.session.wants_read() { - match this.complete_inner_io(cx, Focus::Readable) { - Poll::Ready(Ok((0, _))) => break, + match this.read_io(cx) { + Poll::Ready(Ok(0)) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -220,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Err(err) => return Poll::Ready(Err(err)) }; while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { + match this.write_io(cx) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -246,7 +204,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; + futures::ready!(this.write_io(cx))?; } Pin::new(&mut this.io).poll_flush(cx) } @@ -255,7 +213,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; + futures::ready!(this.write_io(cx))?; } Pin::new(&mut this.io).poll_shutdown(cx) diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index d109369..20cc4eb 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -83,6 +83,7 @@ async fn stream_good() -> io::Result<()> { stream.read_to_end(&mut buf).await?; assert_eq!(buf, FILE); stream.write_all(b"Hello World!").await?; + stream.flush().await?; } let mut buf = String::new(); @@ -119,12 +120,12 @@ async fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; + let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -141,7 +142,7 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.complete_io(&mut cx); + let r = stream.handshake(&mut cx); assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); Ok(()) as io::Result<()> @@ -187,11 +188,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - ready!(stream.complete_io(cx))?; + ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - ready!(stream.complete_io(cx))?; + ready!(stream.handshake(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 6a94347..92043c9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,11 +48,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } } diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index cb3e94b..898deef 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -22,6 +22,7 @@ async fn get(config: Arc, domain: &str, rtt0: bool) let stream = TcpStream::connect(&addr).await?; let mut stream = connector.connect(domain, stream).await?; stream.write_all(input.as_bytes()).await?; + stream.flush().await?; stream.read_to_end(&mut buf).await?; Ok((stream, String::from_utf8(buf).unwrap())) diff --git a/tests/test.rs b/tests/test.rs index 5749efe..74918ca 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -52,18 +52,19 @@ lazy_static!{ let mut buf = vec![0; 8192]; let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; + stream.flush().await?; let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> - }; + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + handle.spawn(fut).unwrap(); } Ok(()) as io::Result<()> - }; + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err))); + runtime.block_on(done); }); let addr = recv.recv().unwrap(); @@ -85,6 +86,7 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) let stream = TcpStream::connect(&addr).await?; let mut stream = config.connect(domain, stream).await?; stream.write_all(FILE).await?; + stream.flush().await?; stream.read_exact(&mut buf).await?; assert_eq!(buf, FILE); From 315d927473386d418d3fbf315a4ec44e53e320a1 Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Tue, 1 Oct 2019 11:04:43 +0900 Subject: [PATCH 135/171] Update to tokio 0.2.0-alpha.6 (#49) --- Cargo.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 63b4122..ff2145f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,8 @@ appveyor = { repository = "quininer/tokio-rustls" } [dependencies] smallvec = "0.6" -tokio-io = "=0.2.0-alpha.5" -futures-core-preview = "=0.3.0-alpha.18" +tokio-io = "=0.2.0-alpha.6" +futures-core-preview = "=0.3.0-alpha.19" rustls = "0.16" webpki = "0.21" @@ -26,7 +26,7 @@ webpki = "0.21" early-data = [] [dev-dependencies] -tokio = "=0.2.0-alpha.5" -futures-util-preview = "0.3.0-alpha.18" +tokio = "=0.2.0-alpha.6" +futures-util-preview = "0.3.0-alpha.19" lazy_static = "1" webpki-roots = "0.17" From 4109c34207a2420ece6ac932816591ff5d115284 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 10:24:22 +0800 Subject: [PATCH 136/171] Revert "refactor: separate read and write" This reverts commit 4dd5e19a1909ef50aa98606b8cdb38db886ddd92. --- src/client.rs | 55 ++++++++----- src/common/mod.rs | 162 ++++++++++++++++++++++++-------------- src/common/test_stream.rs | 11 ++- src/server.rs | 4 +- src/test_0rtt.rs | 1 - tests/test.rs | 10 +-- 6 files changed, 147 insertions(+), 96 deletions(-) diff --git a/src/client.rs b/src/client.rs index c901043..ac961b2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -53,11 +53,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } } @@ -81,7 +81,32 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => Poll::Pending, + TlsState::EarlyData => { + let this = self.get_mut(); + + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + let (pos, data) = &mut this.early_data; + + // complete handshake + if stream.session.is_handshaking() { + futures::ready!(stream.complete_io(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + data.clear(); + + Pin::new(this).poll_read(cx, buf) + } TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -91,7 +116,7 @@ where Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) - }, + } Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); @@ -100,8 +125,9 @@ where this.state.shutdown_write(); } Poll::Ready(Ok(0)) - }, - output => output + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), @@ -127,7 +153,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match dbg!(early_data.write(buf)) { + let len = match early_data.write(buf) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -139,7 +165,7 @@ where // complete handshake if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } // write early data (fallback) @@ -163,14 +189,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - - #[cfg(feature = "early-data")] { - // complete handshake - if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } - } - stream.as_mut_pin().poll_flush(cx) } @@ -183,11 +201,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - - // TODO - // - // should we complete the handshake? - stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index e9fc783..1af5ecb 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -13,6 +13,17 @@ pub struct Stream<'a, IO, S> { pub eof: bool } +trait WriteTls { + fn write_tls(&mut self, cx: &mut Context) -> io::Result; +} + +#[derive(Clone, Copy)] +enum Focus { + Empty, + Readable, + Writable +} + impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { @@ -33,7 +44,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - fn read_io(&mut self, cx: &mut Context) -> Poll> { + pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { + self.complete_inner_io(cx, Focus::Empty) + } + + fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -61,7 +76,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_io(cx); + let _ = self.write_tls(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; @@ -69,7 +84,85 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn write_io(&mut self, cx: &mut Context) -> Poll> { + fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { + match self.write_tls(cx) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.session.wants_write() { + match self.complete_write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { + write_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + if let Focus::Writable = focus { + if !write_would_block { + return Poll::Ready(Ok((rdlen, wrlen))); + } else { + return Poll::Pending; + } + } + + if !self.eof && self.session.wants_read() { + match self.complete_read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + let would_block = match focus { + Focus::Empty => write_would_block || read_would_block, + Focus::Readable => read_would_block, + Focus::Writable => write_would_block, + }; + + match (self.eof, self.session.is_handshaking(), would_block) { + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + return Poll::Ready(Err(err)); + }, + (_, false, true) => { + let would_block = match focus { + Focus::Empty => rdlen == 0 && wrlen == 0, + Focus::Readable => rdlen == 0, + Focus::Writable => wrlen == 0 + }; + + return if would_block { + Poll::Pending + } else { + Poll::Ready(Ok((rdlen, wrlen))) + }; + }, + (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => return Poll::Pending, + (..) => () + } + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { + fn write_tls(&mut self, cx: &mut Context) -> io::Result { + // TODO writev + struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -92,58 +185,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } let mut writer = Writer { io: self.io, cx }; - - match self.session.write_tls(&mut writer) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } - } - - pub fn handshake(&mut self, cx: &mut Context) -> Poll> { - let mut wrlen = 0; - let mut rdlen = 0; - - loop { - let mut write_would_block = false; - let mut read_would_block = false; - - while self.session.wants_write() { - match self.write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, - Poll::Pending => { - write_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - if !self.eof && self.session.wants_read() { - match self.read_io(cx) { - Poll::Ready(Ok(0)) => self.eof = true, - Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - let would_block = write_would_block || read_would_block; - - return match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); - Poll::Ready(Err(err)) - }, - (_, false, true) => if rdlen != 0 || wrlen != 0 { - Poll::Ready(Ok((rdlen, wrlen))) - } else { - Poll::Pending - }, - (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => Poll::Pending, - (..) => continue - } - } + self.session.write_tls(&mut writer) } } @@ -152,8 +194,8 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a let this = self.get_mut(); while this.session.wants_read() { - match this.read_io(cx) { - Poll::Ready(Ok(0)) => break, + match this.complete_inner_io(cx, Focus::Readable) { + Poll::Ready(Ok((0, _))) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -178,7 +220,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Err(err) => return Poll::Ready(Err(err)) }; while this.session.wants_write() { - match this.write_io(cx) { + match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -204,7 +246,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - futures::ready!(this.write_io(cx))?; + futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } Pin::new(&mut this.io).poll_flush(cx) } @@ -213,7 +255,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - futures::ready!(this.write_io(cx))?; + futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } Pin::new(&mut this.io).poll_shutdown(cx) diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 20cc4eb..d109369 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -83,7 +83,6 @@ async fn stream_good() -> io::Result<()> { stream.read_to_end(&mut buf).await?; assert_eq!(buf, FILE); stream.write_all(b"Hello World!").await?; - stream.flush().await?; } let mut buf = String::new(); @@ -120,12 +119,12 @@ async fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; + let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake + poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -142,7 +141,7 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.handshake(&mut cx); + let r = stream.complete_io(&mut cx); assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); Ok(()) as io::Result<()> @@ -188,11 +187,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - ready!(stream.handshake(cx))?; + ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - ready!(stream.handshake(cx))?; + ready!(stream.complete_io(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 92043c9..6a94347 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,11 +48,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } } diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 898deef..cb3e94b 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -22,7 +22,6 @@ async fn get(config: Arc, domain: &str, rtt0: bool) let stream = TcpStream::connect(&addr).await?; let mut stream = connector.connect(domain, stream).await?; stream.write_all(input.as_bytes()).await?; - stream.flush().await?; stream.read_to_end(&mut buf).await?; Ok((stream, String::from_utf8(buf).unwrap())) diff --git a/tests/test.rs b/tests/test.rs index 74918ca..5749efe 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -52,19 +52,18 @@ lazy_static!{ let mut buf = vec![0; 8192]; let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; - stream.flush().await?; let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> - }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + }; - handle.spawn(fut).unwrap(); + handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); } Ok(()) as io::Result<()> - }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + }; - runtime.block_on(done); + runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err))); }); let addr = recv.recv().unwrap(); @@ -86,7 +85,6 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) let stream = TcpStream::connect(&addr).await?; let mut stream = config.connect(domain, stream).await?; stream.write_all(FILE).await?; - stream.flush().await?; stream.read_exact(&mut buf).await?; assert_eq!(buf, FILE); From 64948af0a79e6e8baef5e025c59ca5cdd59d9351 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 10:24:39 +0800 Subject: [PATCH 137/171] publish 0.12.0-alpha.4 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ff2145f..92fd13b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.3" +version = "0.12.0-alpha.4" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 66f17e3b18ad93b2a42a53f66039366a92101336 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 10:53:11 +0800 Subject: [PATCH 138/171] Revert "Revert "refactor: separate read and write"" This reverts commit 4109c34207a2420ece6ac932816591ff5d115284. --- src/client.rs | 55 +++++-------- src/common/mod.rs | 162 ++++++++++++++------------------------ src/common/test_stream.rs | 11 +-- src/server.rs | 4 +- src/test_0rtt.rs | 1 + tests/test.rs | 10 ++- 6 files changed, 96 insertions(+), 147 deletions(-) diff --git a/src/client.rs b/src/client.rs index ac961b2..c901043 100644 --- a/src/client.rs +++ b/src/client.rs @@ -53,11 +53,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } } @@ -81,32 +81,7 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => { - let this = self.get_mut(); - - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); - let (pos, data) = &mut this.early_data; - - // complete handshake - if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; - *pos += len; - } - } - - // end - this.state = TlsState::Stream; - data.clear(); - - Pin::new(this).poll_read(cx, buf) - } + TlsState::EarlyData => Poll::Pending, TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -116,7 +91,7 @@ where Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) - } + }, Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); @@ -125,9 +100,8 @@ where this.state.shutdown_write(); } Poll::Ready(Ok(0)) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => Poll::Pending + }, + output => output } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), @@ -153,7 +127,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match early_data.write(buf) { + let len = match dbg!(early_data.write(buf)) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -165,7 +139,7 @@ where // complete handshake if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } // write early data (fallback) @@ -189,6 +163,14 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] { + // complete handshake + if stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } + } + stream.as_mut_pin().poll_flush(cx) } @@ -201,6 +183,11 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); + + // TODO + // + // should we complete the handshake? + stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 1af5ecb..e9fc783 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -13,17 +13,6 @@ pub struct Stream<'a, IO, S> { pub eof: bool } -trait WriteTls { - fn write_tls(&mut self, cx: &mut Context) -> io::Result; -} - -#[derive(Clone, Copy)] -enum Focus { - Empty, - Readable, - Writable -} - impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { @@ -44,11 +33,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { - self.complete_inner_io(cx, Focus::Empty) - } - - fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { + fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -76,7 +61,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_tls(cx); + let _ = self.write_io(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; @@ -84,85 +69,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { - match self.write_tls(cx) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } - } - - fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { - let mut wrlen = 0; - let mut rdlen = 0; - - loop { - let mut write_would_block = false; - let mut read_would_block = false; - - while self.session.wants_write() { - match self.complete_write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, - Poll::Pending => { - write_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - if let Focus::Writable = focus { - if !write_would_block { - return Poll::Ready(Ok((rdlen, wrlen))); - } else { - return Poll::Pending; - } - } - - if !self.eof && self.session.wants_read() { - match self.complete_read_io(cx) { - Poll::Ready(Ok(0)) => self.eof = true, - Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - let would_block = match focus { - Focus::Empty => write_would_block || read_would_block, - Focus::Readable => read_would_block, - Focus::Writable => write_would_block, - }; - - match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); - return Poll::Ready(Err(err)); - }, - (_, false, true) => { - let would_block = match focus { - Focus::Empty => rdlen == 0 && wrlen == 0, - Focus::Readable => rdlen == 0, - Focus::Writable => wrlen == 0 - }; - - return if would_block { - Poll::Pending - } else { - Poll::Ready(Ok((rdlen, wrlen))) - }; - }, - (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => return Poll::Pending, - (..) => () - } - } - } -} - -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { - fn write_tls(&mut self, cx: &mut Context) -> io::Result { - // TODO writev - + fn write_io(&mut self, cx: &mut Context) -> Poll> { struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -185,7 +92,58 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } let mut writer = Writer { io: self.io, cx }; - self.session.write_tls(&mut writer) + + match self.session.write_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { + write_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + if !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + let would_block = write_would_block || read_would_block; + + return match (self.eof, self.session.is_handshaking(), would_block) { + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Poll::Ready(Err(err)) + }, + (_, false, true) => if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + }, + (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => Poll::Pending, + (..) => continue + } + } } } @@ -194,8 +152,8 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a let this = self.get_mut(); while this.session.wants_read() { - match this.complete_inner_io(cx, Focus::Readable) { - Poll::Ready(Ok((0, _))) => break, + match this.read_io(cx) { + Poll::Ready(Ok(0)) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -220,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Err(err) => return Poll::Ready(Err(err)) }; while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { + match this.write_io(cx) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -246,7 +204,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; + futures::ready!(this.write_io(cx))?; } Pin::new(&mut this.io).poll_flush(cx) } @@ -255,7 +213,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; + futures::ready!(this.write_io(cx))?; } Pin::new(&mut this.io).poll_shutdown(cx) diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index d109369..20cc4eb 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -83,6 +83,7 @@ async fn stream_good() -> io::Result<()> { stream.read_to_end(&mut buf).await?; assert_eq!(buf, FILE); stream.write_all(b"Hello World!").await?; + stream.flush().await?; } let mut buf = String::new(); @@ -119,12 +120,12 @@ async fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; + let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -141,7 +142,7 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.complete_io(&mut cx); + let r = stream.handshake(&mut cx); assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); Ok(()) as io::Result<()> @@ -187,11 +188,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - ready!(stream.complete_io(cx))?; + ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - ready!(stream.complete_io(cx))?; + ready!(stream.handshake(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 6a94347..92043c9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,11 +48,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } } diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index cb3e94b..898deef 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -22,6 +22,7 @@ async fn get(config: Arc, domain: &str, rtt0: bool) let stream = TcpStream::connect(&addr).await?; let mut stream = connector.connect(domain, stream).await?; stream.write_all(input.as_bytes()).await?; + stream.flush().await?; stream.read_to_end(&mut buf).await?; Ok((stream, String::from_utf8(buf).unwrap())) diff --git a/tests/test.rs b/tests/test.rs index 5749efe..74918ca 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -52,18 +52,19 @@ lazy_static!{ let mut buf = vec![0; 8192]; let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; + stream.flush().await?; let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> - }; + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + handle.spawn(fut).unwrap(); } Ok(()) as io::Result<()> - }; + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err))); + runtime.block_on(done); }); let addr = recv.recv().unwrap(); @@ -85,6 +86,7 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) let stream = TcpStream::connect(&addr).await?; let mut stream = config.connect(domain, stream).await?; stream.write_all(FILE).await?; + stream.flush().await?; stream.read_exact(&mut buf).await?; assert_eq!(buf, FILE); From 86171d34a8c8d9e630b9ca5c2af999389ae400e1 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 14:36:16 +0800 Subject: [PATCH 139/171] refactor: more read an write --- src/client.rs | 8 +-- src/common/mod.rs | 133 +++++++++++++++++++++++--------------- src/common/test_stream.rs | 2 +- src/server.rs | 2 +- 4 files changed, 88 insertions(+), 57 deletions(-) diff --git a/src/client.rs b/src/client.rs index c901043..4803410 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ where let (io, session) = stream.get_mut(); let mut stream = Stream::new(io, session).set_eof(eof); - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } @@ -127,7 +127,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match dbg!(early_data.write(buf)) { + let len = match early_data.write(buf) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -138,7 +138,7 @@ where } // complete handshake - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } @@ -166,7 +166,7 @@ where #[cfg(feature = "early-data")] { // complete handshake - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } } diff --git a/src/common/mod.rs b/src/common/mod.rs index e9fc783..195d0da 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -33,6 +33,18 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } + fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { + self.session.process_new_packets() + .map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); + + io::Error::new(io::ErrorKind::InvalidData, err) + }) + } + fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, @@ -56,16 +68,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Err(err) => return Poll::Ready(Err(err)) }; - self.session.process_new_packets() - .map_err(|err| { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ = self.write_io(cx); - - io::Error::new(io::ErrorKind::InvalidData, err) - })?; - Poll::Ready(Ok(n)) } @@ -118,29 +120,31 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - if !self.eof && self.session.wants_read() { + while !self.eof && self.session.wants_read() { match self.read_io(cx) { Poll::Ready(Ok(0)) => self.eof = true, Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, + Poll::Pending => { + read_would_block = true; + break + }, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - let would_block = write_would_block || read_would_block; + self.process_new_packets(cx)?; - return match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { + return match (self.eof, self.session.is_handshaking()) { + (true, true) => { let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); Poll::Ready(Err(err)) }, - (_, false, true) => if rdlen != 0 || wrlen != 0 { + (_, false) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 { Poll::Ready(Ok((rdlen, wrlen))) } else { Poll::Pending }, - (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => Poll::Pending, (..) => continue } } @@ -150,53 +154,80 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { let this = self.get_mut(); + let mut pos = 0; - while this.session.wants_read() { - match this.read_io(cx) { - Poll::Ready(Ok(0)) => break, - Poll::Ready(Ok(_)) => (), - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + while pos != buf.len() { + let mut would_block = false; + + // read a packet + while this.session.wants_read() { + match this.read_io(cx) { + Poll::Ready(Ok(0)) => { + this.eof = true; + break + }, + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + this.process_new_packets(cx)?; + + return match this.session.read(&mut buf[pos..]) { + Ok(0) if pos == 0 && would_block => Poll::Pending, + Ok(n) if this.eof || would_block => Poll::Ready(Ok(pos + n)), + Ok(n) => { + pos += n; + continue + }, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => + Poll::Ready(Ok(pos)), + Err(err) => Poll::Ready(Err(err)) } } - match this.session.read(buf) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } + Poll::Ready(Ok(pos)) } } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); + let mut pos = 0; - let len = match this.session.write(buf) { - Ok(n) => n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => - return Poll::Pending, - Err(err) => return Poll::Ready(Err(err)) - }; - while this.session.wants_write() { - match this.write_io(cx) { - Poll::Ready(Ok(_)) => (), - Poll::Pending if len != 0 => break, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + while pos != buf.len() { + let mut would_block = false; + + match this.session.write(&buf[pos..]) { + Ok(n) => pos += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), + Err(err) => return Poll::Ready(Err(err)) + }; + + while this.session.wants_write() { + match this.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break + }, + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + return match (pos, would_block) { + (0, true) => Poll::Pending, + (n, true) => Poll::Ready(Ok(n)), + (_, false) => continue } } - if len != 0 || buf.is_empty() { - Poll::Ready(Ok(len)) - } else { - // not write zero - match this.session.write(buf) { - Ok(0) => Poll::Pending, - Ok(n) => Poll::Ready(Ok(n)), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(err) => Poll::Ready(Err(err)) - } - } + Poll::Ready(Ok(pos)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 20cc4eb..8be20f4 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -187,7 +187,7 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut good = Good(server); let mut stream = Stream::new(&mut good, client); - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { ready!(stream.handshake(cx))?; } diff --git a/src/server.rs b/src/server.rs index 92043c9..91c1cb4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -47,7 +47,7 @@ where let (io, session) = stream.get_mut(); let mut stream = Stream::new(io, session).set_eof(eof); - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } From 821d1c129f88e9847b28071c4efe8ef842eb0351 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 16:27:42 +0800 Subject: [PATCH 140/171] move badssl test --- src/test_0rtt.rs => tests/badssl.rs | 42 +++++++++++++++++++---------- tests/test.rs | 2 -- 2 files changed, 28 insertions(+), 16 deletions(-) rename src/test_0rtt.rs => tests/badssl.rs (53%) diff --git a/src/test_0rtt.rs b/tests/badssl.rs similarity index 53% rename from src/test_0rtt.rs rename to tests/badssl.rs index 898deef..74bd294 100644 --- a/src/test_0rtt.rs +++ b/tests/badssl.rs @@ -4,16 +4,16 @@ use std::net::ToSocketAddrs; use tokio::prelude::*; use tokio::net::TcpStream; use rustls::ClientConfig; -use crate::{ TlsConnector, client::TlsStream }; +use tokio_rustls::{ TlsConnector, client::TlsStream }; -async fn get(config: Arc, domain: &str, rtt0: bool) +async fn get(config: Arc, domain: &str, port: u16) -> io::Result<(TlsStream, String)> { - let connector = TlsConnector::from(config).early_data(rtt0); + let connector = TlsConnector::from(config); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - let addr = (domain, 443) + let addr = (domain, port) .to_socket_addrs()? .next().unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); @@ -29,20 +29,34 @@ async fn get(config: Arc, domain: &str, rtt0: bool) } #[tokio::test] -async fn test_0rtt() -> io::Result<()> { +async fn test_tls12() -> io::Result<()> { let mut config = ClientConfig::new(); config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - config.enable_early_data = true; + config.versions = vec![rustls::ProtocolVersion::TLSv1_2]; let config = Arc::new(config); - let domain = "mozilla-modern.badssl.com"; + let domain = "tls-v1-2.badssl.com"; - let (_, output) = get(config.clone(), domain, false).await?; - assert!(output.contains("mozilla-modern.badssl.com")); - - let (io, output) = get(config.clone(), domain, true).await?; - assert!(output.contains("mozilla-modern.badssl.com")); - - assert_eq!(io.early_data.0, 0); + let (_, output) = get(config.clone(), domain, 1012).await?; + assert!(output.contains("tls-v1-2.badssl.com")); + + Ok(()) +} + +#[should_panic] +#[test] +fn test_tls13() { + unimplemented!("todo https://github.com/chromium/badssl.com/pull/373"); +} + +#[tokio::test] +async fn test_modern() -> io::Result<()> { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + let config = Arc::new(config); + let domain = "mozilla-modern.badssl.com"; + + let (_, output) = get(config.clone(), domain, 443).await?; + assert!(output.contains("mozilla-modern.badssl.com")); Ok(()) } diff --git a/tests/test.rs b/tests/test.rs index 74918ca..6ebdee9 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -53,7 +53,6 @@ lazy_static!{ let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; stream.flush().await?; - let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); @@ -91,7 +90,6 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) assert_eq!(buf, FILE); - stream.shutdown().await?; Ok(()) } From 369c13d6a5794bd73a5f53dc288d77b70d517086 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 23:00:49 +0800 Subject: [PATCH 141/171] add 0-RTT test --- src/client.rs | 2 + src/lib.rs | 4 -- src/server.rs | 2 + tests/early-data.rs | 118 ++++++++++++++++++++++++++++++++++++++++++++ tests/test.rs | 16 ++---- 5 files changed, 127 insertions(+), 15 deletions(-) create mode 100644 tests/early-data.rs diff --git a/src/client.rs b/src/client.rs index 4803410..26607b3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -113,6 +113,8 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) diff --git a/src/lib.rs b/src/lib.rs index f631a09..382e43a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -195,7 +195,3 @@ impl Future for Accept { Pin::new(&mut self.0).poll(cx) } } - -#[cfg(feature = "early-data")] -#[cfg(test)] -mod test_0rtt; diff --git a/src/server.rs b/src/server.rs index 91c1cb4..4dad3f6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -105,6 +105,8 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) diff --git a/tests/early-data.rs b/tests/early-data.rs new file mode 100644 index 0000000..57645d1 --- /dev/null +++ b/tests/early-data.rs @@ -0,0 +1,118 @@ +#![cfg(feature = "early-data")] + +use std::io::{ self, BufReader, BufRead, Cursor }; +use std::process::{ Command, Child, Stdio }; +use std::net::SocketAddr; +use std::sync::Arc; +use std::marker::Unpin; +use std::pin::{ Pin }; +use std::task::{ Context, Poll }; +use std::time::Duration; +use tokio::prelude::*; +use tokio::net::TcpStream; +use tokio::io::split; +use tokio::timer::delay_for; +use futures_util::{ future, ready }; +use rustls::ClientConfig; +use tokio_rustls::{ TlsConnector, client::TlsStream }; + + +struct Read1(T); + +impl Future for Read1 { + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut buf = [0]; + ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?; + Poll::Pending + } +} + +async fn send(config: Arc, addr: SocketAddr, data: &[u8]) + -> io::Result> +{ + let connector = TlsConnector::from(config) + .early_data(true); + let stream = TcpStream::connect(&addr).await?; + let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); + + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(data).await?; + stream.flush().await?; + + let (r, mut w) = split(stream); + let fut = Read1(r); + let fut2 = async move { + // sleep 3s + // + // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html + delay_for(Duration::from_secs(3)).await; + w.shutdown().await?; + Ok(w) as io::Result<_> + }; + + let stream = match future::select(fut, fut2.boxed()).await { + future::Either::Left(_) => unreachable!(), + future::Either::Right((Ok(w), Read1(r))) => r.unsplit(w), + future::Either::Right((Err(err), _)) => return Err(err) + }; + + Ok(stream) +} + +struct DropKill(Child); + +impl Drop for DropKill { + fn drop(&mut self) { + self.0.kill().unwrap(); + } +} + +#[tokio::test] +async fn test_0rtt() -> io::Result<()> { + let mut handle = Command::new("openssl") + .arg("s_server") + .arg("-early_data") + .arg("-tls1_3") + .args(&["-cert", "./tests/end.cert"]) + .args(&["-key", "./tests/end.rsa"]) + .args(&["-port", "12354"]) + .stdout(Stdio::piped()) + .spawn() + .map(DropKill)?; + + // wait openssl server + delay_for(Duration::from_secs(3)).await; + + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); + config.root_store.add_pem_file(&mut chain).unwrap(); + config.versions = vec![rustls::ProtocolVersion::TLSv1_3]; + config.enable_early_data = true; + let config = Arc::new(config); + let addr = SocketAddr::from(([127, 0, 0, 1], 12354)); + + let io = send(config.clone(), addr, b"hello").await?; + assert!(!io.get_ref().1.is_early_data_accepted()); + + let io = send(config, addr, b"world!").await?; + assert!(io.get_ref().1.is_early_data_accepted()); + + let stdout = handle.0.stdout.as_mut().unwrap(); + let mut lines = BufReader::new(stdout).lines(); + + for line in lines.by_ref() { + if line?.contains("hello") { + break + } + } + + for line in lines.by_ref() { + if line?.contains("world!") { + break + } + } + + Ok(()) +} diff --git a/tests/test.rs b/tests/test.rs index 6ebdee9..30f8e8a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -7,6 +7,7 @@ use lazy_static::lazy_static; use tokio::prelude::*; use tokio::runtime::current_thread; use tokio::net::{ TcpListener, TcpStream }; +use tokio::io::split; use futures_util::try_future::TryFutureExt; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; @@ -42,17 +43,10 @@ lazy_static!{ while let Some(stream) = incoming.next().await { let acceptor = acceptor.clone(); let fut = async move { - let mut stream = acceptor.accept(stream?).await?; + let stream = acceptor.accept(stream?).await?; - // TODO split - // - // let (mut reader, mut write) = stream.split(); - // reader.copy(&mut write).await?; - - let mut buf = vec![0; 8192]; - let n = stream.read(&mut buf).await?; - stream.write(&buf[..n]).await?; - stream.flush().await?; + let (mut reader, mut writer) = split(stream); + reader.copy(&mut writer).await?; Ok(()) as io::Result<()> }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); @@ -67,7 +61,7 @@ lazy_static!{ }); let addr = recv.recv().unwrap(); - (addr, "localhost", CHAIN) + (addr, "testserver.com", CHAIN) }; } From d8235071cda0e4c970e4e818b5351f444c393fd7 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 10 Oct 2019 22:52:31 +0800 Subject: [PATCH 142/171] move sleep --- tests/early-data.rs | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/early-data.rs b/tests/early-data.rs index 57645d1..ae0d614 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -10,7 +10,6 @@ use std::task::{ Context, Poll }; use std::time::Duration; use tokio::prelude::*; use tokio::net::TcpStream; -use tokio::io::split; use tokio::timer::delay_for; use futures_util::{ future, ready }; use rustls::ClientConfig; @@ -41,22 +40,17 @@ async fn send(config: Arc, addr: SocketAddr, data: &[u8]) stream.write_all(data).await?; stream.flush().await?; - let (r, mut w) = split(stream); - let fut = Read1(r); - let fut2 = async move { - // sleep 3s - // - // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html - delay_for(Duration::from_secs(3)).await; - w.shutdown().await?; - Ok(w) as io::Result<_> + // sleep 3s + // + // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html + let sleep3 = delay_for(Duration::from_secs(3)); + let mut stream = match future::select(Read1(stream), sleep3).await { + future::Either::Right((_, Read1(stream))) => stream, + future::Either::Left((Err(err), _)) => return Err(err), + future::Either::Left((Ok(_), _)) => unreachable!(), }; - let stream = match future::select(fut, fut2.boxed()).await { - future::Either::Left(_) => unreachable!(), - future::Either::Right((Ok(w), Read1(r))) => r.unsplit(w), - future::Either::Right((Err(err), _)) => return Err(err) - }; + stream.shutdown().await?; Ok(stream) } From 7864945694d8c43b4f14f3a9aba726dbe27528df Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 10 Oct 2019 23:25:17 +0800 Subject: [PATCH 143/171] ci: try install openssl --- .travis.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.travis.yml b/.travis.yml index b0b0082..6546945 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,15 @@ matrix: allow_failures: - rust: stable +install: + - wget https://github.com/openssl/openssl/archive/OpenSSL_1_1_1d.tar.gz + - tar xvfz OpenSSL_1_1_1d.tar.gz && cd openssl-OpenSSL_1_1_1d + - if [ $TRAVIS_OS_NAME = "linux" ]; then ./Configure linux-x86_64 --prefix=$HOME/installed_openssl; fi + - if [ $TRAVIS_OS_NAME = "osx" ]; then ./Configure darwin64-x86_64-cc --prefix=$HOME/installed_openssl; fi + - make && make install_sw && cd .. + - export PATH=$HOME/installed_openssl/bin:$PATH + - export LD_LIBRARY_PATH=$HOME/installed_openssl/lib:$LD_LIBRARY_PATH + script: - cargo test - cargo test --features early-data From 9a161beb873b582a5a66545c18e966def04229d0 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 11 Oct 2019 01:01:27 +0800 Subject: [PATCH 144/171] use `write_io` instead of `handshake` --- src/client.rs | 4 ++-- src/common/mod.rs | 6 +++--- src/common/test_stream.rs | 4 ++-- src/server.rs | 4 ++-- tests/early-data.rs | 8 ++++---- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/client.rs b/src/client.rs index 26607b3..9843f54 100644 --- a/src/client.rs +++ b/src/client.rs @@ -56,8 +56,8 @@ where futures::ready!(stream.handshake(cx))?; } - if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + while stream.session.wants_write() { + futures::ready!(stream.write_io(cx))?; } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 195d0da..c176ad5 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -33,7 +33,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { + pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { self.session.process_new_packets() .map_err(|err| { // In case we have an alert to send describing this error, @@ -45,7 +45,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { }) } - fn read_io(&mut self, cx: &mut Context) -> Poll> { + pub fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -71,7 +71,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn write_io(&mut self, cx: &mut Context) -> Poll> { + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 8be20f4..84bcbe6 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -191,8 +191,8 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut ready!(stream.handshake(cx))?; } - if stream.session.wants_write() { - ready!(stream.handshake(cx))?; + while stream.session.wants_write() { + ready!(stream.write_io(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 4dad3f6..ac72904 100644 --- a/src/server.rs +++ b/src/server.rs @@ -51,8 +51,8 @@ where futures::ready!(stream.handshake(cx))?; } - if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + while stream.session.wants_write() { + futures::ready!(stream.write_io(cx))?; } } diff --git a/tests/early-data.rs b/tests/early-data.rs index ae0d614..9dd6b5e 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -40,11 +40,11 @@ async fn send(config: Arc, addr: SocketAddr, data: &[u8]) stream.write_all(data).await?; stream.flush().await?; - // sleep 3s + // sleep 1s // // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html - let sleep3 = delay_for(Duration::from_secs(3)); - let mut stream = match future::select(Read1(stream), sleep3).await { + let sleep1 = delay_for(Duration::from_secs(1)); + let mut stream = match future::select(Read1(stream), sleep1).await { future::Either::Right((_, Read1(stream))) => stream, future::Either::Left((Err(err), _)) => return Err(err), future::Either::Left((Ok(_), _)) => unreachable!(), @@ -77,7 +77,7 @@ async fn test_0rtt() -> io::Result<()> { .map(DropKill)?; // wait openssl server - delay_for(Duration::from_secs(3)).await; + delay_for(Duration::from_secs(1)).await; let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); From 10c139df082db6c7c49d05260b1645004d0a296f Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 11 Oct 2019 01:24:27 +0800 Subject: [PATCH 145/171] test: split bad channel --- src/common/test_stream.rs | 41 +++++++++++++++++++++++++++++---------- src/lib.rs | 2 +- tests/early-data.rs | 7 +++++++ 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 84bcbe6..d706e37 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -32,6 +32,31 @@ impl<'a> AsyncWrite for Good<'a> { Poll::Ready(Ok(len)) } + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.0.send_close_notify(); + Poll::Ready(Ok(())) + } +} + +struct Pending; + +impl AsyncRead for Pending { + fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + Poll::Pending + } +} + +impl AsyncWrite for Pending { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll> { + Poll::Pending + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -41,21 +66,17 @@ impl<'a> AsyncWrite for Good<'a> { } } -struct Bad(bool); +struct Eof; -impl AsyncRead for Bad { +impl AsyncRead for Eof { fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { Poll::Ready(Ok(0)) } } -impl AsyncWrite for Bad { +impl AsyncWrite for Eof { fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - if self.0 { - Poll::Pending - } else { - Poll::Ready(Ok(buf.len())) - } + Poll::Ready(Ok(buf.len())) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -99,7 +120,7 @@ async fn stream_bad() -> io::Result<()> { poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; client.set_buffer_limit(1024); - let mut bad = Bad(true); + let mut bad = Pending; let mut stream = Stream::new(&mut bad, &mut client); assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); @@ -138,7 +159,7 @@ async fn stream_handshake() -> io::Result<()> { async fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair(); - let mut bad = Bad(false); + let mut bad = Eof; let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); diff --git a/src/lib.rs b/src/lib.rs index 382e43a..3dea67f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,6 +146,7 @@ impl TlsConnector { } impl TlsAcceptor { + #[inline] pub fn accept(&self, stream: IO) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, @@ -153,7 +154,6 @@ impl TlsAcceptor { self.accept_with(stream, |_| ()) } - #[inline] pub fn accept_with(&self, stream: IO, f: F) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, diff --git a/tests/early-data.rs b/tests/early-data.rs index 9dd6b5e..7a43034 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -96,17 +96,24 @@ async fn test_0rtt() -> io::Result<()> { let stdout = handle.0.stdout.as_mut().unwrap(); let mut lines = BufReader::new(stdout).lines(); + let mut f1 = false; + let mut f2 = false; + for line in lines.by_ref() { if line?.contains("hello") { + f1 = true; break } } for line in lines.by_ref() { if line?.contains("world!") { + f2 = true; break } } + assert!(f1 && f2); + Ok(()) } From 086758837f855adf0ba5f7b14855e4a336219429 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 11 Oct 2019 01:31:00 +0800 Subject: [PATCH 146/171] remove unnecessary `get_mut()` --- src/common/mod.rs | 46 ++++++++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index c176ad5..a870131 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -152,18 +152,17 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { - let this = self.get_mut(); + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { let mut pos = 0; while pos != buf.len() { let mut would_block = false; // read a packet - while this.session.wants_read() { - match this.read_io(cx) { + while self.session.wants_read() { + match self.read_io(cx) { Poll::Ready(Ok(0)) => { - this.eof = true; + self.eof = true; break }, Poll::Ready(Ok(_)) => (), @@ -175,11 +174,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } } - this.process_new_packets(cx)?; + self.process_new_packets(cx)?; - return match this.session.read(&mut buf[pos..]) { + return match self.session.read(&mut buf[pos..]) { Ok(0) if pos == 0 && would_block => Poll::Pending, - Ok(n) if this.eof || would_block => Poll::Ready(Ok(pos + n)), + Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)), Ok(n) => { pos += n; continue @@ -196,21 +195,20 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - let this = self.get_mut(); + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let mut pos = 0; while pos != buf.len() { let mut would_block = false; - match this.session.write(&buf[pos..]) { + match self.session.write(&buf[pos..]) { Ok(n) => pos += n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), Err(err) => return Poll::Ready(Err(err)) }; - while this.session.wants_write() { - match this.write_io(cx) { + while self.session.wants_write() { + match self.write_io(cx) { Poll::Ready(Ok(0)) | Poll::Pending => { would_block = true; break @@ -230,24 +228,20 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Poll::Ready(Ok(pos)) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.get_mut(); - - this.session.flush()?; - while this.session.wants_write() { - futures::ready!(this.write_io(cx))?; + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.session.flush()?; + while self.session.wants_write() { + futures::ready!(self.write_io(cx))?; } - Pin::new(&mut this.io).poll_flush(cx) + Pin::new(&mut self.io).poll_flush(cx) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - while this.session.wants_write() { - futures::ready!(this.write_io(cx))?; + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while self.session.wants_write() { + futures::ready!(self.write_io(cx))?; } - Pin::new(&mut this.io).poll_shutdown(cx) + Pin::new(&mut self.io).poll_shutdown(cx) } } From 9f6d3c74bf954af0dc79b78e2d4adf0bb27ab347 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 23 Oct 2019 19:33:38 +0800 Subject: [PATCH 147/171] release 0.12.0-alpha.5 --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 92fd13b..494b339 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.4" +version = "0.12.0-alpha.5" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -29,4 +29,4 @@ early-data = [] tokio = "=0.2.0-alpha.6" futures-util-preview = "0.3.0-alpha.19" lazy_static = "1" -webpki-roots = "0.17" +webpki-roots = "0.18" From 4b0dd05e86928a1b5aa380c75afdf643ff4ec3fe Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 1 Nov 2019 00:04:06 +0800 Subject: [PATCH 148/171] Move ci to github actions --- .github/actions-rs/grcov.yml | 4 ++++ .github/workflows/ci.yml | 32 ++++++++++++++++++++++++++++++++ .travis.yml | 32 -------------------------------- Cargo.toml | 3 +-- README.md | 3 +-- appveyor.yml | 21 --------------------- tests/badssl.rs | 1 + 7 files changed, 39 insertions(+), 57 deletions(-) create mode 100644 .github/actions-rs/grcov.yml create mode 100644 .github/workflows/ci.yml delete mode 100644 .travis.yml delete mode 100644 appveyor.yml diff --git a/.github/actions-rs/grcov.yml b/.github/actions-rs/grcov.yml new file mode 100644 index 0000000..477274f --- /dev/null +++ b/.github/actions-rs/grcov.yml @@ -0,0 +1,4 @@ +branch: true +llvm: true +output-type: lcov +output-file: ./lcov.info diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..12acc8c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: Rust + +on: [push, pull_request] + +jobs: + cov: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v1 + + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + + - uses: actions-rs/cargo@v1 + with: + command: test + args: --features early-data + env: + 'CARGO_INCREMENTAL': '0' + 'RUSTFLAGS': '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Zno-landing-pads' + + - id: grcov + uses: actions-rs/grcov@v0.1 + + - name: Update Codecov + uses: codecov/codecov-action@v1.0.3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ${{ steps.grcov.outputs.report }} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 6546945..0000000 --- a/.travis.yml +++ /dev/null @@ -1,32 +0,0 @@ -language: rust -cache: cargo - -matrix: - include: - - rust: stable - os: linux - - rust: nightly - os: linux - - rust: stable - os: osx - - rust: nightly - os: osx - allow_failures: - - rust: stable - -install: - - wget https://github.com/openssl/openssl/archive/OpenSSL_1_1_1d.tar.gz - - tar xvfz OpenSSL_1_1_1d.tar.gz && cd openssl-OpenSSL_1_1_1d - - if [ $TRAVIS_OS_NAME = "linux" ]; then ./Configure linux-x86_64 --prefix=$HOME/installed_openssl; fi - - if [ $TRAVIS_OS_NAME = "osx" ]; then ./Configure darwin64-x86_64-cc --prefix=$HOME/installed_openssl; fi - - make && make install_sw && cd .. - - export PATH=$HOME/installed_openssl/bin:$PATH - - export LD_LIBRARY_PATH=$HOME/installed_openssl/lib:$LD_LIBRARY_PATH - -script: - - cargo test - - cargo test --features early-data - # - cd examples/server - # - cargo check - # - cd ../../examples/client - # - cargo check diff --git a/Cargo.toml b/Cargo.toml index 494b339..552fecc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,7 @@ categories = ["asynchronous", "cryptography", "network-programming"] edition = "2018" [badges] -travis-ci = { repository = "quininer/tokio-rustls" } -appveyor = { repository = "quininer/tokio-rustls" } +github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } [dependencies] smallvec = "0.6" diff --git a/README.md b/README.md index d1e7a30..0511849 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # tokio-rustls -[![travis-ci](https://travis-ci.org/quininer/tokio-rustls.svg?branch=master)](https://travis-ci.org/quininer/tokio-rustls) -[![appveyor](https://ci.appveyor.com/api/projects/status/4ukw15enii50suqi?svg=true)](https://ci.appveyor.com/project/quininer/tokio-rustls) +[![github actions](https://github.com/quininer/tokio-rustls/workflows/ci/badge.svg)](https://github.com/quininer/tokio-rustls/actions) [![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) [![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-MIT) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-APACHE) diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 26db365..0000000 --- a/appveyor.yml +++ /dev/null @@ -1,21 +0,0 @@ -environment: - matrix: - - TARGET: x86_64-pc-windows-msvc - - TARGET: i686-pc-windows-msvc - -install: - - appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe - - rustup-init.exe -y --default-host %TARGET% - - set PATH=%PATH%;%USERPROFILE%\.cargo\bin - - rustc --version - - cargo --version - -build: false - -test_script: - - 'cargo test' - - 'cargo test --features early-data' - - 'cd examples/server' - - 'cargo check' - - 'cd ../../examples/client' - - 'cargo check' diff --git a/tests/badssl.rs b/tests/badssl.rs index 74bd294..3a02e86 100644 --- a/tests/badssl.rs +++ b/tests/badssl.rs @@ -42,6 +42,7 @@ async fn test_tls12() -> io::Result<()> { Ok(()) } +#[ignore] #[should_panic] #[test] fn test_tls13() { From 03b1f3b45444dd18ee54cbd7ee5f2c8d941eb5cd Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 1 Nov 2019 01:36:50 +0800 Subject: [PATCH 149/171] Add codecov badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0511849..c851af3 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # tokio-rustls [![github actions](https://github.com/quininer/tokio-rustls/workflows/ci/badge.svg)](https://github.com/quininer/tokio-rustls/actions) +[![codecov](https://codecov.io/gh/quininer/tokio-rustls/branch/master/graph/badge.svg)](https://codecov.io/gh/quininer/tokio-rustls) [![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) [![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-MIT) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-APACHE) From 7cccd9c3b3d0cd9695b50da78e77afb760298b9f Mon Sep 17 00:00:00 2001 From: gvallat Date: Wed, 6 Nov 2019 09:18:30 +0100 Subject: [PATCH 150/171] Export rustls dangerous_configuration feature --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 552fecc..c2bfee9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ webpki = "0.21" [features] early-data = [] +dangerous_configuration = ["rustls/dangerous_configuration"] [dev-dependencies] tokio = "=0.2.0-alpha.6" From 3e2c0446a41cca4873f16e4909527c5f49a21f35 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 6 Nov 2019 11:41:59 +0100 Subject: [PATCH 151/171] Port unified TLS stream type to tokio-0.2 --- Cargo.toml | 1 + src/lib.rs | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c2bfee9..70b7968 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } smallvec = "0.6" tokio-io = "=0.2.0-alpha.6" futures-core-preview = "=0.3.0-alpha.19" +pin-project = "0.4" rustls = "0.16" webpki = "0.21" diff --git a/src/lib.rs b/src/lib.rs index 3dea67f..9b9fd58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,8 @@ pub mod server; use common::Stream; use futures_core as futures; -use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use pin_project::{pin_project, project}; +use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -195,3 +196,113 @@ impl Future for Accept { Pin::new(&mut self.0).poll(cx) } } + +/// Unified TLS stream type +/// +/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use +/// a single type to keep both client- and server-initiated TLS-encrypted connections. +#[pin_project] +pub enum TlsStream { + Client(#[pin] client::TlsStream), + Server(#[pin] server::TlsStream), +} + +impl TlsStream { + pub fn get_ref(&self) -> (&T, &dyn Session) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_ref(); + (io, &*session) + } + Server(io) => { + let (io, session) = io.get_ref(); + (io, &*session) + } + } + } + + pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + Server(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + } + } +} + +impl From> for TlsStream { + fn from(s: client::TlsStream) -> Self { + Self::Client(s) + } +} + +impl From> for TlsStream { + fn from(s: server::TlsStream) -> Self { + Self::Server(s) + } +} + +impl AsyncRead for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[project] + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_read(cx, buf), + TlsStream::Server(x) => x.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[project] + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_write(cx, buf), + TlsStream::Server(x) => x.poll_write(cx, buf), + } + } + + #[project] + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_flush(cx), + TlsStream::Server(x) => x.poll_flush(cx), + } + } + + #[project] + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_shutdown(cx), + TlsStream::Server(x) => x.poll_shutdown(cx), + } + } +} From 872510bd65949afff0c76b9218c0c8db8263e7d5 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 6 Nov 2019 21:43:50 +0800 Subject: [PATCH 152/171] Fix 0-RTT flush --- src/client.rs | 6 +++++- src/lib.rs | 18 +++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9843f54..a8447bb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -154,7 +154,7 @@ where // end this.state = TlsState::Stream; - data.clear(); + *data = Vec::new(); stream.as_mut_pin().poll_write(cx, buf) } _ => stream.as_mut_pin().poll_write(cx, buf), @@ -171,6 +171,10 @@ where while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } + + this.state = TlsState::Stream; + let (_, data) = &mut this.early_data; + *data = Vec::new(); } stream.as_mut_pin().poll_flush(cx) diff --git a/src/lib.rs b/src/lib.rs index 9b9fd58..31545dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,20 +1,20 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -pub mod client; mod common; +pub mod client; pub mod server; -use common::Stream; -use futures_core as futures; -use pin_project::{pin_project, project}; -use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; -use std::future::Future; +use std::{ io, mem }; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{io, mem}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::future::Future; +use std::task::{ Context, Poll }; +use futures_core as futures; +use pin_project::{ pin_project, project }; +use tokio_io::{ AsyncRead, AsyncWrite }; use webpki::DNSNameRef; +use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; +use common::Stream; pub use rustls; pub use webpki; From ff3d0a4de3f0493b9c108baf9555c2108f406f22 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 6 Nov 2019 21:44:21 +0800 Subject: [PATCH 153/171] bump 0.12.0-alpha.6 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 70b7968..3cf2ef3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.5" +version = "0.12.0-alpha.6" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From ba909ed95ea7352cba17ce0493d3aff0253e609f Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 7 Nov 2019 10:57:14 +0800 Subject: [PATCH 154/171] Fix 0-RTT fallback --- src/client.rs | 38 ++++++++++++++++++++++++++------------ src/common/mod.rs | 1 - src/lib.rs | 2 +- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/client.rs b/src/client.rs index a8447bb..2c6229b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -167,14 +167,26 @@ where .set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { - // complete handshake - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } + if let TlsState::EarlyData = this.state { + let (pos, data) = &mut this.early_data; - this.state = TlsState::Stream; - let (_, data) = &mut this.early_data; - *data = Vec::new(); + // complete handshake + while stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + this.state = TlsState::Stream; + let (_, data) = &mut this.early_data; + *data = Vec::new(); + } } stream.as_mut_pin().poll_flush(cx) @@ -186,14 +198,16 @@ where self.state.shutdown_write(); } + #[cfg(feature = "early-data")] { + // we skip the handshake + if let TlsState::EarlyData = self.state { + return Pin::new(&mut self.io).poll_shutdown(cx); + } + } + let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - - // TODO - // - // should we complete the handshake? - stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index a870131..3083534 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -240,7 +240,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' while self.session.wants_write() { futures::ready!(self.write_io(cx))?; } - Pin::new(&mut self.io).poll_shutdown(cx) } } diff --git a/src/lib.rs b/src/lib.rs index 31545dc..1c09cef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,7 +127,7 @@ impl TlsConnector { #[cfg(feature = "early-data")] { - Connect(if self.early_data { + Connect(if self.early_data && session.early_data().is_some() { client::MidHandshake::EarlyData(client::TlsStream { session, io: stream, From fe113dc6b00012f82b98370cc57aa7bbd71e384c Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 7 Nov 2019 10:58:06 +0800 Subject: [PATCH 155/171] bump 0.12.0-alpha.7 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 3cf2ef3..8b44e63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.6" +version = "0.12.0-alpha.7" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 262796af396737266d055b76643201a523182cbc Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 7 Nov 2019 22:52:27 +0800 Subject: [PATCH 156/171] Clean TlsState --- Cargo.toml | 1 - src/client.rs | 40 ++++++++++++---------------------- src/lib.rs | 59 ++++++++++++++++++++++++++------------------------- src/server.rs | 2 +- 4 files changed, 45 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8b44e63..f205f6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ edition = "2018" github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } [dependencies] -smallvec = "0.6" tokio-io = "=0.2.0-alpha.6" futures-core-preview = "=0.3.0-alpha.19" pin-project = "0.4" diff --git a/src/client.rs b/src/client.rs index 2c6229b..632e4ed 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,15 +8,10 @@ pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ClientSession, pub(crate) state: TlsState, - - #[cfg(feature = "early-data")] - pub(crate) early_data: (usize, Vec), } pub(crate) enum MidHandshake { Handshaking(TlsStream), - #[cfg(feature = "early-data")] - EarlyData(TlsStream), End, } @@ -48,23 +43,23 @@ where let this = self.get_mut(); if let MidHandshake::Handshaking(stream) = this { - let eof = !stream.state.readable(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session).set_eof(eof); + if !stream.state.is_early_data() { + let eof = !stream.state.readable(); + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session).set_eof(eof); - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } + while stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } - while stream.session.wants_write() { - futures::ready!(stream.write_io(cx))?; + while stream.session.wants_write() { + futures::ready!(stream.write_io(cx))?; + } } } match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), - #[cfg(feature = "early-data")] - MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), MidHandshake::End => panic!(), } } @@ -81,7 +76,7 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => Poll::Pending, + TlsState::EarlyData(..) => Poll::Pending, TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -122,11 +117,9 @@ where match this.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => { + TlsState::EarlyData(ref mut pos, ref mut data) => { use std::io::Write; - let (pos, data) = &mut this.early_data; - // write early data if let Some(mut early_data) = stream.session.early_data() { let len = match early_data.write(buf) { @@ -154,7 +147,6 @@ where // end this.state = TlsState::Stream; - *data = Vec::new(); stream.as_mut_pin().poll_write(cx, buf) } _ => stream.as_mut_pin().poll_write(cx, buf), @@ -167,9 +159,7 @@ where .set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { - if let TlsState::EarlyData = this.state { - let (pos, data) = &mut this.early_data; - + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { // complete handshake while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; @@ -184,8 +174,6 @@ where } this.state = TlsState::Stream; - let (_, data) = &mut this.early_data; - *data = Vec::new(); } } @@ -200,7 +188,7 @@ where #[cfg(feature = "early-data")] { // we skip the handshake - if let TlsState::EarlyData = self.state { + if let TlsState::EarlyData(..) = self.state { return Pin::new(&mut self.io).poll_shutdown(cx); } } diff --git a/src/lib.rs b/src/lib.rs index 1c09cef..d5113e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,10 +19,10 @@ use common::Stream; pub use rustls; pub use webpki; -#[derive(Debug, Copy, Clone)] +#[derive(Debug)] enum TlsState { #[cfg(feature = "early-data")] - EarlyData, + EarlyData(usize, Vec), Stream, ReadShutdown, WriteShutdown, @@ -51,12 +51,25 @@ impl TlsState { } } - fn readable(self) -> bool { + fn readable(&self) -> bool { match self { TlsState::ReadShutdown | TlsState::FullyShutdown => false, _ => true, } } + + #[cfg(feature = "early-data")] + fn is_early_data(&self) -> bool { + match self { + TlsState::EarlyData(..) => true, + _ => false + } + } + + #[cfg(not(feature = "early-data"))] + const fn is_early_data(&self) -> bool { + false + } } /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. @@ -100,6 +113,7 @@ impl TlsConnector { self } + #[inline] pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, @@ -107,7 +121,6 @@ impl TlsConnector { self.connect_with(domain, stream, |_| ()) } - #[inline] pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, @@ -116,33 +129,21 @@ impl TlsConnector { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - #[cfg(not(feature = "early-data"))] - { - Connect(client::MidHandshake::Handshaking(client::TlsStream { - session, - io: stream, - state: TlsState::Stream, - })) - } + Connect(client::MidHandshake::Handshaking(client::TlsStream { + io: stream, - #[cfg(feature = "early-data")] - { - Connect(if self.early_data && session.early_data().is_some() { - client::MidHandshake::EarlyData(client::TlsStream { - session, - io: stream, - state: TlsState::EarlyData, - early_data: (0, Vec::new()), - }) + #[cfg(not(feature = "early-data"))] + state: TlsState::Stream, + + #[cfg(feature = "early-data")] + state: if self.early_data && session.early_data().is_some() { + TlsState::EarlyData(0, Vec::new()) } else { - client::MidHandshake::Handshaking(client::TlsStream { - session, - io: stream, - state: TlsState::Stream, - early_data: (0, Vec::new()), - }) - }) - } + TlsState::Stream + }, + + session + })) } } diff --git a/src/server.rs b/src/server.rs index ac72904..87ce3f8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -76,7 +76,7 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - match this.state { + match &this.state { TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); From 8b3bf3a2b6c2e4427e2ca6173307e4ec62ca6966 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 7 Nov 2019 23:46:49 +0800 Subject: [PATCH 157/171] Remove pin-project We always constrain T is Unpin, so we don't need pin-project. --- Cargo.toml | 1 - src/lib.rs | 38 ++++++++++++++------------------------ 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f205f6e..d4f3448 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } [dependencies] tokio-io = "=0.2.0-alpha.6" futures-core-preview = "=0.3.0-alpha.19" -pin-project = "0.4" rustls = "0.16" webpki = "0.21" diff --git a/src/lib.rs b/src/lib.rs index d5113e5..8caf1e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,6 @@ use std::sync::Arc; use std::future::Future; use std::task::{ Context, Poll }; use futures_core as futures; -use pin_project::{ pin_project, project }; use tokio_io::{ AsyncRead, AsyncWrite }; use webpki::DNSNameRef; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; @@ -202,10 +201,9 @@ impl Future for Accept { /// /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use /// a single type to keep both client- and server-initiated TLS-encrypted connections. -#[pin_project] pub enum TlsStream { - Client(#[pin] client::TlsStream), - Server(#[pin] server::TlsStream), + Client(client::TlsStream), + Server(server::TlsStream), } impl TlsStream { @@ -254,17 +252,15 @@ impl AsyncRead for TlsStream where T: AsyncRead + AsyncWrite + Unpin, { - #[project] #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - #[project] - match self.project() { - TlsStream::Client(x) => x.poll_read(cx, buf), - TlsStream::Server(x) => x.poll_read(cx, buf), + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf), } } } @@ -273,37 +269,31 @@ impl AsyncWrite for TlsStream where T: AsyncRead + AsyncWrite + Unpin, { - #[project] #[inline] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - #[project] - match self.project() { - TlsStream::Client(x) => x.poll_write(cx, buf), - TlsStream::Server(x) => x.poll_write(cx, buf), + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf), } } - #[project] #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - #[project] - match self.project() { - TlsStream::Client(x) => x.poll_flush(cx), - TlsStream::Server(x) => x.poll_flush(cx), + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_flush(cx), + TlsStream::Server(x) => Pin::new(x).poll_flush(cx), } } - #[project] #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - #[project] - match self.project() { - TlsStream::Client(x) => x.poll_shutdown(cx), - TlsStream::Server(x) => x.poll_shutdown(cx), + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx), + TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx), } } } From 07c51665da3797643f9e1c7350ccaba8e61fa6d2 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 8 Nov 2019 01:39:20 +0800 Subject: [PATCH 158/171] Fix 0-RTT write zero --- src/client.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 632e4ed..8671776 100644 --- a/src/client.rs +++ b/src/client.rs @@ -129,7 +129,9 @@ where Err(err) => return Poll::Ready(Err(err)) }; data.extend_from_slice(&buf[..len]); - return Poll::Ready(Ok(len)); + if len != 0 { + return Poll::Ready(Ok(len)); + } } // complete handshake From 314625390756e6ff73af9d984913b49b54d1a92e Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 8 Nov 2019 01:40:00 +0800 Subject: [PATCH 159/171] bump 0.12.0-alpha.8 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d4f3448..daec434 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.7" +version = "0.12.0-alpha.8" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From 61b2f5b3bc7b06a0f1b6985af86dd2160411b915 Mon Sep 17 00:00:00 2001 From: Gleb Pomykalov Date: Wed, 27 Nov 2019 01:37:00 +0300 Subject: [PATCH 160/171] Migrate to tokio 0.2 and futures 0.3 --- Cargo.toml | 8 ++++---- src/client.rs | 2 +- src/common/mod.rs | 2 +- src/common/test_stream.rs | 2 +- src/lib.rs | 2 +- src/server.rs | 2 +- tests/test.rs | 35 +++++++++++++++++++---------------- 7 files changed, 28 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index daec434..3bc10f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,8 @@ edition = "2018" github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } [dependencies] -tokio-io = "=0.2.0-alpha.6" -futures-core-preview = "=0.3.0-alpha.19" +tokio = "0.2.0" +futures-core = "0.3.1" rustls = "0.16" webpki = "0.21" @@ -25,7 +25,7 @@ early-data = [] dangerous_configuration = ["rustls/dangerous_configuration"] [dev-dependencies] -tokio = "=0.2.0-alpha.6" -futures-util-preview = "0.3.0-alpha.19" +tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] } +futures-util = "0.3.1" lazy_static = "1" webpki-roots = "0.18" diff --git a/src/client.rs b/src/client.rs index 8671776..470194a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -69,7 +69,7 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { self.io.prepare_uninitialized_buffer(buf) } diff --git a/src/common/mod.rs b/src/common/mod.rs index 3083534..1deb21d 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,7 +3,7 @@ use std::task::{ Poll, Context }; use std::marker::Unpin; use std::io::{ self, Read, Write }; use rustls::Session; -use tokio_io::{ AsyncRead, AsyncWrite }; +use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index d706e37..0055014 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -4,7 +4,7 @@ use std::task::{ Poll, Context }; use futures_core::ready; use futures_util::future::poll_fn; use futures_util::task::noop_waker_ref; -use tokio_io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; +use tokio::io::{ AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt }; use std::io::{ self, Read, Write, BufReader, Cursor }; use webpki::DNSNameRef; use rustls::internal::pemfile::{ certs, rsa_private_keys }; diff --git a/src/lib.rs b/src/lib.rs index 8caf1e5..706e972 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use std::future::Future; use std::task::{ Context, Poll }; use futures_core as futures; -use tokio_io::{ AsyncRead, AsyncWrite }; +use tokio::io::{ AsyncRead, AsyncWrite }; use webpki::DNSNameRef; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; use common::Stream; diff --git a/src/server.rs b/src/server.rs index 87ce3f8..9066c27 100644 --- a/src/server.rs +++ b/src/server.rs @@ -67,7 +67,7 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { self.io.prepare_uninitialized_buffer(buf) } diff --git a/tests/test.rs b/tests/test.rs index 30f8e8a..d0d30f4 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -3,12 +3,12 @@ use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::SocketAddr; +use futures_util::future::TryFutureExt; use lazy_static::lazy_static; use tokio::prelude::*; -use tokio::runtime::current_thread; +use tokio::runtime; use tokio::net::{ TcpListener, TcpStream }; -use tokio::io::split; -use futures_util::try_future::TryFutureExt; +use tokio::io::{copy, split}; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ TlsConnector, TlsAcceptor }; @@ -30,32 +30,36 @@ lazy_static!{ let (send, recv) = channel(); thread::spawn(move || { - let mut runtime = current_thread::Runtime::new().unwrap(); - let handle = runtime.handle(); + let mut runtime = runtime::Builder::new() + .basic_scheduler() + .enable_io() + .build() + .unwrap(); + + let handle = runtime.handle().clone(); let done = async move { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(&addr).await?; + let mut listener = TcpListener::bind(&addr).await?; send.send(listener.local_addr()?).unwrap(); - let mut incoming = listener.incoming(); - while let Some(stream) = incoming.next().await { + loop { + let (stream, _) = listener.accept().await?; + let acceptor = acceptor.clone(); let fut = async move { - let stream = acceptor.accept(stream?).await?; + let stream = acceptor.accept(stream).await?; let (mut reader, mut writer) = split(stream); - reader.copy(&mut writer).await?; + copy(&mut reader, &mut writer).await?; Ok(()) as io::Result<()> }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - handle.spawn(fut).unwrap(); + handle.spawn(fut); } - - Ok(()) as io::Result<()> - }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + }.unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err)); runtime.block_on(done); }); @@ -95,8 +99,7 @@ async fn pass() -> io::Result<()> { // TcpStream::bind now returns a future it creates a race // condition until its ready sometimes. use std::time::*; - let deadline = Instant::now() + Duration::from_secs(1); - tokio::timer::delay(deadline); + tokio::time::delay_for(Duration::from_secs(1)).await; let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(chain)); From 078f6c0e738c411015b0cb7e1de4b19a0dd50db9 Mon Sep 17 00:00:00 2001 From: Gleb Pomykalov Date: Wed, 27 Nov 2019 01:57:32 +0300 Subject: [PATCH 161/171] Fix early-data --- tests/early-data.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/early-data.rs b/tests/early-data.rs index 7a43034..c69750e 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -10,10 +10,11 @@ use std::task::{ Context, Poll }; use std::time::Duration; use tokio::prelude::*; use tokio::net::TcpStream; -use tokio::timer::delay_for; +use tokio::time::delay_for; use futures_util::{ future, ready }; use rustls::ClientConfig; use tokio_rustls::{ TlsConnector, client::TlsStream }; +use std::future::Future; struct Read1(T); From d42540f52f405555c278ecf65063f9a8fb93cd48 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 27 Nov 2019 22:23:10 +0800 Subject: [PATCH 162/171] release 0.12.0 --- Cargo.toml | 3 ++- src/client.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3bc10f3..cf10494 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0-alpha.8" +version = "0.12.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" @@ -15,6 +15,7 @@ edition = "2018" github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } [dependencies] +bytes = "0.5" tokio = "0.2.0" futures-core = "0.3.1" rustls = "0.16" diff --git a/src/client.rs b/src/client.rs index 470194a..779ddc5 100644 --- a/src/client.rs +++ b/src/client.rs @@ -128,8 +128,8 @@ where return Poll::Pending, Err(err) => return Poll::Ready(Err(err)) }; - data.extend_from_slice(&buf[..len]); if len != 0 { + data.extend_from_slice(&buf[..len]); return Poll::Ready(Ok(len)); } } From 34b1bc9c830010909f1cbaade66494444caf5b74 Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 28 Nov 2019 00:11:02 +0800 Subject: [PATCH 163/171] Update example --- README.md | 4 ++-- examples/client/Cargo.toml | 6 +++--- examples/client/src/main.rs | 24 ++++++++++++++---------- examples/server/Cargo.toml | 4 ++-- examples/server/src/main.rs | 35 +++++++++++++++++------------------ tests/test.rs | 2 +- 6 files changed, 39 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index c851af3..70ab646 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@ config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); let config = TlsConnector::from(Arc::new(config)); let dnsname = DNSNameRef::try_from_ascii_str("www.rust-lang.org").unwrap(); -TcpStream::connect(&addr) - .and_then(move |socket| config.connect(dnsname, socket)) +let stream = TcpStream::connect(&addr).await?; +let mut stream = config.connect(dnsname, stream).await?; // ... ``` diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index feec249..4e29373 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -5,9 +5,9 @@ authors = ["quininer "] edition = "2018" [dependencies] -futures = { package = "futures-preview", version = "0.3.0-alpha.16", features = ["io-compat"] } -romio = "0.3.0-alpha.8" +futures-util = "0.3" +tokio = { version = "0.2", features = [ "net", "io-util", "rt-threaded" ] } structopt = "0.2" tokio-rustls = { path = "../.." } -webpki-roots = "0.16" +webpki-roots = "0.18" tokio-stdin-stdout = "0.1" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 6416db2..8f130d2 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -1,16 +1,14 @@ -#![feature(async_await)] - use std::io; use std::fs::File; use std::path::PathBuf; use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; +use futures_util::future; use structopt::StructOpt; -use romio::TcpStream; -use futures::prelude::*; -use futures::executor; -use futures::compat::{ AsyncRead01CompatExt, AsyncWrite01CompatExt }; +use tokio::runtime; +use tokio::net::TcpStream; +use tokio::io::{ AsyncWriteExt, copy, split }; use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef }; use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; @@ -46,6 +44,10 @@ fn main() -> io::Result<()> { domain ); + let mut runtime = runtime::Builder::new() + .basic_scheduler() + .enable_io() + .build()?; let mut config = ClientConfig::new(); if let Some(cafile) = &options.cafile { let mut pem = BufReader::new(File::open(cafile)?); @@ -58,6 +60,8 @@ fn main() -> io::Result<()> { let fut = async { let stream = TcpStream::connect(&addr).await?; + + // TODO tokio-compat let (mut stdin, mut stdout) = (tokio_stdin(0).compat(), tokio_stdout(0).compat()); let domain = DNSNameRef::try_from_ascii_str(&domain) @@ -66,14 +70,14 @@ fn main() -> io::Result<()> { let mut stream = connector.connect(domain, stream).await?; stream.write_all(content.as_bytes()).await?; - let (mut reader, mut writer) = stream.split(); + let (mut reader, mut writer) = split(stream); future::try_join( - reader.copy_into(&mut stdout), - stdin.copy_into(&mut writer) + copy(&mut reader, &mut stdout), + copy(&mut stdin, &mut writer) ).await?; Ok(()) }; - executor::block_on(fut) + runtime.block_on(fut) } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 9da4423..cd42662 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -5,7 +5,7 @@ authors = ["quininer "] edition = "2018" [dependencies] -futures = { package = "futures-preview", version = "0.3.0-alpha.16" } -romio = "0.3.0-alpha.8" +futures-util = "0.3" +tokio = { version = "0.2", features = [ "net", "io-util", "rt-threaded" ] } structopt = "0.2" tokio-rustls = { path = "../.." } diff --git a/examples/server/src/main.rs b/examples/server/src/main.rs index a2a3b13..1fcb0e4 100644 --- a/examples/server/src/main.rs +++ b/examples/server/src/main.rs @@ -1,15 +1,13 @@ -#![feature(async_await)] - use std::fs::File; use std::sync::Arc; use std::net::ToSocketAddrs; use std::path::{ PathBuf, Path }; use std::io::{ self, BufReader }; +use futures_util::future::TryFutureExt; use structopt::StructOpt; -use futures::task::SpawnExt; -use futures::prelude::*; -use futures::executor; -use romio::TcpListener; +use tokio::runtime; +use tokio::net::TcpListener; +use tokio::io::{ AsyncWriteExt, copy, split }; use tokio_rustls::rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; use tokio_rustls::rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::TlsAcceptor; @@ -53,27 +51,30 @@ fn main() -> io::Result<()> { let mut keys = load_keys(&options.key)?; let flag_echo = options.echo; - let mut pool = executor::ThreadPool::new()?; + let mut runtime = runtime::Builder::new() + .threaded_scheduler() + .enable_io() + .build()?; + let handle = runtime.handle().clone(); let mut config = ServerConfig::new(NoClientAuth::new()); config.set_single_cert(certs, keys.remove(0)) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; let acceptor = TlsAcceptor::from(Arc::new(config)); let fut = async { - let mut listener = TcpListener::bind(&addr)?; - let mut incoming = listener.incoming(); + let mut listener = TcpListener::bind(&addr).await?; - while let Some(stream) = incoming.next().await { + loop { + let (stream, peer_addr) = listener.accept().await?; let acceptor = acceptor.clone(); let fut = async move { - let stream = stream?; - let peer_addr = stream.peer_addr()?; let mut stream = acceptor.accept(stream).await?; if flag_echo { - let (mut reader, mut writer) = stream.split(); - let n = reader.copy_into(&mut writer).await?; + let (mut reader, mut writer) = split(stream); + let n = copy(&mut reader, &mut writer).await?; + writer.flush().await?; println!("Echo: {} - {}", peer_addr, n); } else { stream.write_all( @@ -90,11 +91,9 @@ fn main() -> io::Result<()> { Ok(()) as io::Result<()> }; - pool.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))); } - - Ok(()) }; - executor::block_on(fut) + runtime.block_on(fut) } diff --git a/tests/test.rs b/tests/test.rs index d0d30f4..9b98688 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -7,8 +7,8 @@ use futures_util::future::TryFutureExt; use lazy_static::lazy_static; use tokio::prelude::*; use tokio::runtime; +use tokio::io::{ copy, split }; use tokio::net::{ TcpListener, TcpStream }; -use tokio::io::{copy, split}; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ TlsConnector, TlsAcceptor }; From 02028c54b8d5e99a1cab1857b964f17180f023b9 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 2 Dec 2019 23:43:49 +0800 Subject: [PATCH 164/171] Fix client example --- examples/client/Cargo.toml | 2 +- examples/client/src/main.rs | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 4e29373..20313b9 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] futures-util = "0.3" -tokio = { version = "0.2", features = [ "net", "io-util", "rt-threaded" ] } +tokio = { version = "0.2", features = [ "net", "io-std", "io-util", "rt-threaded" ] } structopt = "0.2" tokio-rustls = { path = "../.." } webpki-roots = "0.18" diff --git a/examples/client/src/main.rs b/examples/client/src/main.rs index 8f130d2..6012c7e 100644 --- a/examples/client/src/main.rs +++ b/examples/client/src/main.rs @@ -8,9 +8,12 @@ use futures_util::future; use structopt::StructOpt; use tokio::runtime; use tokio::net::TcpStream; -use tokio::io::{ AsyncWriteExt, copy, split }; +use tokio::io::{ + AsyncWriteExt, + copy, split, + stdin as tokio_stdin, stdout as tokio_stdout +}; use tokio_rustls::{ TlsConnector, rustls::ClientConfig, webpki::DNSNameRef }; -use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout }; #[derive(StructOpt)] @@ -61,8 +64,7 @@ fn main() -> io::Result<()> { let fut = async { let stream = TcpStream::connect(&addr).await?; - // TODO tokio-compat - let (mut stdin, mut stdout) = (tokio_stdin(0).compat(), tokio_stdout(0).compat()); + let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout()); let domain = DNSNameRef::try_from_ascii_str(&domain) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; @@ -71,10 +73,13 @@ fn main() -> io::Result<()> { stream.write_all(content.as_bytes()).await?; let (mut reader, mut writer) = split(stream); - future::try_join( + future::select( copy(&mut reader, &mut stdout), copy(&mut stdin, &mut writer) - ).await?; + ) + .await + .factor_first() + .0?; Ok(()) }; From a9b20c509c61c907a3bf019a3d89e5588f64693f Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 2 Dec 2019 23:55:11 +0800 Subject: [PATCH 165/171] Update ci --- .github/workflows/ci.yml | 17 +++++++++++++++-- Cargo.toml | 2 +- README.md | 2 +- examples/client/Cargo.toml | 1 - 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12acc8c..d354dd7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: Rust on: [push, pull_request] jobs: - cov: + test: runs-on: ubuntu-latest steps: @@ -12,8 +12,14 @@ jobs: - uses: actions-rs/toolchain@v1 with: toolchain: nightly + profile: minimal override: true + - uses: actions/cache@v1 + with: + path: ~/.cargo + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + - uses: actions-rs/cargo@v1 with: command: test @@ -22,11 +28,18 @@ jobs: 'CARGO_INCREMENTAL': '0' 'RUSTFLAGS': '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Zno-landing-pads' + - name: Check + run: | + cd examples/client + cargo check + cd ../server + cargo check + - id: grcov uses: actions-rs/grcov@v0.1 - name: Update Codecov - uses: codecov/codecov-action@v1.0.3 + uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} file: ${{ steps.grcov.outputs.report }} diff --git a/Cargo.toml b/Cargo.toml index cf10494..e30cb12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ categories = ["asynchronous", "cryptography", "network-programming"] edition = "2018" [badges] -github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } +github-actions = { repository = "quininer/tokio-rustls", workflow = "Rust" } [dependencies] bytes = "0.5" diff --git a/README.md b/README.md index 70ab646..751da83 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # tokio-rustls -[![github actions](https://github.com/quininer/tokio-rustls/workflows/ci/badge.svg)](https://github.com/quininer/tokio-rustls/actions) +[![github actions](https://github.com/quininer/tokio-rustls/workflows/Rust/badge.svg)](https://github.com/quininer/tokio-rustls/actions) [![codecov](https://codecov.io/gh/quininer/tokio-rustls/branch/master/graph/badge.svg)](https://codecov.io/gh/quininer/tokio-rustls) [![crates](https://img.shields.io/crates/v/tokio-rustls.svg)](https://crates.io/crates/tokio-rustls) [![license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/quininer/tokio-rustls/blob/master/LICENSE-MIT) diff --git a/examples/client/Cargo.toml b/examples/client/Cargo.toml index 20313b9..40162f8 100644 --- a/examples/client/Cargo.toml +++ b/examples/client/Cargo.toml @@ -10,4 +10,3 @@ tokio = { version = "0.2", features = [ "net", "io-std", "io-util", "rt-threaded structopt = "0.2" tokio-rustls = { path = "../.." } webpki-roots = "0.18" -tokio-stdin-stdout = "0.1" From 074fe4a5ac0bf6174136b572394e2f0d1963c26c Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 8 Dec 2019 00:52:55 +0800 Subject: [PATCH 166/171] Move TlsState to common --- src/common/mod.rs | 53 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 55 +---------------------------------------------- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 1deb21d..a53f548 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -7,6 +7,59 @@ use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; +#[derive(Debug)] +pub enum TlsState { + #[cfg(feature = "early-data")] + EarlyData(usize, Vec), + Stream, + ReadShutdown, + WriteShutdown, + FullyShutdown, +} + +impl TlsState { + pub fn shutdown_read(&mut self) { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::ReadShutdown, + } + } + + pub fn shutdown_write(&mut self) { + match *self { + TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::WriteShutdown, + } + } + + pub fn writeable(&self) -> bool { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } + + pub fn readable(&self) -> bool { + match self { + TlsState::ReadShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } + + #[cfg(feature = "early-data")] + pub fn is_early_data(&self) -> bool { + match self { + TlsState::EarlyData(..) => true, + _ => false + } + } + + #[cfg(not(feature = "early-data"))] + pub const fn is_early_data(&self) -> bool { + false + } +} + pub struct Stream<'a, IO, S> { pub io: &'a mut IO, pub session: &'a mut S, diff --git a/src/lib.rs b/src/lib.rs index 706e972..09c10e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,64 +13,11 @@ use futures_core as futures; use tokio::io::{ AsyncRead, AsyncWrite }; use webpki::DNSNameRef; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; -use common::Stream; +use common::{ Stream, TlsState }; pub use rustls; pub use webpki; -#[derive(Debug)] -enum TlsState { - #[cfg(feature = "early-data")] - EarlyData(usize, Vec), - Stream, - ReadShutdown, - WriteShutdown, - FullyShutdown, -} - -impl TlsState { - fn shutdown_read(&mut self) { - match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, - _ => *self = TlsState::ReadShutdown, - } - } - - fn shutdown_write(&mut self) { - match *self { - TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, - _ => *self = TlsState::WriteShutdown, - } - } - - fn writeable(&self) -> bool { - match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => false, - _ => true, - } - } - - fn readable(&self) -> bool { - match self { - TlsState::ReadShutdown | TlsState::FullyShutdown => false, - _ => true, - } - } - - #[cfg(feature = "early-data")] - fn is_early_data(&self) -> bool { - match self { - TlsState::EarlyData(..) => true, - _ => false - } - } - - #[cfg(not(feature = "early-data"))] - const fn is_early_data(&self) -> bool { - false - } -} - /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { From 7f69e889a409237ea72db561bda65dc34c296ec4 Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 8 Dec 2019 00:59:15 +0800 Subject: [PATCH 167/171] Fix incorrect prepare_uninitialized_buffer --- Cargo.toml | 1 + src/client.rs | 3 ++- src/server.rs | 6 +++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e30cb12..e0822cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ webpki = "0.21" [features] early-data = [] dangerous_configuration = ["rustls/dangerous_configuration"] +unstable = [] [dev-dependencies] tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] } diff --git a/src/client.rs b/src/client.rs index 779ddc5..7807f12 100644 --- a/src/client.rs +++ b/src/client.rs @@ -69,8 +69,9 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { + #[cfg(feature = "unstable")] unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { - self.io.prepare_uninitialized_buffer(buf) + false } fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { diff --git a/src/server.rs b/src/server.rs index 9066c27..0563341 100644 --- a/src/server.rs +++ b/src/server.rs @@ -67,8 +67,12 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { + #[cfg(feature = "unstable")] unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { - self.io.prepare_uninitialized_buffer(buf) + // TODO + // + // https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer + false } fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { From 368f32ea9fceb3bc55303f102437d87adaa815db Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 8 Dec 2019 16:41:47 +0800 Subject: [PATCH 168/171] Add Failable{Connect,Accept} --- src/client.rs | 57 +++++++++++---------------- src/common/handshake.rs | 84 ++++++++++++++++++++++++++++++++++++++++ src/common/mod.rs | 10 ++++- src/lib.rs | 85 +++++++++++++++++++++++++++++++++++++---- src/server.rs | 43 +++++++-------------- 5 files changed, 206 insertions(+), 73 deletions(-) create mode 100644 src/common/handshake.rs diff --git a/src/client.rs b/src/client.rs index 7807f12..25d5874 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,7 @@ use super::*; use rustls::Session; +use crate::common::IoSession; + /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -10,11 +12,6 @@ pub struct TlsStream { pub(crate) state: TlsState, } -pub(crate) enum MidHandshake { - Handshaking(TlsStream), - End, -} - impl TlsStream { #[inline] pub fn get_ref(&self) -> (&IO, &ClientSession) { @@ -32,36 +29,23 @@ impl TlsStream { } } -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - type Output = io::Result>; +impl IoSession for TlsStream { + type Io = IO; + type Session = ClientSession; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + fn skip_handshake(&self) -> bool { + self.state.is_early_data() + } - if let MidHandshake::Handshaking(stream) = this { - if !stream.state.is_early_data() { - let eof = !stream.state.readable(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session).set_eof(eof); + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } - - while stream.session.wants_write() { - futures::ready!(stream.write_io(cx))?; - } - } - } - - match mem::replace(this, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), - MidHandshake::End => panic!(), - } + #[inline] + fn into_io(self) -> Self::Io { + self.io } } @@ -119,6 +103,7 @@ where match this.state { #[cfg(feature = "early-data")] TlsState::EarlyData(ref mut pos, ref mut data) => { + use futures_core::ready; use std::io::Write; // write early data @@ -137,13 +122,13 @@ where // complete handshake while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + ready!(stream.handshake(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -162,16 +147,18 @@ where .set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { + use futures_core::ready; + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { // complete handshake while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + ready!(stream.handshake(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } diff --git a/src/common/handshake.rs b/src/common/handshake.rs new file mode 100644 index 0000000..0006b56 --- /dev/null +++ b/src/common/handshake.rs @@ -0,0 +1,84 @@ +use std::{ io, mem }; +use std::pin::Pin; +use std::future::Future; +use std::task::{ Context, Poll }; +use futures_core::future::FusedFuture; +use tokio::io::{ AsyncRead, AsyncWrite }; +use rustls::Session; +use crate::common::{ TlsState, Stream }; + + +pub(crate) trait IoSession { + type Io; + type Session; + + fn skip_handshake(&self) -> bool; + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); + fn into_io(self) -> Self::Io; +} + +pub(crate) enum MidHandshake { + Handshaking(IS), + End, +} + +impl FusedFuture for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: Session + Unpin +{ + fn is_terminated(&self) -> bool { + if let MidHandshake::End = self { + true + } else { + false + } + } +} + +impl Future for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: Session + Unpin +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) { + if !stream.skip_handshake() { + let (state, io, session) = stream.get_mut(); + let mut tls_stream = Stream::new(io, session) + .set_eof(!state.readable()); + + macro_rules! try_poll { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Pending => { + *this = MidHandshake::Handshaking(stream); + return Poll::Pending; + } + } + } + } + + while tls_stream.session.is_handshaking() { + try_poll!(tls_stream.handshake(cx)); + } + + while tls_stream.session.wants_write() { + try_poll!(tls_stream.write_io(cx)); + } + } + + Poll::Ready(Ok(stream)) + } else { + panic!() + } + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index a53f548..9f6d9ac 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,10 +1,12 @@ +mod handshake; + use std::pin::Pin; use std::task::{ Poll, Context }; -use std::marker::Unpin; use std::io::{ self, Read, Write }; use rustls::Session; use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; +pub(crate) use handshake::{ IoSession, MidHandshake }; #[derive(Debug)] @@ -18,6 +20,7 @@ pub enum TlsState { } impl TlsState { + #[inline] pub fn shutdown_read(&mut self) { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, @@ -25,6 +28,7 @@ impl TlsState { } } + #[inline] pub fn shutdown_write(&mut self) { match *self { TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, @@ -32,6 +36,7 @@ impl TlsState { } } + #[inline] pub fn writeable(&self) -> bool { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => false, @@ -39,6 +44,7 @@ impl TlsState { } } + #[inline] pub fn readable(&self) -> bool { match self { TlsState::ReadShutdown | TlsState::FullyShutdown => false, @@ -46,6 +52,7 @@ impl TlsState { } } + #[inline] #[cfg(feature = "early-data")] pub fn is_early_data(&self) -> bool { match self { @@ -54,6 +61,7 @@ impl TlsState { } } + #[inline] #[cfg(not(feature = "early-data"))] pub const fn is_early_data(&self) -> bool { false diff --git a/src/lib.rs b/src/lib.rs index 09c10e2..28d9de1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,16 +4,16 @@ mod common; pub mod client; pub mod server; -use std::{ io, mem }; +use std::io; use std::pin::Pin; use std::sync::Arc; use std::future::Future; use std::task::{ Context, Poll }; -use futures_core as futures; +use futures_core::future::FusedFuture; use tokio::io::{ AsyncRead, AsyncWrite }; use webpki::DNSNameRef; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; -use common::{ Stream, TlsState }; +use common::{ Stream, TlsState, MidHandshake }; pub use rustls; pub use webpki; @@ -75,7 +75,7 @@ impl TlsConnector { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - Connect(client::MidHandshake::Handshaking(client::TlsStream { + Connect(MidHandshake::Handshaking(client::TlsStream { io: stream, #[cfg(not(feature = "early-data"))] @@ -110,7 +110,7 @@ impl TlsAcceptor { let mut session = ServerSession::new(&self.inner); f(&mut session); - Accept(server::MidHandshake::Handshaking(server::TlsStream { + Accept(MidHandshake::Handshaking(server::TlsStream { session, io: stream, state: TlsState::Stream, @@ -120,30 +120,99 @@ impl TlsAcceptor { /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. -pub struct Connect(client::MidHandshake); +pub struct Connect(MidHandshake>); /// Future returned from `TlsAcceptor::accept` which will resolve /// once the accept handshake has finished. -pub struct Accept(server::MidHandshake); +pub struct Accept(MidHandshake>); + +/// Like [Connect], but returns `IO` on failure. +pub struct FailableConnect(MidHandshake>); + +/// Like [Accept], but returns `IO` on failure. +pub struct FailableAccept(MidHandshake>); + +impl Connect { + #[inline] + pub fn into_failable(self) -> FailableConnect { + FailableConnect(self.0) + } +} + +impl Accept { + #[inline] + pub fn into_failable(self) -> FailableAccept { + FailableAccept(self.0) + } +} impl Future for Connect { type Output = io::Result>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) + Pin::new(&mut self.0) + .poll(cx) + .map_err(|(err, _)| err) + } +} + +impl FusedFuture for Connect { + fn is_terminated(&self) -> bool { + self.0.is_terminated() } } impl Future for Accept { type Output = io::Result>; + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0) + .poll(cx) + .map_err(|(err, _)| err) + } +} + +impl FusedFuture for Accept { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for FailableConnect { + type Output = Result, (io::Error, IO)>; + #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } +impl FusedFuture for FailableConnect { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for FailableAccept { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl FusedFuture for FailableAccept { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + /// Unified TLS stream type /// /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use diff --git a/src/server.rs b/src/server.rs index 0563341..aa7164e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ use super::*; use rustls::Session; +use crate::common::IoSession; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -10,11 +11,6 @@ pub struct TlsStream { pub(crate) state: TlsState, } -pub(crate) enum MidHandshake { - Handshaking(TlsStream), - End, -} - impl TlsStream { #[inline] pub fn get_ref(&self) -> (&IO, &ServerSession) { @@ -32,34 +28,23 @@ impl TlsStream { } } -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - type Output = io::Result>; +impl IoSession for TlsStream { + type Io = IO; + type Session = ServerSession; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + fn skip_handshake(&self) -> bool { + false + } - if let MidHandshake::Handshaking(stream) = this { - let eof = !stream.state.readable(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session).set_eof(eof); + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } - - while stream.session.wants_write() { - futures::ready!(stream.write_io(cx))?; - } - } - - match mem::replace(this, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), - MidHandshake::End => panic!(), - } + #[inline] + fn into_io(self) -> Self::Io { + self.io } } From 7530e2f7396d4c5d72be5693281338a78566380d Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 8 Dec 2019 16:54:37 +0800 Subject: [PATCH 169/171] publish 0.12.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e0822cf..6ed6429 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.0" +version = "0.12.1" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" From ce16555b13c556277edec825310696be2ad29930 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 7 Jan 2020 23:57:00 +0800 Subject: [PATCH 170/171] implement WriteV close https://github.com/quininer/tokio-rustls/issues/57 --- Cargo.toml | 5 +- src/client.rs | 2 +- src/common/handshake.rs | 2 +- src/common/mod.rs | 44 +++++++++++++- src/common/vecbuf.rs | 128 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 5 +- src/server.rs | 2 +- 7 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 src/common/vecbuf.rs diff --git a/Cargo.toml b/Cargo.toml index 6ed6429..52964d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,16 +15,17 @@ edition = "2018" github-actions = { repository = "quininer/tokio-rustls", workflow = "Rust" } [dependencies] -bytes = "0.5" tokio = "0.2.0" futures-core = "0.3.1" rustls = "0.16" webpki = "0.21" +bytes = { version = "0.5", optional = true } + [features] early-data = [] dangerous_configuration = ["rustls/dangerous_configuration"] -unstable = [] +unstable = ["bytes"] [dev-dependencies] tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] } diff --git a/src/client.rs b/src/client.rs index 25d5874..5007aa8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -54,7 +54,7 @@ where IO: AsyncRead + AsyncWrite + Unpin, { #[cfg(feature = "unstable")] - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { false } diff --git a/src/common/handshake.rs b/src/common/handshake.rs index 0006b56..c59541e 100644 --- a/src/common/handshake.rs +++ b/src/common/handshake.rs @@ -78,7 +78,7 @@ where Poll::Ready(Ok(stream)) } else { - panic!() + panic!("unexpected polling after handshake") } } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 9f6d9ac..1d0dd07 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,8 +1,11 @@ mod handshake; +#[cfg(feature = "unstable")] +mod vecbuf; + use std::pin::Pin; use std::task::{ Poll, Context }; -use std::io::{ self, Read, Write }; +use std::io::{ self, Read }; use rustls::Session; use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; @@ -23,7 +26,8 @@ impl TlsState { #[inline] pub fn shutdown_read(&mut self) { match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + TlsState::WriteShutdown | TlsState::FullyShutdown => + *self = TlsState::FullyShutdown, _ => *self = TlsState::ReadShutdown, } } @@ -31,7 +35,8 @@ impl TlsState { #[inline] pub fn shutdown_write(&mut self) { match *self { - TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + TlsState::ReadShutdown | TlsState::FullyShutdown => + *self = TlsState::FullyShutdown, _ => *self = TlsState::WriteShutdown, } } @@ -132,7 +137,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } + #[cfg(not(feature = "unstable"))] pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + use std::io::Write; + struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -162,6 +170,36 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } + #[cfg(feature = "unstable")] + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + use rustls::WriteV; + + struct Writer<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> { + fn writev(&mut self, vbuf: &[&[u8]]) -> io::Result { + use vecbuf::VecBuf; + + let mut vbuf = VecBuf::new(vbuf); + + match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut writer = Writer { io: self.io, cx }; + + match self.session.writev_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs new file mode 100644 index 0000000..6ea19e3 --- /dev/null +++ b/src/common/vecbuf.rs @@ -0,0 +1,128 @@ +use std::io::IoSlice; +use std::cmp::{ self, Ordering }; +use bytes::Buf; + + +pub struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + pub fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + }, + Ordering::Greater => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + #[allow(clippy::needless_range_loop)] + fn bytes_vectored<'c>(&'c self, dst: &mut [IoSlice<'c>]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + if len > 0 { + dst[0] = IoSlice::new(self.bytes()); + } + + for i in 1..len { + dst[i] = IoSlice::new(&self.inner[self.pos + i]); + } + + len + } +} + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(1); + + assert_eq!(buf.remaining(), 4); + assert_eq!(buf.bytes(), b"e"); + + buf.advance(1); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [IoSlice; 2] = + [IoSlice::new(b1), IoSlice::new(b2)]; + + assert_eq!(2, buf.bytes_vectored(&mut dst[..])); + } +} diff --git a/src/lib.rs b/src/lib.rs index 28d9de1..db34b07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,8 +51,8 @@ impl From> for TlsAcceptor { impl TlsConnector { /// Enable 0-RTT. /// - /// Note that you want to use 0-RTT. - /// You must set `enable_early_data` to `true` in `ClientConfig`. + /// If you want to use 0-RTT, + /// You must also set `ClientConfig.enable_early_data` to `true`. #[cfg(feature = "early-data")] pub fn early_data(mut self, flag: bool) -> TlsConnector { self.early_data = flag; @@ -158,6 +158,7 @@ impl Future for Connect { } impl FusedFuture for Connect { + #[inline] fn is_terminated(&self) -> bool { self.0.is_terminated() } diff --git a/src/server.rs b/src/server.rs index aa7164e..abf86d6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -53,7 +53,7 @@ where IO: AsyncRead + AsyncWrite + Unpin, { #[cfg(feature = "unstable")] - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { // TODO // // https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer From d7862fae8ae6870ad27bd62b7960798825998a62 Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 8 Jan 2020 00:34:59 +0800 Subject: [PATCH 171/171] bump version --- .github/workflows/ci.yml | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d354dd7..96eb2c0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: - uses: actions-rs/cargo@v1 with: command: test - args: --features early-data + args: --features early-data,unstable env: 'CARGO_INCREMENTAL': '0' 'RUSTFLAGS': '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Zno-landing-pads' diff --git a/Cargo.toml b/Cargo.toml index 52964d7..47052bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.12.1" +version = "0.12.2" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls"