From 438cb8f9c82a45b46f8b7e4d4c4d6da264a24573 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Fri, 1 Oct 2021 09:52:10 -0400 Subject: [PATCH] Implement AsRawFd for both tokio-rustls and tokio-native-tls TlsStream (#74) * implement AsRawFd for both tokio-rustls and tokio-native-tls TlsStream * implement windows' AsRawHandle * typo in cfg(windows) * use RawSocket, not RawHandle * implement AsRawFd & AsRawSocket for tokio_rustls::client::TlsStream and tokio_rustls::TlsStream enum --- tokio-native-tls/src/lib.rs | 34 ++++++++++++++++++++++++++-------- tokio-rustls/src/client.rs | 24 ++++++++++++++++++++++++ tokio-rustls/src/lib.rs | 24 ++++++++++++++++++++++++ tokio-rustls/src/server.rs | 25 +++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 8 deletions(-) diff --git a/tokio-native-tls/src/lib.rs b/tokio-native-tls/src/lib.rs index 6e650ac..d3ed938 100644 --- a/tokio-native-tls/src/lib.rs +++ b/tokio-native-tls/src/lib.rs @@ -35,6 +35,10 @@ use std::fmt; use std::future::Future; use std::io::{self, Read, Write}; use std::marker::Unpin; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::ptr::null_mut; use std::task::{Context, Poll}; @@ -167,18 +171,12 @@ impl TlsStream { } /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &native_tls::TlsStream> - where - S: AsyncRead + AsyncWrite + Unpin, - { + pub fn get_ref(&self) -> &native_tls::TlsStream> { &self.0 } /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> - where - S: AsyncRead + AsyncWrite + Unpin, - { + pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> { &mut self.0 } } @@ -221,6 +219,26 @@ where } } +#[cfg(unix)] +impl AsRawFd for TlsStream +where + S: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().get_ref().get_ref().as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + S: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().get_ref().get_ref().as_raw_socket() + } +} + async fn handshake(f: F, stream: S) -> Result, Error> where F: FnOnce( diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 3bd0e1f..7292b1a 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -1,5 +1,9 @@ use super::*; use crate::common::IoSession; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -27,6 +31,26 @@ impl TlsStream { } } +#[cfg(unix)] +impl AsRawFd for TlsStream +where + S: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + S: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().0.as_raw_socket() + } +} + impl IoSession for TlsStream { type Io = IO; type Session = ClientConnection; diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index a8e7302..fee3cce 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -17,6 +17,10 @@ use common::{MidHandshake, Stream, TlsState}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -261,6 +265,26 @@ impl From> for TlsStream { } } +#[cfg(unix)] +impl AsRawFd for TlsStream +where + S: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + S: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().0.as_raw_socket() + } +} + impl AsyncRead for TlsStream where T: AsyncRead + AsyncWrite + Unpin, diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs index cf30b11..4b8ec49 100644 --- a/tokio-rustls/src/server.rs +++ b/tokio-rustls/src/server.rs @@ -1,3 +1,8 @@ +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; + use super::*; use crate::common::IoSession; @@ -123,3 +128,23 @@ where stream.as_mut_pin().poll_shutdown(cx) } } + +#[cfg(unix)] +impl AsRawFd for TlsStream +where + IO: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + IO: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().0.as_raw_socket() + } +}