diff --git a/src/lib.rs b/src/lib.rs index d61caf0..43bd51f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,10 +28,12 @@ use rustls::{ use common::Stream; +#[derive(Clone)] pub struct TlsConnector { inner: Arc } +#[derive(Clone)] pub struct TlsAcceptor { inner: Arc } @@ -49,16 +51,16 @@ impl From> for TlsAcceptor { } impl TlsConnector { - pub fn connect(&self, domain: DNSNameRef, stream: S) -> Connect - where S: io::Read + io::Write + pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect + where IO: io::Read + io::Write { Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) } #[inline] - pub fn connect_with_session(stream: S, session: ClientSession) - -> Connect - where S: io::Read + io::Write + pub fn connect_with_session(stream: IO, session: ClientSession) + -> Connect + where IO: io::Read + io::Write { Connect(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) @@ -67,15 +69,15 @@ impl TlsConnector { } impl TlsAcceptor { - pub fn accept(&self, stream: S) -> Accept - where S: io::Read + io::Write, + pub fn accept(&self, stream: IO) -> Accept + where IO: io::Read + io::Write, { Self::accept_with_session(stream, ServerSession::new(&self.inner)) } #[inline] - pub fn accept_with_session(stream: S, session: ServerSession) -> Accept - where S: io::Read + io::Write + pub fn accept_with_session(stream: IO, session: ServerSession) -> Accept + where IO: io::Read + io::Write { Accept(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) @@ -86,43 +88,64 @@ impl TlsAcceptor { /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct Connect(MidHandshake); +pub struct Connect(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct Accept(MidHandshake); +pub struct Accept(MidHandshake); -struct MidHandshake { - inner: Option> +struct MidHandshake { + inner: Option> } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] -pub struct TlsStream { +pub struct TlsStream { is_shutdown: bool, eof: bool, - io: S, - session: C + io: IO, + session: S } -impl TlsStream { +impl TlsStream { #[inline] - pub fn get_ref(&self) -> (&S, &C) { + pub fn get_ref(&self) -> (&IO, &S) { (&self.io, &self.session) } #[inline] - pub fn get_mut(&mut self) -> (&mut S, &mut C) { + pub fn get_mut(&mut self) -> (&mut IO, &mut S) { (&mut self.io, &mut self.session) } + + #[inline] + pub fn into_inner(self) -> (IO, S) { + (self.io, self.session) + } } -impl io::Read for TlsStream - where S: io::Read + io::Write, C: Session +impl From<(IO, S)> for TlsStream { + #[inline] + fn from((io, session): (IO, S)) -> TlsStream { + TlsStream { + is_shutdown: false, + eof: false, + io, session + } + } +} + +impl io::Read for TlsStream + where IO: io::Read + io::Write, S: Session { + #[cfg(feature = "nightly")] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } + fn read(&mut self, buf: &mut [u8]) -> io::Result { if self.eof { return Ok(0); @@ -142,8 +165,8 @@ impl io::Read for TlsStream } } -impl io::Write for TlsStream - where S: io::Read + io::Write, C: Session +impl io::Write for TlsStream + where IO: io::Read + io::Write, S: Session { fn write(&mut self, buf: &[u8]) -> io::Result { Stream::new(&mut self.session, &mut self.io).write(buf) diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 11179dc..00b4722 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -5,8 +5,8 @@ use tokio::prelude::Poll; use common::Stream; -impl Future for Connect { - type Item = TlsStream; +impl Future for Connect { + type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -14,8 +14,8 @@ impl Future for Connect { } } -impl Future for Accept { - type Item = TlsStream; +impl Future for Accept { + type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -23,10 +23,10 @@ impl Future for Accept { } } -impl Future for MidHandshake - where S: io::Read + io::Write, C: Session +impl Future for MidHandshake + where IO: io::Read + io::Write, S: Session { - type Item = TlsStream; + type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { @@ -48,20 +48,20 @@ impl Future for MidHandshake } } -impl AsyncRead for TlsStream +impl AsyncRead for TlsStream where - S: AsyncRead + AsyncWrite, - C: Session + IO: AsyncRead + AsyncWrite, + S: Session { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false } } -impl AsyncWrite for TlsStream +impl AsyncWrite for TlsStream where - S: AsyncRead + AsyncWrite, - C: Session + IO: AsyncRead + AsyncWrite, + S: Session { fn shutdown(&mut self) -> Poll<(), io::Error> { if !self.is_shutdown {