change: impl io::{Read,Write}

This commit is contained in:
quininer 2018-03-21 21:44:36 +08:00
parent 8c79329c7a
commit 72de25ebce
5 changed files with 107 additions and 89 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "tokio-rustls"
version = "0.5.0"
version = "0.6.0-alpha"
authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0"
repository = "https://github.com/quininer/tokio-rustls"

View File

@ -1,7 +1,5 @@
extern crate clap;
extern crate rustls;
extern crate futures;
extern crate tokio_io;
extern crate tokio;
extern crate webpki_roots;
extern crate tokio_rustls;
@ -10,12 +8,11 @@ use std::sync::Arc;
use std::net::ToSocketAddrs;
use std::io::BufReader;
use std::fs::File;
use futures::{ Future, Stream };
use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_io::{ io, AsyncRead };
use tokio::prelude::{ Future, Stream };
use tokio::io::{ self, AsyncRead };
use tokio::net::TcpListener;
use tokio::executor::current_thread;
use clap::{ App, Arg };
use tokio_rustls::ServerConfigExt;
@ -64,7 +61,7 @@ fn main() {
})
.map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr))
.map_err(move |err| println!("Error: {:?} - {:?}", err, addr2));
current_thread::spawn(done);
tokio::spawn(done);
Ok(())
} else {
@ -82,10 +79,10 @@ fn main() {
.and_then(|(stream, _)| io::flush(stream))
.map(move |_| println!("Accept: {:?}", addr))
.map_err(move |err| println!("Error: {:?} - {:?}", err, addr2));
current_thread::spawn(done);
tokio::spawn(done);
Ok(())
});
current_thread::run(|_| current_thread::spawn(done.map_err(drop)));
tokio::run(done.map_err(drop));
}

View File

@ -1,7 +1,9 @@
extern crate futures;
use super::*;
use futures::{ Future, Poll, Async };
use futures::io::{ Error, AsyncRead, AsyncWrite };
use futures::task::Context;
use self::futures::{ Future, Poll, Async };
use self::futures::io::{ Error, AsyncRead, AsyncWrite };
use self::futures::task::Context;
impl<S: io::Read + io::Write> Future for ConnectAsync<S> {

View File

@ -1,12 +1,10 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
extern crate futures;
extern crate tokio;
extern crate rustls;
extern crate webpki;
mod tokio_impl;
mod futures_impl;
#[cfg(feature = "tokio")] mod tokio_impl;
#[cfg(feature = "futures")] mod futures_impl;
use std::io;
use std::sync::Arc;
@ -17,14 +15,14 @@ use rustls::{
/// Extension trait for the `Arc<ClientConfig>` type in the `rustls` crate.
pub trait ClientConfigExt {
pub trait ClientConfigExt: sealed::Sealed {
fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S)
-> ConnectAsync<S>
where S: io::Read + io::Write;
}
/// Extension trait for the `Arc<ServerConfig>` type in the `rustls` crate.
pub trait ServerConfigExt {
pub trait ServerConfigExt: sealed::Sealed {
fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S>
where S: io::Read + io::Write;
@ -39,6 +37,7 @@ pub struct ConnectAsync<S>(MidHandshake<S, ClientSession>);
/// once the accept handshake has finished.
pub struct AcceptAsync<S>(MidHandshake<S, ServerSession>);
impl sealed::Sealed for Arc<ClientConfig> {}
impl ClientConfigExt for Arc<ClientConfig> {
fn connect_async<S>(&self, domain: webpki::DNSNameRef, stream: S)
@ -54,11 +53,11 @@ pub fn connect_async_with_session<S>(stream: S, session: ClientSession)
-> ConnectAsync<S>
where S: io::Read + io::Write
{
ConnectAsync(MidHandshake {
inner: Some(TlsStream::new(stream, session))
})
ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) })
}
impl sealed::Sealed for Arc<ServerConfig> {}
impl ServerConfigExt for Arc<ServerConfig> {
fn accept_async<S>(&self, stream: S)
-> AcceptAsync<S>
@ -73,9 +72,7 @@ pub fn accept_async_with_session<S>(stream: S, session: ServerSession)
-> AcceptAsync<S>
where S: io::Read + io::Write
{
AcceptAsync(MidHandshake {
inner: Some(TlsStream::new(stream, session))
})
AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) })
}
@ -104,11 +101,30 @@ impl<S, C> TlsStream<S, C> {
}
}
macro_rules! try_wouldblock {
( continue $r:expr ) => {
match $r {
Ok(true) => continue,
Ok(false) => false,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e)
}
};
( ignore $r:expr ) => {
match $r {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => return Err(e)
}
};
}
impl<S, C> TlsStream<S, C>
where S: io::Read + io::Write, C: Session
{
#[inline]
pub fn new(io: S, session: C) -> TlsStream<S, C> {
fn new(io: S, session: C) -> TlsStream<S, C> {
TlsStream {
is_shutdown: false,
eof: false,
@ -117,15 +133,12 @@ impl<S, C> TlsStream<S, C>
}
}
pub fn do_io(&mut self) -> io::Result<()> {
loop {
let read_would_block = if !self.eof && self.session.wants_read() {
match self.session.read_tls(&mut self.io) {
Ok(0) => {
fn do_read(&mut self) -> io::Result<bool> {
if !self.eof && self.session.wants_read() {
if self.session.read_tls(&mut self.io)? == 0 {
self.eof = true;
continue
},
Ok(_) => {
}
if let Err(err) = self.session.process_new_packets() {
// flush queued messages before returning an Err in
// order to send alerts instead of abruptly closing
@ -134,28 +147,32 @@ impl<S, C> TlsStream<S, C>
// ignore result to avoid masking original error
let _ = self.session.write_tls(&mut self.io);
}
return Err(io::Error::new(io::ErrorKind::Other, err));
return Err(io::Error::new(io::ErrorKind::InvalidData, err));
}
continue
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e)
}
} else {
false
};
let write_would_block = if self.session.wants_write() {
match self.session.write_tls(&mut self.io) {
Ok(_) => continue,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e)
}
Ok(true)
} else {
false
};
Ok(false)
}
}
if read_would_block || write_would_block {
fn do_write(&mut self) -> io::Result<bool> {
if self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
Ok(true)
} else {
Ok(false)
}
}
#[inline]
pub fn do_io(&mut self) -> io::Result<()> {
loop {
let write_would_block = try_wouldblock!(continue self.do_write());
let read_would_block = try_wouldblock!(continue self.do_read());
if write_would_block || read_would_block {
return Err(io::Error::from(io::ErrorKind::WouldBlock));
} else {
return Ok(());
@ -168,12 +185,14 @@ impl<S, C> io::Read for TlsStream<S, C>
where S: io::Read + io::Write, C: Session
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
try_wouldblock!(ignore self.do_io());
loop {
match self.session.read(buf) {
Ok(0) if !self.eof => self.do_io()?,
Ok(0) if !self.eof => while self.do_read()? {},
Ok(n) => return Ok(n),
Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted {
self.do_io()?;
try_wouldblock!(ignore self.do_read());
return if self.eof { Ok(0) } else { Err(e) }
} else {
return Err(e)
@ -187,38 +206,39 @@ impl<S, C> io::Write for TlsStream<S, C>
where S: io::Read + io::Write, C: Session
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
try_wouldblock!(ignore self.do_io());
let mut wlen = self.session.write(buf)?;
loop {
let output = self.session.write(buf)?;
while self.session.wants_write() {
match self.session.write_tls(&mut self.io) {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => if output == 0 {
match self.do_write() {
Ok(true) => continue,
Ok(false) if wlen == 0 => (),
Ok(false) => break,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
if wlen == 0 {
// Both rustls buffer and IO buffer are blocking.
return Err(io::Error::from(io::ErrorKind::WouldBlock));
} else {
break;
continue
},
Err(e) => return Err(e)
}
assert_eq!(wlen, 0);
wlen = self.session.write(buf)?;
}
if output > 0 {
// Already wrote something out.
return Ok(output);
}
}
Ok(wlen)
}
fn flush(&mut self) -> io::Result<()> {
self.session.flush()?;
while self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
}
while self.do_write()? {};
self.io.flush()
}
}
mod sealed {
pub trait Sealed {}
}

View File

@ -1,7 +1,9 @@
extern crate tokio;
use super::*;
use tokio::prelude::*;
use tokio::io::{ AsyncRead, AsyncWrite };
use tokio::prelude::Poll;
use self::tokio::prelude::*;
use self::tokio::io::{ AsyncRead, AsyncWrite };
use self::tokio::prelude::Poll;
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
@ -67,10 +69,7 @@ impl<S, C> AsyncWrite for TlsStream<S, C>
self.session.send_close_notify();
self.is_shutdown = true;
}
while self.session.wants_write() {
self.session.write_tls(&mut self.io)?;
}
self.io.flush()?;
while self.do_write()? {};
self.io.shutdown()
}
}