From 3e2c0446a41cca4873f16e4909527c5f49a21f35 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 6 Nov 2019 11:41:59 +0100 Subject: [PATCH] Port unified TLS stream type to tokio-0.2 --- Cargo.toml | 1 + src/lib.rs | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c2bfee9..70b7968 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } smallvec = "0.6" tokio-io = "=0.2.0-alpha.6" futures-core-preview = "=0.3.0-alpha.19" +pin-project = "0.4" rustls = "0.16" webpki = "0.21" diff --git a/src/lib.rs b/src/lib.rs index 3dea67f..9b9fd58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,8 @@ pub mod server; use common::Stream; use futures_core as futures; -use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use pin_project::{pin_project, project}; +use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -195,3 +196,113 @@ impl Future for Accept { Pin::new(&mut self.0).poll(cx) } } + +/// Unified TLS stream type +/// +/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use +/// a single type to keep both client- and server-initiated TLS-encrypted connections. +#[pin_project] +pub enum TlsStream { + Client(#[pin] client::TlsStream), + Server(#[pin] server::TlsStream), +} + +impl TlsStream { + pub fn get_ref(&self) -> (&T, &dyn Session) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_ref(); + (io, &*session) + } + Server(io) => { + let (io, session) = io.get_ref(); + (io, &*session) + } + } + } + + pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + Server(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + } + } +} + +impl From> for TlsStream { + fn from(s: client::TlsStream) -> Self { + Self::Client(s) + } +} + +impl From> for TlsStream { + fn from(s: server::TlsStream) -> Self { + Self::Server(s) + } +} + +impl AsyncRead for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[project] + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_read(cx, buf), + TlsStream::Server(x) => x.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[project] + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_write(cx, buf), + TlsStream::Server(x) => x.poll_write(cx, buf), + } + } + + #[project] + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_flush(cx), + TlsStream::Server(x) => x.poll_flush(cx), + } + } + + #[project] + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + TlsStream::Client(x) => x.poll_shutdown(cx), + TlsStream::Server(x) => x.poll_shutdown(cx), + } + } +}