From 1c3aeb691e024ddaaee09e03994b1ed36ffc0546 Mon Sep 17 00:00:00 2001 From: aloucks Date: Fri, 28 Feb 2020 10:31:17 -0500 Subject: [PATCH] Allow access to all inner streams [tokio-native-tls] (#6) Related: https://github.com/tokio-rs/tokio/issues/1383 --- tokio-native-tls/src/lib.rs | 23 ++++++++++++++++++----- tokio-native-tls/tests/smoke.rs | 8 ++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/tokio-native-tls/src/lib.rs b/tokio-native-tls/src/lib.rs index 2770650..ba1e67d 100644 --- a/tokio-native-tls/src/lib.rs +++ b/tokio-native-tls/src/lib.rs @@ -40,12 +40,25 @@ use std::pin::Pin; use std::ptr::null_mut; use std::task::{Context, Poll}; +/// An intermediate wrapper for the inner stream `S`. #[derive(Debug)] -struct AllowStd { +pub struct AllowStd { inner: S, context: *mut (), } +impl AllowStd { + /// Returns a shared reference to the inner stream. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Returns a mutable reference to the inner stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } +} + /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. /// @@ -163,19 +176,19 @@ impl TlsStream { } /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &S + pub fn get_ref(&self) -> &native_tls::TlsStream> where S: AsyncRead + AsyncWrite + Unpin, { - &self.0.get_ref().inner + &self.0 } /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut S + pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> where S: AsyncRead + AsyncWrite + Unpin, { - &mut self.0.get_mut().inner + &mut self.0 } } diff --git a/tokio-native-tls/tests/smoke.rs b/tokio-native-tls/tests/smoke.rs index 48c29a0..193dc51 100644 --- a/tokio-native-tls/tests/smoke.rs +++ b/tokio-native-tls/tests/smoke.rs @@ -19,6 +19,14 @@ async fn client_to_server() { let server = async move { let (socket, _) = srv.accept().await.unwrap(); let mut socket = server_tls.accept(socket).await.unwrap(); + + // Verify access to all of the nested inner streams (e.g. so that peer + // certificates can be accessed). This is just a compile check. + let native_tls_stream: &native_tls::TlsStream<_> = socket.get_ref(); + let _peer_cert = native_tls_stream.peer_certificate().unwrap(); + let allow_std_stream: &tokio_native_tls::AllowStd<_> = native_tls_stream.get_ref(); + let _tokio_tcp_stream: &tokio::net::TcpStream = allow_std_stream.get_ref(); + let mut data = Vec::new(); socket.read_to_end(&mut data).await.unwrap(); data