Compare commits

...

1 Commits

Author SHA1 Message Date
Brian Picciano 18fd688b33 Implement TransparentConfigAcceptor 10 months ago
  1. 142
      src/common/mod.rs
  2. 96
      src/lib.rs

@ -359,5 +359,147 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
}
}
/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements
/// AsyncRead + AsyncWrite.
pub struct AsyncReadWrite<R, W> {
r: Pin<Box<R>>,
w: Pin<Box<W>>,
}
impl<R, W> AsyncReadWrite<R, W>
where
R: Unpin,
W: Unpin,
{
pub fn new(r: R, w: W) -> Self {
Self {
r: Box::pin(r),
w: Box::pin(w),
}
}
pub fn into_inner(self) -> (R, W) {
(*Pin::into_inner(self.r), *Pin::into_inner(self.w))
}
}
impl<R, W> AsyncRead for AsyncReadWrite<R, W>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.r.as_mut().poll_read(cx, buf)
}
}
impl<R, W> AsyncWrite for AsyncReadWrite<R, W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.w.as_mut().poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.w.as_mut().poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.w.as_mut().poll_shutdown(cx)
}
}
/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal
/// buffer.
pub struct AsyncReadCapture<R> {
r: Pin<Box<R>>,
buf: Vec<u8>,
}
impl<R> AsyncReadCapture<R>
where
R: AsyncRead + Unpin,
{
/// Initializes an AsyncReadCapture with an empty internal buffer of the given size.
pub fn with_capacity(r: R, cap: usize) -> Self {
Self {
r: Box::pin(r),
buf: Vec::with_capacity(cap),
}
}
pub fn into_inner(self) -> (R, Vec<u8>) {
(*Pin::into_inner(self.r), self.buf)
}
}
impl<R> AsyncRead for AsyncReadCapture<R>
where
R: AsyncRead,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let res = self.r.as_mut().poll_read(cx, buf);
if let Poll::Ready(Ok(())) = res {
self.buf.extend_from_slice(buf.filled());
}
res
}
}
pub struct AsyncReadPrefixed<R> {
r: Pin<Box<R>>,
prefix: Vec<u8>,
}
impl<R> AsyncReadPrefixed<R>
where
R: AsyncRead + Unpin,
{
pub fn new(r: R, prefix: Vec<u8>) -> Self {
Self {
r: Box::pin(r),
prefix,
}
}
}
impl<R> AsyncRead for AsyncReadPrefixed<R>
where
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let prefix_len = this.prefix.len();
if prefix_len == 0 {
return this.r.as_mut().poll_read(cx, buf);
}
let n = std::cmp::min(prefix_len, buf.remaining());
let to_write = this.prefix.drain(..n);
buf.put_slice(to_write.as_slice());
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod test_stream;

@ -49,7 +49,7 @@ pub mod client;
mod common;
pub mod server;
use common::{MidHandshake, Stream, TlsState};
use common::{AsyncReadCapture, AsyncReadPrefixed, AsyncReadWrite, MidHandshake, Stream, TlsState};
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use std::future::Future;
use std::io;
@ -60,7 +60,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
pub use rustls;
@ -333,6 +333,98 @@ where
}
}
type IOWithCapture<IO> = AsyncReadWrite<AsyncReadCapture<ReadHalf<IO>>, WriteHalf<IO>>;
type IOWithPrefix<IO> = AsyncReadWrite<AsyncReadPrefixed<ReadHalf<IO>>, WriteHalf<IO>>;
fn unwrap_io_with_capture<IO>(
io_with_capture: IOWithCapture<IO>,
) -> (ReadHalf<IO>, WriteHalf<IO>, Vec<u8>)
where
IO: AsyncRead + AsyncWrite + Unpin,
{
let (r, w) = io_with_capture.into_inner();
let (r, bytes_read) = r.into_inner();
(r, w, bytes_read)
}
pub struct TransparentConfigAcceptor<IO> {
acceptor: Pin<Box<LazyConfigAcceptor<IOWithCapture<IO>>>>,
}
impl<IO> TransparentConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
let (r, w) = tokio::io::split(io);
let r = AsyncReadCapture::with_capacity(r, 1024);
let rw = AsyncReadWrite::new(r, w);
Self {
acceptor: Box::pin(LazyConfigAcceptor::new(acceptor, rw)),
}
}
}
impl<IO> Future for TransparentConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<TransparentStartHandshake<IO>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut().acceptor.as_mut().poll(cx) {
Poll::Ready(Ok(h)) => {
let (r, w, bytes_read) = unwrap_io_with_capture(h.io);
Poll::Ready(Ok(TransparentStartHandshake {
accepted: h.accepted,
r,
w,
bytes_read,
}))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
pub struct TransparentStartHandshake<IO> {
accepted: rustls::server::Accepted,
r: ReadHalf<IO>,
w: WriteHalf<IO>,
bytes_read: Vec<u8>,
}
impl<IO> TransparentStartHandshake<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
self.accepted.client_hello()
}
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
self.into_stream_with(config, |_| ())
}
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
where
F: FnOnce(&mut ServerConnection),
{
let start_handshake = StartHandshake {
accepted: self.accepted,
io: self.r.unsplit(self.w),
};
start_handshake.into_stream_with(config, f)
}
pub fn into_original_stream(self) -> IOWithPrefix<IO> {
let r = AsyncReadPrefixed::new(self.r, self.bytes_read);
AsyncReadWrite::new(r, self.w)
}
}
/// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished.
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);

Loading…
Cancel
Save