Implement AsRawFd for both tokio-rustls and tokio-native-tls TlsStream<S> (#74)

* implement AsRawFd for both tokio-rustls and tokio-native-tls TlsStream<S>

* implement windows' AsRawHandle

* typo in cfg(windows)

* use RawSocket, not RawHandle

* implement AsRawFd & AsRawSocket for tokio_rustls::client::TlsStream and tokio_rustls::TlsStream enum
This commit is contained in:
Jerome Gravel-Niquet 2021-10-01 09:52:10 -04:00 committed by GitHub
parent 8501aafae5
commit 438cb8f9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 99 additions and 8 deletions

View File

@ -35,6 +35,10 @@ use std::fmt;
use std::future::Future; use std::future::Future;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::marker::Unpin; use std::marker::Unpin;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin; use std::pin::Pin;
use std::ptr::null_mut; use std::ptr::null_mut;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -167,18 +171,12 @@ impl<S> TlsStream<S> {
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> {
where
S: AsyncRead + AsyncWrite + Unpin,
{
&self.0 &self.0
} }
/// Returns a mutable reference to the inner stream. /// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> {
where
S: AsyncRead + AsyncWrite + Unpin,
{
&mut self.0 &mut self.0
} }
} }
@ -221,6 +219,26 @@ where
} }
} }
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().get_ref().get_ref().as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().get_ref().get_ref().as_raw_socket()
}
}
async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error> async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
where where
F: FnOnce( F: FnOnce(

View File

@ -1,5 +1,9 @@
use super::*; use super::*;
use crate::common::IoSession; use crate::common::IoSession;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
@ -27,6 +31,26 @@ impl<IO> TlsStream<IO> {
} }
} }
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}
impl<IO> IoSession for TlsStream<IO> { impl<IO> IoSession for TlsStream<IO> {
type Io = IO; type Io = IO;
type Session = ClientConnection; type Session = ClientConnection;

View File

@ -17,6 +17,10 @@ use common::{MidHandshake, Stream, TlsState};
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use std::future::Future; use std::future::Future;
use std::io; use std::io;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -261,6 +265,26 @@ impl<T> From<server::TlsStream<T>> for TlsStream<T> {
} }
} }
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}
impl<T> AsyncRead for TlsStream<T> impl<T> AsyncRead for TlsStream<T>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,

View File

@ -1,3 +1,8 @@
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use super::*; use super::*;
use crate::common::IoSession; use crate::common::IoSession;
@ -123,3 +128,23 @@ where
stream.as_mut_pin().poll_shutdown(cx) stream.as_mut_pin().poll_shutdown(cx)
} }
} }
#[cfg(unix)]
impl<IO> AsRawFd for TlsStream<IO>
where
IO: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<IO> AsRawSocket for TlsStream<IO>
where
IO: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}