diff --git a/src/lib.rs b/src/lib.rs index 293a112..3337e0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ extern crate webpki; use std::io; use std::sync::Arc; +use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, @@ -17,7 +18,7 @@ use rustls::{ /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ClientConfigExt: sealed::Sealed { - fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) + fn connect_async(&self, domain: DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write; } @@ -41,7 +42,7 @@ pub struct AcceptAsync(MidHandshake); impl sealed::Sealed for Arc {} impl ClientConfigExt for Arc { - fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) + fn connect_async(&self, domain: DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write { @@ -54,7 +55,9 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: io::Read + io::Write { - ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) + ConnectAsync(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + }) } impl sealed::Sealed for Arc {} @@ -73,7 +76,9 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: io::Read + io::Write { - AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) + AcceptAsync(MidHandshake { + inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + }) } @@ -87,7 +92,6 @@ struct MidHandshake { #[derive(Debug)] pub struct TlsStream { is_shutdown: bool, - eof: bool, io: S, session: C } @@ -104,78 +108,6 @@ impl TlsStream { } } -impl TlsStream - where S: io::Read + io::Write, C: Session -{ - #[inline] - fn new(io: S, session: C) -> TlsStream { - TlsStream { - is_shutdown: false, - eof: false, - io: io, - session: session - } - } - - fn do_read(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result { - if !*eof && session.wants_read() { - if session.read_tls(io)? == 0 { - *eof = true; - } - - if let Err(err) = session.process_new_packets() { - // flush queued messages before returning an Err in - // order to send alerts instead of abruptly closing - // the socket - if session.wants_write() { - // ignore result to avoid masking original error - let _ = session.write_tls(io); - } - return Err(io::Error::new(io::ErrorKind::InvalidData, err)); - } - - Ok(true) - } else { - Ok(false) - } - } - - fn do_write(session: &mut C, io: &mut S) -> io::Result { - if session.wants_write() { - session.write_tls(io)?; - - Ok(true) - } else { - Ok(false) - } - } - - #[inline] - pub fn do_io(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result<()> { - macro_rules! try_wouldblock { - ( $r:expr ) => { - match $r { - Ok(true) => continue, - Ok(false) => false, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - }; - } - - loop { - let write_would_block = try_wouldblock!(Self::do_write(session, io)); - let read_would_block = try_wouldblock!(Self::do_read(session, io, eof)); - - if write_would_block || read_would_block { - return Err(io::Error::from(io::ErrorKind::WouldBlock)); - } else { - return Ok(()); - } - } - } -} - impl io::Read for TlsStream where S: io::Read + io::Write, C: Session {