You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
553 lines
16 KiB
553 lines
16 KiB
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
|
|
//!
|
|
//! # Why do I need to call `poll_flush`?
|
|
//!
|
|
//! Most TLS implementations will have an internal buffer to improve throughput,
|
|
//! and rustls is no exception.
|
|
//!
|
|
//! When we write data to `TlsStream`, we always write rustls buffer first,
|
|
//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
|
|
//! When data channel is pending, some data may remain in rustls buffer.
|
|
//!
|
|
//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
|
|
//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
|
|
//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
|
|
//!
|
|
//! You should call `poll_flush` at the appropriate time,
|
|
//! such as when a period of `poll_write` write is complete and there is no more data to write.
|
|
//!
|
|
//! ## Why don't we write during `poll_read`?
|
|
//!
|
|
//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
|
|
//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
|
|
//!
|
|
//! And reverse write will also prevent us implement full duplex in the future.
|
|
//!
|
|
//! see <https://github.com/tokio-rs/tls/issues/40>
|
|
//!
|
|
//! ## Why can't we handle it like `native-tls`?
|
|
//!
|
|
//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
|
|
//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
|
|
//! Thus avoiding the call of `poll_flush`.
|
|
//!
|
|
//! but which does not conform to convention of `AsyncWrite` trait.
|
|
//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
|
|
//!
|
|
//! see <https://github.com/tokio-rs/tls/issues/41>
|
|
|
|
macro_rules! ready {
|
|
( $e:expr ) => {
|
|
match $e {
|
|
std::task::Poll::Ready(t) => t,
|
|
std::task::Poll::Pending => return std::task::Poll::Pending,
|
|
}
|
|
};
|
|
}
|
|
|
|
pub mod client;
|
|
mod common;
|
|
pub mod server;
|
|
|
|
use common::{MidHandshake, Stream, TlsState};
|
|
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
|
|
use std::future::Future;
|
|
use std::io;
|
|
#[cfg(unix)]
|
|
use std::os::unix::io::{AsRawFd, RawFd};
|
|
#[cfg(windows)]
|
|
use std::os::windows::io::{AsRawSocket, RawSocket};
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::{Context, Poll};
|
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
|
|
pub use rustls;
|
|
|
|
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
|
|
#[derive(Clone)]
|
|
pub struct TlsConnector {
|
|
inner: Arc<ClientConfig>,
|
|
#[cfg(feature = "early-data")]
|
|
early_data: bool,
|
|
}
|
|
|
|
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
|
|
#[derive(Clone)]
|
|
pub struct TlsAcceptor {
|
|
inner: Arc<ServerConfig>,
|
|
}
|
|
|
|
impl From<Arc<ClientConfig>> for TlsConnector {
|
|
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
|
|
TlsConnector {
|
|
inner,
|
|
#[cfg(feature = "early-data")]
|
|
early_data: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Arc<ServerConfig>> for TlsAcceptor {
|
|
fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
|
|
TlsAcceptor { inner }
|
|
}
|
|
}
|
|
|
|
impl TlsConnector {
|
|
/// Enable 0-RTT.
|
|
///
|
|
/// If you want to use 0-RTT,
|
|
/// You must also set `ClientConfig.enable_early_data` to `true`.
|
|
#[cfg(feature = "early-data")]
|
|
pub fn early_data(mut self, flag: bool) -> TlsConnector {
|
|
self.early_data = flag;
|
|
self
|
|
}
|
|
|
|
#[inline]
|
|
pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
self.connect_with(domain, stream, |_| ())
|
|
}
|
|
|
|
pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
F: FnOnce(&mut ClientConnection),
|
|
{
|
|
let mut session = match ClientConnection::new(self.inner.clone(), domain) {
|
|
Ok(session) => session,
|
|
Err(error) => {
|
|
return Connect(MidHandshake::Error {
|
|
io: stream,
|
|
// TODO(eliza): should this really return an `io::Error`?
|
|
// Probably not...
|
|
error: io::Error::new(io::ErrorKind::Other, error),
|
|
});
|
|
}
|
|
};
|
|
f(&mut session);
|
|
|
|
Connect(MidHandshake::Handshaking(client::TlsStream {
|
|
io: stream,
|
|
|
|
#[cfg(not(feature = "early-data"))]
|
|
state: TlsState::Stream,
|
|
|
|
#[cfg(feature = "early-data")]
|
|
state: if self.early_data && session.early_data().is_some() {
|
|
TlsState::EarlyData(0, Vec::new())
|
|
} else {
|
|
TlsState::Stream
|
|
},
|
|
|
|
#[cfg(feature = "early-data")]
|
|
early_waker: None,
|
|
|
|
session,
|
|
}))
|
|
}
|
|
}
|
|
|
|
impl TlsAcceptor {
|
|
#[inline]
|
|
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
self.accept_with(stream, |_| ())
|
|
}
|
|
|
|
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
F: FnOnce(&mut ServerConnection),
|
|
{
|
|
let mut session = match ServerConnection::new(self.inner.clone()) {
|
|
Ok(session) => session,
|
|
Err(error) => {
|
|
return Accept(MidHandshake::Error {
|
|
io: stream,
|
|
// TODO(eliza): should this really return an `io::Error`?
|
|
// Probably not...
|
|
error: io::Error::new(io::ErrorKind::Other, error),
|
|
});
|
|
}
|
|
};
|
|
f(&mut session);
|
|
|
|
Accept(MidHandshake::Handshaking(server::TlsStream {
|
|
session,
|
|
io: stream,
|
|
state: TlsState::Stream,
|
|
}))
|
|
}
|
|
}
|
|
|
|
pub struct LazyConfigAcceptor<IO> {
|
|
acceptor: rustls::server::Acceptor,
|
|
io: Option<IO>,
|
|
}
|
|
|
|
impl<IO> LazyConfigAcceptor<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
#[inline]
|
|
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
|
|
Self {
|
|
acceptor,
|
|
io: Some(io),
|
|
}
|
|
}
|
|
|
|
/// Takes back the client connection. Will return `None` if called more than once or if the
|
|
/// connection has been accepted.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```no_run
|
|
/// # fn choose_server_config(
|
|
/// # _: rustls::server::ClientHello,
|
|
/// # ) -> std::sync::Arc<rustls::ServerConfig> {
|
|
/// # unimplemented!();
|
|
/// # }
|
|
/// # #[allow(unused_variables)]
|
|
/// # async fn listen() {
|
|
/// use tokio::io::AsyncWriteExt;
|
|
/// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
|
|
/// let (stream, _) = listener.accept().await.unwrap();
|
|
///
|
|
/// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
|
|
/// futures_util::pin_mut!(acceptor);
|
|
///
|
|
/// match acceptor.as_mut().await {
|
|
/// Ok(start) => {
|
|
/// let clientHello = start.client_hello();
|
|
/// let config = choose_server_config(clientHello);
|
|
/// let stream = start.into_stream(config).await.unwrap();
|
|
/// // Proceed with handling the ServerConnection...
|
|
/// }
|
|
/// Err(err) => {
|
|
/// if let Some(mut stream) = acceptor.take_io() {
|
|
/// stream
|
|
/// .write_all(
|
|
/// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
|
|
/// .as_bytes()
|
|
/// )
|
|
/// .await
|
|
/// .unwrap();
|
|
/// }
|
|
/// }
|
|
/// }
|
|
/// # }
|
|
/// ```
|
|
pub fn take_io(&mut self) -> Option<IO> {
|
|
self.io.take()
|
|
}
|
|
}
|
|
|
|
impl<IO> Future for LazyConfigAcceptor<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
type Output = Result<StartHandshake<IO>, io::Error>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let this = self.get_mut();
|
|
loop {
|
|
let io = match this.io.as_mut() {
|
|
Some(io) => io,
|
|
None => {
|
|
return Poll::Ready(Err(io::Error::new(
|
|
io::ErrorKind::Other,
|
|
"acceptor cannot be polled after acceptance",
|
|
)))
|
|
}
|
|
};
|
|
|
|
let mut reader = common::SyncReadAdapter { io, cx };
|
|
match this.acceptor.read_tls(&mut reader) {
|
|
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
|
|
Ok(_) => {}
|
|
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
|
|
Err(e) => return Err(e).into(),
|
|
}
|
|
|
|
match this.acceptor.accept() {
|
|
Ok(Some(accepted)) => {
|
|
let io = this.io.take().unwrap();
|
|
return Poll::Ready(Ok(StartHandshake { accepted, io }));
|
|
}
|
|
Ok(None) => continue,
|
|
Err(err) => {
|
|
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct StartHandshake<IO> {
|
|
accepted: rustls::server::Accepted,
|
|
io: IO,
|
|
}
|
|
|
|
impl<IO> StartHandshake<IO>
|
|
where
|
|
IO: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
|
|
self.accepted.client_hello()
|
|
}
|
|
|
|
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
|
|
self.into_stream_with(config, |_| ())
|
|
}
|
|
|
|
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
|
|
where
|
|
F: FnOnce(&mut ServerConnection),
|
|
{
|
|
let mut conn = match self.accepted.into_connection(config) {
|
|
Ok(conn) => conn,
|
|
Err(error) => {
|
|
return Accept(MidHandshake::Error {
|
|
io: self.io,
|
|
// TODO(eliza): should this really return an `io::Error`?
|
|
// Probably not...
|
|
error: io::Error::new(io::ErrorKind::Other, error),
|
|
});
|
|
}
|
|
};
|
|
f(&mut conn);
|
|
|
|
Accept(MidHandshake::Handshaking(server::TlsStream {
|
|
session: conn,
|
|
io: self.io,
|
|
state: TlsState::Stream,
|
|
}))
|
|
}
|
|
}
|
|
|
|
/// Future returned from `TlsConnector::connect` which will resolve
|
|
/// once the connection handshake has finished.
|
|
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
|
|
|
|
/// Future returned from `TlsAcceptor::accept` which will resolve
|
|
/// once the accept handshake has finished.
|
|
pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
|
|
|
|
/// Like [Connect], but returns `IO` on failure.
|
|
pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
|
|
|
|
/// Like [Accept], but returns `IO` on failure.
|
|
pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
|
|
|
|
impl<IO> Connect<IO> {
|
|
#[inline]
|
|
pub fn into_fallible(self) -> FallibleConnect<IO> {
|
|
FallibleConnect(self.0)
|
|
}
|
|
|
|
pub fn get_ref(&self) -> Option<&IO> {
|
|
match &self.0 {
|
|
MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
|
|
MidHandshake::Error { io, .. } => Some(io),
|
|
MidHandshake::End => None,
|
|
}
|
|
}
|
|
|
|
pub fn get_mut(&mut self) -> Option<&mut IO> {
|
|
match &mut self.0 {
|
|
MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
|
|
MidHandshake::Error { io, .. } => Some(io),
|
|
MidHandshake::End => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<IO> Accept<IO> {
|
|
#[inline]
|
|
pub fn into_fallible(self) -> FallibleAccept<IO> {
|
|
FallibleAccept(self.0)
|
|
}
|
|
|
|
pub fn get_ref(&self) -> Option<&IO> {
|
|
match &self.0 {
|
|
MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
|
|
MidHandshake::Error { io, .. } => Some(io),
|
|
MidHandshake::End => None,
|
|
}
|
|
}
|
|
|
|
pub fn get_mut(&mut self) -> Option<&mut IO> {
|
|
match &mut self.0 {
|
|
MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
|
|
MidHandshake::Error { io, .. } => Some(io),
|
|
MidHandshake::End => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
|
|
type Output = io::Result<client::TlsStream<IO>>;
|
|
|
|
#[inline]
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
|
|
}
|
|
}
|
|
|
|
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
|
|
type Output = io::Result<server::TlsStream<IO>>;
|
|
|
|
#[inline]
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
|
|
}
|
|
}
|
|
|
|
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
|
|
type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
|
|
|
|
#[inline]
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
Pin::new(&mut self.0).poll(cx)
|
|
}
|
|
}
|
|
|
|
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
|
|
type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
|
|
|
|
#[inline]
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
Pin::new(&mut self.0).poll(cx)
|
|
}
|
|
}
|
|
|
|
/// Unified TLS stream type
|
|
///
|
|
/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
|
|
/// a single type to keep both client- and server-initiated TLS-encrypted connections.
|
|
#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
|
|
#[derive(Debug)]
|
|
pub enum TlsStream<T> {
|
|
Client(client::TlsStream<T>),
|
|
Server(server::TlsStream<T>),
|
|
}
|
|
|
|
impl<T> TlsStream<T> {
|
|
pub fn get_ref(&self) -> (&T, &CommonState) {
|
|
use TlsStream::*;
|
|
match self {
|
|
Client(io) => {
|
|
let (io, session) = io.get_ref();
|
|
(io, session)
|
|
}
|
|
Server(io) => {
|
|
let (io, session) = io.get_ref();
|
|
(io, session)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
|
|
use TlsStream::*;
|
|
match self {
|
|
Client(io) => {
|
|
let (io, session) = io.get_mut();
|
|
(io, &mut *session)
|
|
}
|
|
Server(io) => {
|
|
let (io, session) = io.get_mut();
|
|
(io, &mut *session)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> From<client::TlsStream<T>> for TlsStream<T> {
|
|
fn from(s: client::TlsStream<T>) -> Self {
|
|
Self::Client(s)
|
|
}
|
|
}
|
|
|
|
impl<T> From<server::TlsStream<T>> for TlsStream<T> {
|
|
fn from(s: server::TlsStream<T>) -> Self {
|
|
Self::Server(s)
|
|
}
|
|
}
|
|
|
|
#[cfg(unix)]
|
|
impl<S> AsRawFd for TlsStream<S>
|
|
where
|
|
S: AsRawFd,
|
|
{
|
|
fn as_raw_fd(&self) -> RawFd {
|
|
self.get_ref().0.as_raw_fd()
|
|
}
|
|
}
|
|
|
|
#[cfg(windows)]
|
|
impl<S> AsRawSocket for TlsStream<S>
|
|
where
|
|
S: AsRawSocket,
|
|
{
|
|
fn as_raw_socket(&self) -> RawSocket {
|
|
self.get_ref().0.as_raw_socket()
|
|
}
|
|
}
|
|
|
|
impl<T> AsyncRead for TlsStream<T>
|
|
where
|
|
T: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
#[inline]
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<io::Result<()>> {
|
|
match self.get_mut() {
|
|
TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
|
|
TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> AsyncWrite for TlsStream<T>
|
|
where
|
|
T: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
#[inline]
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
match self.get_mut() {
|
|
TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
|
|
TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match self.get_mut() {
|
|
TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
|
|
TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match self.get_mut() {
|
|
TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
|
|
TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
|
|
}
|
|
}
|
|
}
|
|
|