more complete handshake

This commit is contained in:
quininer 2018-09-24 23:02:30 +08:00
parent 30cacd04a0
commit 1f98d87a62
3 changed files with 31 additions and 17 deletions

View File

@ -12,9 +12,10 @@ use rustls::WriteV;
#[cfg(feature = "tokio-support")] #[cfg(feature = "tokio-support")]
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
pub struct Stream<'a, S: 'a, IO: 'a> { pub struct Stream<'a, S: 'a, IO: 'a> {
session: &'a mut S, pub session: &'a mut S,
io: &'a mut IO pub io: &'a mut IO
} }
pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write {

View File

@ -129,9 +129,11 @@ impl<IO, S> TlsStream<IO, S> {
} }
} }
impl<IO, S> From<(IO, S)> for TlsStream<IO, S> { impl<IO, S: Session> From<(IO, S)> for TlsStream<IO, S> {
#[inline] #[inline]
fn from((io, session): (IO, S)) -> TlsStream<IO, S> { fn from((io, session): (IO, S)) -> TlsStream<IO, S> {
assert!(!session.is_handshaking());
TlsStream { TlsStream {
is_shutdown: false, is_shutdown: false,
eof: false, eof: false,

View File

@ -5,6 +5,17 @@ use tokio::prelude::Poll;
use common::Stream; use common::Stream;
macro_rules! try_async {
( $e:expr ) => {
match $e {
Ok(n) => n,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
return Ok(Async::NotReady),
Err(e) => return Err(e)
}
}
}
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> { impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
type Item = TlsStream<IO, ClientSession>; type Item = TlsStream<IO, ClientSession>;
type Error = io::Error; type Error = io::Error;
@ -24,7 +35,9 @@ impl<IO: AsyncRead + AsyncWrite> Future for Accept<IO> {
} }
impl<IO, S> Future for MidHandshake<IO, S> impl<IO, S> Future for MidHandshake<IO, S>
where IO: io::Read + io::Write, S: Session where
IO: io::Read + io::Write,
S: Session
{ {
type Item = TlsStream<IO, S>; type Item = TlsStream<IO, S>;
type Error = io::Error; type Error = io::Error;
@ -32,15 +45,15 @@ impl<IO, S> Future for MidHandshake<IO, S>
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
{ {
let stream = self.inner.as_mut().unwrap(); let stream = self.inner.as_mut().unwrap();
if stream.session.is_handshaking() {
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(session, io); let mut stream = Stream::new(session, io);
match stream.complete_io() { if stream.session.is_handshaking() {
Ok(_) => (), try_async!(stream.complete_io());
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
Err(e) => return Err(e)
} }
if stream.session.wants_write() {
try_async!(stream.complete_io());
} }
} }
@ -69,12 +82,10 @@ impl<IO, S> AsyncWrite for TlsStream<IO, S>
self.is_shutdown = true; self.is_shutdown = true;
} }
match self.session.complete_io(&mut self.io) { {
Ok(_) => (), let mut stream = Stream::new(&mut self.session, &mut self.io);
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), try_async!(stream.complete_io());
Err(e) => return Err(e)
} }
self.io.shutdown() self.io.shutdown()
} }
} }