Compare commits
1 Commits
main
...
transparen
Author | SHA1 | Date | |
---|---|---|---|
|
18fd688b33 |
@ -359,5 +359,147 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements
|
||||
/// AsyncRead + AsyncWrite.
|
||||
pub struct AsyncReadWrite<R, W> {
|
||||
r: Pin<Box<R>>,
|
||||
w: Pin<Box<W>>,
|
||||
}
|
||||
|
||||
impl<R, W> AsyncReadWrite<R, W>
|
||||
where
|
||||
R: Unpin,
|
||||
W: Unpin,
|
||||
{
|
||||
pub fn new(r: R, w: W) -> Self {
|
||||
Self {
|
||||
r: Box::pin(r),
|
||||
w: Box::pin(w),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> (R, W) {
|
||||
(*Pin::into_inner(self.r), *Pin::into_inner(self.w))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, W> AsyncRead for AsyncReadWrite<R, W>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.r.as_mut().poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, W> AsyncWrite for AsyncReadWrite<R, W>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.w.as_mut().poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.w.as_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.w.as_mut().poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal
|
||||
/// buffer.
|
||||
pub struct AsyncReadCapture<R> {
|
||||
r: Pin<Box<R>>,
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<R> AsyncReadCapture<R>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
/// Initializes an AsyncReadCapture with an empty internal buffer of the given size.
|
||||
pub fn with_capacity(r: R, cap: usize) -> Self {
|
||||
Self {
|
||||
r: Box::pin(r),
|
||||
buf: Vec::with_capacity(cap),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> (R, Vec<u8>) {
|
||||
(*Pin::into_inner(self.r), self.buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> AsyncRead for AsyncReadCapture<R>
|
||||
where
|
||||
R: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let res = self.r.as_mut().poll_read(cx, buf);
|
||||
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
self.buf.extend_from_slice(buf.filled());
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AsyncReadPrefixed<R> {
|
||||
r: Pin<Box<R>>,
|
||||
prefix: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<R> AsyncReadPrefixed<R>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
pub fn new(r: R, prefix: Vec<u8>) -> Self {
|
||||
Self {
|
||||
r: Box::pin(r),
|
||||
prefix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> AsyncRead for AsyncReadPrefixed<R>
|
||||
where
|
||||
R: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
let prefix_len = this.prefix.len();
|
||||
if prefix_len == 0 {
|
||||
return this.r.as_mut().poll_read(cx, buf);
|
||||
}
|
||||
|
||||
let n = std::cmp::min(prefix_len, buf.remaining());
|
||||
let to_write = this.prefix.drain(..n);
|
||||
|
||||
buf.put_slice(to_write.as_slice());
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_stream;
|
||||
|
96
src/lib.rs
96
src/lib.rs
@ -49,7 +49,7 @@ pub mod client;
|
||||
mod common;
|
||||
pub mod server;
|
||||
|
||||
use common::{MidHandshake, Stream, TlsState};
|
||||
use common::{AsyncReadCapture, AsyncReadPrefixed, AsyncReadWrite, MidHandshake, Stream, TlsState};
|
||||
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
@ -60,7 +60,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
|
||||
|
||||
pub use rustls;
|
||||
|
||||
@ -333,6 +333,98 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
type IOWithCapture<IO> = AsyncReadWrite<AsyncReadCapture<ReadHalf<IO>>, WriteHalf<IO>>;
|
||||
type IOWithPrefix<IO> = AsyncReadWrite<AsyncReadPrefixed<ReadHalf<IO>>, WriteHalf<IO>>;
|
||||
|
||||
fn unwrap_io_with_capture<IO>(
|
||||
io_with_capture: IOWithCapture<IO>,
|
||||
) -> (ReadHalf<IO>, WriteHalf<IO>, Vec<u8>)
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let (r, w) = io_with_capture.into_inner();
|
||||
let (r, bytes_read) = r.into_inner();
|
||||
(r, w, bytes_read)
|
||||
}
|
||||
|
||||
pub struct TransparentConfigAcceptor<IO> {
|
||||
acceptor: Pin<Box<LazyConfigAcceptor<IOWithCapture<IO>>>>,
|
||||
}
|
||||
|
||||
impl<IO> TransparentConfigAcceptor<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
|
||||
let (r, w) = tokio::io::split(io);
|
||||
let r = AsyncReadCapture::with_capacity(r, 1024);
|
||||
let rw = AsyncReadWrite::new(r, w);
|
||||
Self {
|
||||
acceptor: Box::pin(LazyConfigAcceptor::new(acceptor, rw)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Future for TransparentConfigAcceptor<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = io::Result<TransparentStartHandshake<IO>>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.get_mut().acceptor.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(h)) => {
|
||||
let (r, w, bytes_read) = unwrap_io_with_capture(h.io);
|
||||
Poll::Ready(Ok(TransparentStartHandshake {
|
||||
accepted: h.accepted,
|
||||
r,
|
||||
w,
|
||||
bytes_read,
|
||||
}))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TransparentStartHandshake<IO> {
|
||||
accepted: rustls::server::Accepted,
|
||||
r: ReadHalf<IO>,
|
||||
w: WriteHalf<IO>,
|
||||
bytes_read: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<IO> TransparentStartHandshake<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
|
||||
self.accepted.client_hello()
|
||||
}
|
||||
|
||||
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
|
||||
self.into_stream_with(config, |_| ())
|
||||
}
|
||||
|
||||
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
|
||||
where
|
||||
F: FnOnce(&mut ServerConnection),
|
||||
{
|
||||
let start_handshake = StartHandshake {
|
||||
accepted: self.accepted,
|
||||
io: self.r.unsplit(self.w),
|
||||
};
|
||||
|
||||
start_handshake.into_stream_with(config, f)
|
||||
}
|
||||
|
||||
pub fn into_original_stream(self) -> IOWithPrefix<IO> {
|
||||
let r = AsyncReadPrefixed::new(self.r, self.bytes_read);
|
||||
AsyncReadWrite::new(r, self.w)
|
||||
}
|
||||
}
|
||||
|
||||
/// Future returned from `TlsConnector::connect` which will resolve
|
||||
/// once the connection handshake has finished.
|
||||
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
|
||||
|
Loading…
Reference in New Issue
Block a user