change: impl io::{Read,Write}

This commit is contained in:
quininer 2018-03-21 21:44:36 +08:00
parent 8c79329c7a
commit 72de25ebce
5 changed files with 107 additions and 89 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.5.0" version = "0.6.0-alpha"
authors = ["quininer kel <quininer@live.com>"] authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
repository = "https://github.com/quininer/tokio-rustls" repository = "https://github.com/quininer/tokio-rustls"

View File

@ -1,7 +1,5 @@
extern crate clap; extern crate clap;
extern crate rustls; extern crate rustls;
extern crate futures;
extern crate tokio_io;
extern crate tokio; extern crate tokio;
extern crate webpki_roots; extern crate webpki_roots;
extern crate tokio_rustls; extern crate tokio_rustls;
@ -10,12 +8,11 @@ use std::sync::Arc;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::io::BufReader; use std::io::BufReader;
use std::fs::File; use std::fs::File;
use futures::{ Future, Stream };
use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys }; use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_io::{ io, AsyncRead }; use tokio::prelude::{ Future, Stream };
use tokio::io::{ self, AsyncRead };
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::executor::current_thread;
use clap::{ App, Arg }; use clap::{ App, Arg };
use tokio_rustls::ServerConfigExt; use tokio_rustls::ServerConfigExt;
@ -64,7 +61,7 @@ fn main() {
}) })
.map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr))
.map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2));
current_thread::spawn(done); tokio::spawn(done);
Ok(()) Ok(())
} else { } else {
@ -82,10 +79,10 @@ fn main() {
.and_then(|(stream, _)| io::flush(stream)) .and_then(|(stream, _)| io::flush(stream))
.map(move |_| println!("Accept: {:?}", addr)) .map(move |_| println!("Accept: {:?}", addr))
.map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2));
current_thread::spawn(done); tokio::spawn(done);
Ok(()) Ok(())
}); });
current_thread::run(|_| current_thread::spawn(done.map_err(drop))); tokio::run(done.map_err(drop));
} }

View File

@ -1,7 +1,9 @@
extern crate futures;
use super::*; use super::*;
use futures::{ Future, Poll, Async }; use self::futures::{ Future, Poll, Async };
use futures::io::{ Error, AsyncRead, AsyncWrite }; use self::futures::io::{ Error, AsyncRead, AsyncWrite };
use futures::task::Context; use self::futures::task::Context;
impl<S: io::Read + io::Write> Future for ConnectAsync<S> { impl<S: io::Read + io::Write> Future for ConnectAsync<S> {

View File

@ -1,12 +1,10 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
extern crate futures;
extern crate tokio;
extern crate rustls; extern crate rustls;
extern crate webpki; extern crate webpki;
mod tokio_impl; #[cfg(feature = "tokio")] mod tokio_impl;
mod futures_impl; #[cfg(feature = "futures")] mod futures_impl;
use std::io; use std::io;
use std::sync::Arc; use std::sync::Arc;
@ -17,14 +15,14 @@ use rustls::{
/// Extension trait for the `Arc<ClientConfig>` type in the `rustls` crate. /// Extension trait for the `Arc<ClientConfig>` type in the `rustls` crate.
pub trait ClientConfigExt { pub trait ClientConfigExt: sealed::Sealed {
fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S) fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S)
-> ConnectAsync<S> -> ConnectAsync<S>
where S: io::Read + io::Write; where S: io::Read + io::Write;
} }
/// Extension trait for the `Arc<ServerConfig>` type in the `rustls` crate. /// Extension trait for the `Arc<ServerConfig>` type in the `rustls` crate.
pub trait ServerConfigExt { pub trait ServerConfigExt: sealed::Sealed {
fn accept_async<S>(&self, stream: S) fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S> -> AcceptAsync<S>
where S: io::Read + io::Write; where S: io::Read + io::Write;
@ -39,6 +37,7 @@ pub struct ConnectAsync<S>(MidHandshake<S, ClientSession>);
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct AcceptAsync<S>(MidHandshake<S, ServerSession>); pub struct AcceptAsync<S>(MidHandshake<S, ServerSession>);
impl sealed::Sealed for Arc<ClientConfig> {}
impl ClientConfigExt for Arc<ClientConfig> { impl ClientConfigExt for Arc<ClientConfig> {
fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S) fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S)
@ -54,11 +53,11 @@ pub fn connect_async_with_session<S>(stream: S, session: ClientSession)
-> ConnectAsync<S> -> ConnectAsync<S>
where S: io::Read + io::Write where S: io::Read + io::Write
{ {
ConnectAsync(MidHandshake { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) })
inner: Some(TlsStream::new(stream, session))
})
} }
impl sealed::Sealed for Arc<ServerConfig> {}
impl ServerConfigExt for Arc<ServerConfig> { impl ServerConfigExt for Arc<ServerConfig> {
fn accept_async<S>(&self, stream: S) fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S> -> AcceptAsync<S>
@ -73,9 +72,7 @@ pub fn accept_async_with_session<S>(stream: S, session: ServerSession)
-> AcceptAsync<S> -> AcceptAsync<S>
where S: io::Read + io::Write where S: io::Read + io::Write
{ {
AcceptAsync(MidHandshake { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) })
inner: Some(TlsStream::new(stream, session))
})
} }
@ -104,11 +101,30 @@ impl<S, C> TlsStream<S, C> {
} }
} }
macro_rules! try_wouldblock {
( continue $r:expr ) => {
match $r {
Ok(true) => continue,
Ok(false) => false,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e)
}
};
( ignore $r:expr ) => {
match $r {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => return Err(e)
}
};
}
impl<S, C> TlsStream<S, C> impl<S, C> TlsStream<S, C>
where S: io::Read + io::Write, C: Session where S: io::Read + io::Write, C: Session
{ {
#[inline] #[inline]
pub fn new(io: S, session: C) -> TlsStream<S, C> { fn new(io: S, session: C) -> TlsStream<S, C> {
TlsStream { TlsStream {
is_shutdown: false, is_shutdown: false,
eof: false, eof: false,
@ -117,15 +133,12 @@ impl<S, C> TlsStream<S, C>
} }
} }
pub fn do_io(&mut self) -> io::Result<()> { fn do_read(&mut self) -> io::Result<bool> {
loop { if !self.eof && self.session.wants_read() {
let read_would_block = if !self.eof && self.session.wants_read() { if self.session.read_tls(&mut self.io)? == 0 {
match self.session.read_tls(&mut self.io) {
Ok(0) => {
self.eof = true; self.eof = true;
continue }
},
Ok(_) => {
if let Err(err) = self.session.process_new_packets() { if let Err(err) = self.session.process_new_packets() {
// flush queued messages before returning an Err in // flush queued messages before returning an Err in
// order to send alerts instead of abruptly closing // order to send alerts instead of abruptly closing
@ -134,28 +147,32 @@ impl<S, C> TlsStream<S, C>
// ignore result to avoid masking original error // ignore result to avoid masking original error
let _ = self.session.write_tls(&mut self.io); let _ = self.session.write_tls(&mut self.io);
} }
return Err(io::Error::new(io::ErrorKind::Other, err)); return Err(io::Error::new(io::ErrorKind::InvalidData, err));
} }
continue
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e)
}
} else {
false
};
let write_would_block = if self.session.wants_write() { Ok(true)
match self.session.write_tls(&mut self.io) {
Ok(_) => continue,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e)
}
} else { } else {
false Ok(false)
}; }
}
if read_would_block || write_would_block { fn do_write(&mut self) -> io::Result<bool> {
if self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
Ok(true)
} else {
Ok(false)
}
}
#[inline]
pub fn do_io(&mut self) -> io::Result<()> {
loop {
let write_would_block = try_wouldblock!(continue self.do_write());
let read_would_block = try_wouldblock!(continue self.do_read());
if write_would_block || read_would_block {
return Err(io::Error::from(io::ErrorKind::WouldBlock)); return Err(io::Error::from(io::ErrorKind::WouldBlock));
} else { } else {
return Ok(()); return Ok(());
@ -168,12 +185,14 @@ impl<S, C> io::Read for TlsStream<S, C>
where S: io::Read + io::Write, C: Session where S: io::Read + io::Write, C: Session
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
try_wouldblock!(ignore self.do_io());
loop { loop {
match self.session.read(buf) { match self.session.read(buf) {
Ok(0) if !self.eof => self.do_io()?, Ok(0) if !self.eof => while self.do_read()? {},
Ok(n) => return Ok(n), Ok(n) => return Ok(n),
Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted {
self.do_io()?; try_wouldblock!(ignore self.do_read());
return if self.eof { Ok(0) } else { Err(e) } return if self.eof { Ok(0) } else { Err(e) }
} else { } else {
return Err(e) return Err(e)
@ -187,38 +206,39 @@ impl<S, C> io::Write for TlsStream<S, C>
where S: io::Read + io::Write, C: Session where S: io::Read + io::Write, C: Session
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() { try_wouldblock!(ignore self.do_io());
return Ok(0);
} let mut wlen = self.session.write(buf)?;
loop { loop {
let output = self.session.write(buf)?; match self.do_write() {
Ok(true) => continue,
while self.session.wants_write() { Ok(false) if wlen == 0 => (),
match self.session.write_tls(&mut self.io) { Ok(false) => break,
Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => if output == 0 { if wlen == 0 {
// Both rustls buffer and IO buffer are blocking. // Both rustls buffer and IO buffer are blocking.
return Err(io::Error::from(io::ErrorKind::WouldBlock)); return Err(io::Error::from(io::ErrorKind::WouldBlock));
} else { } else {
break; continue
}, },
Err(e) => return Err(e) Err(e) => return Err(e)
} }
assert_eq!(wlen, 0);
wlen = self.session.write(buf)?;
} }
if output > 0 { Ok(wlen)
// Already wrote something out.
return Ok(output);
}
}
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
self.session.flush()?; self.session.flush()?;
while self.session.wants_write() { while self.do_write()? {};
self.session.write_tls(&mut self.io)?;
}
self.io.flush() self.io.flush()
} }
} }
mod sealed {
pub trait Sealed {}
}

View File

@ -1,7 +1,9 @@
extern crate tokio;
use super::*; use super::*;
use tokio::prelude::*; use self::tokio::prelude::*;
use tokio::io::{ AsyncRead, AsyncWrite }; use self::tokio::io::{ AsyncRead, AsyncWrite };
use tokio::prelude::Poll; use self::tokio::prelude::Poll;
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> { impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
@ -67,10 +69,7 @@ impl<S, C> AsyncWrite for TlsStream<S, C>
self.session.send_close_notify(); self.session.send_close_notify();
self.is_shutdown = true; self.is_shutdown = true;
} }
while self.session.wants_write() { while self.do_write()? {};
self.session.write_tls(&mut self.io)?;
}
self.io.flush()?;
self.io.shutdown() self.io.shutdown()
} }
} }