diff --git a/tokio-native-tls/src/lib.rs b/tokio-native-tls/src/lib.rs index 6e650ac..d3ed938 100644 --- a/tokio-native-tls/src/lib.rs +++ b/tokio-native-tls/src/lib.rs @@ -35,6 +35,10 @@ use std::fmt; use std::future::Future; use std::io::{self, Read, Write}; use std::marker::Unpin; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::ptr::null_mut; use std::task::{Context, Poll}; @@ -167,18 +171,12 @@ impl TlsStream { } /// Returns a shared reference to the inner stream. - pub fn get_ref(&self) -> &native_tls::TlsStream> - where - S: AsyncRead + AsyncWrite + Unpin, - { + pub fn get_ref(&self) -> &native_tls::TlsStream> { &self.0 } /// Returns a mutable reference to the inner stream. - pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> - where - S: AsyncRead + AsyncWrite + Unpin, - { + pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> { &mut self.0 } } @@ -221,6 +219,26 @@ where } } +#[cfg(unix)] +impl AsRawFd for TlsStream +where + S: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().get_ref().get_ref().as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + S: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().get_ref().get_ref().as_raw_socket() + } +} + async fn handshake(f: F, stream: S) -> Result, Error> where F: FnOnce( diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 3bd0e1f..7292b1a 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -1,5 +1,9 @@ use super::*; use crate::common::IoSession; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -27,6 +31,26 @@ impl TlsStream { } } +#[cfg(unix)] +impl AsRawFd for TlsStream +where + S: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + S: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().0.as_raw_socket() + } +} + impl IoSession for TlsStream { type Io = IO; type Session = ClientConnection; diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index a8e7302..fee3cce 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -17,6 +17,10 @@ use common::{MidHandshake, Stream, TlsState}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -261,6 +265,26 @@ impl From> for TlsStream { } } +#[cfg(unix)] +impl AsRawFd for TlsStream +where + S: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + S: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().0.as_raw_socket() + } +} + impl AsyncRead for TlsStream where T: AsyncRead + AsyncWrite + Unpin, diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs index cf30b11..4b8ec49 100644 --- a/tokio-rustls/src/server.rs +++ b/tokio-rustls/src/server.rs @@ -1,3 +1,8 @@ +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; + use super::*; use crate::common::IoSession; @@ -123,3 +128,23 @@ where stream.as_mut_pin().poll_shutdown(cx) } } + +#[cfg(unix)] +impl AsRawFd for TlsStream +where + IO: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.get_ref().0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for TlsStream +where + IO: AsRawSocket, +{ + fn as_raw_socket(&self) -> RawSocket { + self.get_ref().0.as_raw_socket() + } +}