diff --git a/src/common/mod.rs b/src/common/mod.rs index fde34c0..16e735c 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -359,5 +359,106 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { } } +/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements +/// AsyncRead + AsyncWrite. +pub struct AsyncReadWrite { + r: Pin>, + w: Pin>, +} + +impl AsyncReadWrite +where + R: Unpin, + W: Unpin, +{ + pub fn new(r: R, w: W) -> Self { + Self { + r: Box::pin(r), + w: Box::pin(w), + } + } + + pub fn into_inner(self) -> (R, W) { + (*Pin::into_inner(self.r), *Pin::into_inner(self.w)) + } +} + +impl AsyncRead for AsyncReadWrite +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.r.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncReadWrite +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.w.as_mut().poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.w.as_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.w.as_mut().poll_shutdown(cx) + } +} + +/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal +/// buffer. +pub struct AsyncReadCapture { + r: Pin>, + buf: Vec, +} + +impl AsyncReadCapture +where + R: AsyncRead + Unpin, +{ + /// Initializes an AsyncReadCapture with an empty internal buffer of the given size. + pub fn with_capacity(r: R, cap: usize) -> Self { + Self { + r: Box::pin(r), + buf: Vec::with_capacity(cap), + } + } + + pub fn into_inner(self) -> (R, Vec) { + (*Pin::into_inner(self.r), self.buf) + } +} + +impl AsyncRead for AsyncReadCapture +where + R: AsyncRead, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let res = self.r.as_mut().poll_read(cx, buf); + + if let Poll::Ready(Ok(())) = res { + self.buf.extend_from_slice(buf.filled()); + } + + res + } +} + #[cfg(test)] mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index 000245c..576d247 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ pub mod client; mod common; pub mod server; -use common::{MidHandshake, Stream, TlsState}; +use common::{AsyncReadCapture, AsyncReadWrite, MidHandshake, Stream, TlsState}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; @@ -333,6 +333,44 @@ where } } +type IOWithCapture = AsyncReadWrite, IO>; + +pub struct TransparentConfigAcceptor { + acceptor: LazyConfigAcceptor>, +} + +impl TransparentConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { + let r = AsyncReadCapture::with_capacity(io, 512); + let rw = AsyncReadWrite::new(r, io); + Self { + acceptor: LazyConfigAcceptor::new(acceptor, rw), + } + } + + //pub fn into_lazy_config_acceptor(self) -> LazyConfigAcceptor { + // self.acceptor + //} +} + +//impl Future for TransparentConfigAcceptor +//where +// IO: AsyncRead + AsyncWrite + Unpin, +//{ +// type Output = io::Result>; +// +// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { +// let this = self.get_mut(); +// } +//} + +pub struct TransparentStartHandshake { + h: StartHandshake, +} + /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(MidHandshake>);