change: use rustls Stream

This commit is contained in:
quininer 2018-03-23 17:14:27 +08:00
parent d4cb46e895
commit 64ca6e290c
4 changed files with 32 additions and 63 deletions

View File

@ -23,14 +23,15 @@ webpki = "0.18.0-alpha"
[dev-dependencies]
tokio = "0.1"
tokio-io = "0.1"
# tokio-core = "0.1"
# tokio-file-unix = "0.4"
tokio-core = "0.1"
tokio-file-unix = "0.4"
clap = "2.26"
webpki-roots = "0.14"
[features]
unstable-futures = [ "futures", "tokio/unstable-futures" ]
default = [ "unstable-futures", "tokio" ]
default = [ "tokio" ]
# unstable-futures = [ "futures", "tokio/unstable-futures" ]
# default = [ "unstable-futures", "tokio" ]
[patch.crates-io]
tokio = { path = "../ref/tokio" }
# tokio = { path = "../ref/tokio" }

View File

@ -10,7 +10,8 @@ use std::io;
use std::sync::Arc;
use rustls::{
Session, ClientSession, ServerSession,
ClientConfig, ServerConfig
ClientConfig, ServerConfig,
Stream
};
@ -92,10 +93,12 @@ pub struct TlsStream<S, C> {
}
impl<S, C> TlsStream<S, C> {
#[inline]
pub fn get_ref(&self) -> (&S, &C) {
(&self.io, &self.session)
}
#[inline]
pub fn get_mut(&mut self) -> (&mut S, &mut C) {
(&mut self.io, &mut self.session)
}
@ -187,19 +190,13 @@ impl<S, C> io::Read for TlsStream<S, C>
where S: io::Read + io::Write, C: Session
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof));
let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io);
loop {
match self.session.read(buf) {
Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {},
Ok(n) => return Ok(n),
Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted {
try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof));
return if self.eof { Ok(0) } else { Err(e) }
} else {
return Err(e)
}
}
match stream.read(buf) {
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0),
Err(e) => Err(e)
}
}
}
@ -208,35 +205,19 @@ impl<S, C> io::Write for TlsStream<S, C>
where S: io::Read + io::Write, C: Session
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof));
let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io);
let mut wlen = self.session.write(buf)?;
loop {
match Self::do_write(&mut self.session, &mut self.io) {
Ok(true) => continue,
Ok(false) if wlen == 0 => (),
Ok(false) => break,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
if wlen == 0 {
// Both rustls buffer and IO buffer are blocking.
return Err(io::Error::from(io::ErrorKind::WouldBlock));
} else {
continue
},
Err(e) => return Err(e)
}
assert_eq!(wlen, 0);
wlen = self.session.write(buf)?;
}
Ok(wlen)
stream.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.session.flush()?;
while Self::do_write(&mut self.session, &mut self.io)? {};
{
let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io);
stream.flush()?;
}
self.io.flush()
}
}

View File

@ -35,17 +35,10 @@ impl<S, C> Future for MidHandshake<S, C>
let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break };
match TlsStream::do_io(&mut stream.session, &mut stream.io, &mut stream.eof) {
Ok(()) => match (stream.eof, stream.session.is_handshaking()) {
(true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(false, true) => continue,
(..) => break
},
Err(e) => match (e.kind(), stream.session.is_handshaking()) {
(io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady),
(io::ErrorKind::WouldBlock, false) => break,
(..) => return Err(e)
}
match stream.session.complete_io(&mut stream.io) {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
Err(e) => return Err(e)
}
}

View File

@ -59,6 +59,7 @@ fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
recv.recv().unwrap()
}
#[cfg(feature = "unstable-futures")]
fn start_server2(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
use futures::{ FutureExt, StreamExt };
use futures::io::{ AsyncReadExt, AsyncWriteExt };
@ -136,16 +137,9 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor
let done = TcpStream::connect(addr)
.and_then(|stream| config.connect_async(domain, stream))
.and_then(|stream| {
eprintln!("WRITE: {:?}", stream);
stream.write_all(HELLO_WORLD)
})
.and_then(|(stream, _)| {
eprintln!("READ: {:?}", stream);
stream.read_exact(vec![0; HELLO_WORLD.len()])
})
.and_then(|stream| stream.write_all(HELLO_WORLD))
.and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()]))
.and_then(|(stream, buf)| {
eprintln!("OK: {:?}", stream);
assert_eq!(buf, HELLO_WORLD);
Ok(())
});