Allow access to all inner streams [tokio-native-tls] (#6)

Related: https://github.com/tokio-rs/tokio/issues/1383
This commit is contained in:
aloucks 2020-02-28 10:31:17 -05:00 committed by GitHub
parent 7e41beaff4
commit 1c3aeb691e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 5 deletions

View File

@ -40,12 +40,25 @@ use std::pin::Pin;
use std::ptr::null_mut; use std::ptr::null_mut;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
/// An intermediate wrapper for the inner stream `S`.
#[derive(Debug)] #[derive(Debug)]
struct AllowStd<S> { pub struct AllowStd<S> {
inner: S, inner: S,
context: *mut (), context: *mut (),
} }
impl<S> AllowStd<S> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S {
&self.inner
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
}
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
/// ///
@ -163,19 +176,19 @@ impl<S> TlsStream<S> {
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
&self.0.get_ref().inner &self.0
} }
/// Returns a mutable reference to the inner stream. /// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
&mut self.0.get_mut().inner &mut self.0
} }
} }

View File

@ -19,6 +19,14 @@ async fn client_to_server() {
let server = async move { let server = async move {
let (socket, _) = srv.accept().await.unwrap(); let (socket, _) = srv.accept().await.unwrap();
let mut socket = server_tls.accept(socket).await.unwrap(); let mut socket = server_tls.accept(socket).await.unwrap();
// Verify access to all of the nested inner streams (e.g. so that peer
// certificates can be accessed). This is just a compile check.
let native_tls_stream: &native_tls::TlsStream<_> = socket.get_ref();
let _peer_cert = native_tls_stream.peer_certificate().unwrap();
let allow_std_stream: &tokio_native_tls::AllowStd<_> = native_tls_stream.get_ref();
let _tokio_tcp_stream: &tokio::net::TcpStream = allow_std_stream.get_ref();
let mut data = Vec::new(); let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap(); socket.read_to_end(&mut data).await.unwrap();
data data