diff --git a/src/common/mod.rs b/src/common/mod.rs index fde34c0..1cc976a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -359,5 +359,147 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { } } +/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements +/// AsyncRead + AsyncWrite. +pub struct AsyncReadWrite { + r: Pin>, + w: Pin>, +} + +impl AsyncReadWrite +where + R: Unpin, + W: Unpin, +{ + pub fn new(r: R, w: W) -> Self { + Self { + r: Box::pin(r), + w: Box::pin(w), + } + } + + pub fn into_inner(self) -> (R, W) { + (*Pin::into_inner(self.r), *Pin::into_inner(self.w)) + } +} + +impl AsyncRead for AsyncReadWrite +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.r.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncReadWrite +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.w.as_mut().poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.w.as_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.w.as_mut().poll_shutdown(cx) + } +} + +/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal +/// buffer. +pub struct AsyncReadCapture { + r: Pin>, + buf: Vec, +} + +impl AsyncReadCapture +where + R: AsyncRead + Unpin, +{ + /// Initializes an AsyncReadCapture with an empty internal buffer of the given size. + pub fn with_capacity(r: R, cap: usize) -> Self { + Self { + r: Box::pin(r), + buf: Vec::with_capacity(cap), + } + } + + pub fn into_inner(self) -> (R, Vec) { + (*Pin::into_inner(self.r), self.buf) + } +} + +impl AsyncRead for AsyncReadCapture +where + R: AsyncRead, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let res = self.r.as_mut().poll_read(cx, buf); + + if let Poll::Ready(Ok(())) = res { + self.buf.extend_from_slice(buf.filled()); + } + + res + } +} + +pub struct AsyncReadPrefixed { + r: Pin>, + prefix: Vec, +} + +impl AsyncReadPrefixed +where + R: AsyncRead + Unpin, +{ + pub fn new(r: R, prefix: Vec) -> Self { + Self { + r: Box::pin(r), + prefix, + } + } +} + +impl AsyncRead for AsyncReadPrefixed +where + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + + let prefix_len = this.prefix.len(); + if prefix_len == 0 { + return this.r.as_mut().poll_read(cx, buf); + } + + let n = std::cmp::min(prefix_len, buf.remaining()); + let to_write = this.prefix.drain(..n); + + buf.put_slice(to_write.as_slice()); + Poll::Ready(Ok(())) + } +} + #[cfg(test)] mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index 000245c..8574cda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ pub mod client; mod common; pub mod server; -use common::{MidHandshake, Stream, TlsState}; +use common::{AsyncReadCapture, AsyncReadPrefixed, AsyncReadWrite, MidHandshake, Stream, TlsState}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; @@ -60,7 +60,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}; pub use rustls; @@ -333,6 +333,98 @@ where } } +type IOWithCapture = AsyncReadWrite>, WriteHalf>; +type IOWithPrefix = AsyncReadWrite>, WriteHalf>; + +fn unwrap_io_with_capture( + io_with_capture: IOWithCapture, +) -> (ReadHalf, WriteHalf, Vec) +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + let (r, w) = io_with_capture.into_inner(); + let (r, bytes_read) = r.into_inner(); + (r, w, bytes_read) +} + +pub struct TransparentConfigAcceptor { + acceptor: Pin>>>, +} + +impl TransparentConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { + let (r, w) = tokio::io::split(io); + let r = AsyncReadCapture::with_capacity(r, 1024); + let rw = AsyncReadWrite::new(r, w); + Self { + acceptor: Box::pin(LazyConfigAcceptor::new(acceptor, rw)), + } + } +} + +impl Future for TransparentConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut().acceptor.as_mut().poll(cx) { + Poll::Ready(Ok(h)) => { + let (r, w, bytes_read) = unwrap_io_with_capture(h.io); + Poll::Ready(Ok(TransparentStartHandshake { + accepted: h.accepted, + r, + w, + bytes_read, + })) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +pub struct TransparentStartHandshake { + accepted: rustls::server::Accepted, + r: ReadHalf, + w: WriteHalf, + bytes_read: Vec, +} + +impl TransparentStartHandshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { + self.accepted.client_hello() + } + + pub fn into_stream(self, config: Arc) -> Accept { + self.into_stream_with(config, |_| ()) + } + + pub fn into_stream_with(self, config: Arc, f: F) -> Accept + where + F: FnOnce(&mut ServerConnection), + { + let start_handshake = StartHandshake { + accepted: self.accepted, + io: self.r.unsplit(self.w), + }; + + start_handshake.into_stream_with(config, f) + } + + pub fn into_original_stream(self) -> IOWithPrefix { + let r = AsyncReadPrefixed::new(self.r, self.bytes_read); + AsyncReadWrite::new(r, self.w) + } +} + /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(MidHandshake>);