change: use rustls Stream

This commit is contained in:
quininer 2018-03-23 17:14:27 +08:00
parent d4cb46e895
commit 64ca6e290c
4 changed files with 32 additions and 63 deletions

View File

@ -23,14 +23,15 @@ webpki = "0.18.0-alpha"
[dev-dependencies] [dev-dependencies]
tokio = "0.1" tokio = "0.1"
tokio-io = "0.1" tokio-io = "0.1"
# tokio-core = "0.1" tokio-core = "0.1"
# tokio-file-unix = "0.4" tokio-file-unix = "0.4"
clap = "2.26" clap = "2.26"
webpki-roots = "0.14" webpki-roots = "0.14"
[features] [features]
unstable-futures = [ "futures", "tokio/unstable-futures" ] default = [ "tokio" ]
default = [ "unstable-futures", "tokio" ] # unstable-futures = [ "futures", "tokio/unstable-futures" ]
# default = [ "unstable-futures", "tokio" ]
[patch.crates-io] [patch.crates-io]
tokio = { path = "../ref/tokio" } # tokio = { path = "../ref/tokio" }

View File

@ -10,7 +10,8 @@ use std::io;
use std::sync::Arc; use std::sync::Arc;
use rustls::{ use rustls::{
Session, ClientSession, ServerSession, Session, ClientSession, ServerSession,
ClientConfig, ServerConfig ClientConfig, ServerConfig,
Stream
}; };
@ -92,10 +93,12 @@ pub struct TlsStream<S, C> {
} }
impl<S, C> TlsStream<S, C> { impl<S, C> TlsStream<S, C> {
#[inline]
pub fn get_ref(&self) -> (&S, &C) { pub fn get_ref(&self) -> (&S, &C) {
(&self.io, &self.session) (&self.io, &self.session)
} }
#[inline]
pub fn get_mut(&mut self) -> (&mut S, &mut C) { pub fn get_mut(&mut self) -> (&mut S, &mut C) {
(&mut self.io, &mut self.session) (&mut self.io, &mut self.session)
} }
@ -187,19 +190,13 @@ impl<S, C> io::Read for TlsStream<S, C>
where S: io::Read + io::Write, C: Session where S: io::Read + io::Write, C: Session
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io);
loop { match stream.read(buf) {
match self.session.read(buf) { Ok(n) => Ok(n),
Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {}, Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0),
Ok(n) => return Ok(n), Err(e) => Err(e)
Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted {
try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof));
return if self.eof { Ok(0) } else { Err(e) }
} else {
return Err(e)
}
}
} }
} }
} }
@ -208,35 +205,19 @@ impl<S, C> io::Write for TlsStream<S, C>
where S: io::Read + io::Write, C: Session where S: io::Read + io::Write, C: Session
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io);
let mut wlen = self.session.write(buf)?; stream.write(buf)
loop {
match Self::do_write(&mut self.session, &mut self.io) {
Ok(true) => continue,
Ok(false) if wlen == 0 => (),
Ok(false) => break,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
if wlen == 0 {
// Both rustls buffer and IO buffer are blocking.
return Err(io::Error::from(io::ErrorKind::WouldBlock));
} else {
continue
},
Err(e) => return Err(e)
}
assert_eq!(wlen, 0);
wlen = self.session.write(buf)?;
}
Ok(wlen)
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
self.session.flush()?; {
while Self::do_write(&mut self.session, &mut self.io)? {}; let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io);
stream.flush()?;
}
self.io.flush() self.io.flush()
} }
} }

View File

@ -35,17 +35,10 @@ impl<S, C> Future for MidHandshake<S, C>
let stream = self.inner.as_mut().unwrap(); let stream = self.inner.as_mut().unwrap();
if !stream.session.is_handshaking() { break }; if !stream.session.is_handshaking() { break };
match TlsStream::do_io(&mut stream.session, &mut stream.io, &mut stream.eof) { match stream.session.complete_io(&mut stream.io) {
Ok(()) => match (stream.eof, stream.session.is_handshaking()) { Ok(_) => (),
(true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
(false, true) => continue, Err(e) => return Err(e)
(..) => break
},
Err(e) => match (e.kind(), stream.session.is_handshaking()) {
(io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady),
(io::ErrorKind::WouldBlock, false) => break,
(..) => return Err(e)
}
} }
} }

View File

@ -59,6 +59,7 @@ fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
recv.recv().unwrap() recv.recv().unwrap()
} }
#[cfg(feature = "unstable-futures")]
fn start_server2(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr { fn start_server2(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
use futures::{ FutureExt, StreamExt }; use futures::{ FutureExt, StreamExt };
use futures::io::{ AsyncReadExt, AsyncWriteExt }; use futures::io::{ AsyncReadExt, AsyncWriteExt };
@ -136,16 +137,9 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor
let done = TcpStream::connect(addr) let done = TcpStream::connect(addr)
.and_then(|stream| config.connect_async(domain, stream)) .and_then(|stream| config.connect_async(domain, stream))
.and_then(|stream| { .and_then(|stream| stream.write_all(HELLO_WORLD))
eprintln!("WRITE: {:?}", stream); .and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()]))
stream.write_all(HELLO_WORLD)
})
.and_then(|(stream, _)| {
eprintln!("READ: {:?}", stream);
stream.read_exact(vec![0; HELLO_WORLD.len()])
})
.and_then(|(stream, buf)| { .and_then(|(stream, buf)| {
eprintln!("OK: {:?}", stream);
assert_eq!(buf, HELLO_WORLD); assert_eq!(buf, HELLO_WORLD);
Ok(()) Ok(())
}); });