Implement AsRawFd for both tokio-rustls and tokio-native-tls TlsStream<S> (#74)

* implement AsRawFd for both tokio-rustls and tokio-native-tls TlsStream<S>

* implement windows' AsRawHandle

* typo in cfg(windows)

* use RawSocket, not RawHandle

* implement AsRawFd & AsRawSocket for tokio_rustls::client::TlsStream and tokio_rustls::TlsStream enum
This commit is contained in:
Jerome Gravel-Niquet 2021-10-01 09:52:10 -04:00 committed by GitHub
parent 8501aafae5
commit 438cb8f9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 99 additions and 8 deletions

View File

@ -35,6 +35,10 @@ use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::marker::Unpin;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::ptr::null_mut;
use std::task::{Context, Poll};
@ -167,18 +171,12 @@ impl<S> TlsStream<S> {
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> {
&self.0
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> {
&mut self.0
}
}
@ -221,6 +219,26 @@ where
}
}
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().get_ref().get_ref().as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().get_ref().get_ref().as_raw_socket()
}
}
async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
where
F: FnOnce(

View File

@ -1,5 +1,9 @@
use super::*;
use crate::common::IoSession;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
@ -27,6 +31,26 @@ impl<IO> TlsStream<IO> {
}
}
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}
impl<IO> IoSession for TlsStream<IO> {
type Io = IO;
type Session = ClientConnection;

View File

@ -17,6 +17,10 @@ use common::{MidHandshake, Stream, TlsState};
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use std::future::Future;
use std::io;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
@ -261,6 +265,26 @@ impl<T> From<server::TlsStream<T>> for TlsStream<T> {
}
}
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}
impl<T> AsyncRead for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,

View File

@ -1,3 +1,8 @@
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use super::*;
use crate::common::IoSession;
@ -123,3 +128,23 @@ where
stream.as_mut_pin().poll_shutdown(cx)
}
}
#[cfg(unix)]
impl<IO> AsRawFd for TlsStream<IO>
where
IO: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<IO> AsRawSocket for TlsStream<IO>
where
IO: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}