Async TLS for the Tokio runtime
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
tokio-rustls/src/common/mod.rs

363 lines
11 KiB

mod handshake;
pub(crate) use handshake::{IoSession, MidHandshake};
use rustls::{ConnectionCommon, SideData};
use std::io::{self, IoSlice, Read, Write};
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Debug)]
pub enum TlsState {
#[cfg(feature = "early-data")]
EarlyData(usize, Vec<u8>),
Stream,
ReadShutdown,
WriteShutdown,
FullyShutdown,
}
impl TlsState {
#[inline]
pub fn shutdown_read(&mut self) {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown,
}
}
#[inline]
pub fn shutdown_write(&mut self) {
match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown,
}
}
#[inline]
pub fn writeable(&self) -> bool {
!matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
}
#[inline]
pub fn readable(&self) -> bool {
!matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
}
#[inline]
#[cfg(feature = "early-data")]
pub fn is_early_data(&self) -> bool {
matches!(self, TlsState::EarlyData(..))
}
#[inline]
#[cfg(not(feature = "early-data"))]
pub const fn is_early_data(&self) -> bool {
false
}
}
pub struct Stream<'a, IO, C> {
pub io: &'a mut IO,
pub session: &'a mut C,
pub eof: bool,
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
Stream {
io,
session,
// The state so far is only used to detect EOF, so either Stream
// or EarlyData state should both be all right.
eof: false,
}
}
pub fn set_eof(mut self, eof: bool) -> Self {
self.eof = eof;
self
}
pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
Pin::new(self)
}
pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
let mut reader = SyncReadAdapter { io: self.io, cx };
let n = match self.session.read_tls(&mut reader) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
Err(err) => return Poll::Ready(Err(err)),
};
let stats = self.session.process_new_packets().map_err(|err| {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})?;
if stats.peer_has_closed() && self.session.is_handshaking() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"tls handshake alert",
)));
}
Poll::Ready(Ok(n))
}
pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
}
impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}
#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}
fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}
let mut writer = Writer { io: self.io, cx };
match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result),
}
}
pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0;
let mut rdlen = 0;
loop {
let mut write_would_block = false;
let mut read_would_block = false;
let mut need_flush = false;
while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(n)) => {
wrlen += n;
need_flush = true;
}
Poll::Pending => {
write_would_block = true;
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
if need_flush {
match Pin::new(&mut self.io).poll_flush(cx) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => write_would_block = true,
}
}
while !self.eof && self.session.wants_read() {
match self.read_io(cx) {
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => {
read_would_block = true;
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
return match (self.eof, self.session.is_handshaking()) {
(true, true) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
Poll::Ready(Err(err))
}
(_, false) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true) if write_would_block || read_would_block => {
if rdlen != 0 || wrlen != 0 {
Poll::Ready(Ok((rdlen, wrlen)))
} else {
Poll::Pending
}
}
(..) => continue,
};
}
}
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut io_pending = false;
// read a packet
while !self.eof && self.session.wants_read() {
match self.read_io(cx) {
Poll::Ready(Ok(0)) => {
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
io_pending = true;
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
match self.session.reader().read(buf.initialize_unfilled()) {
// If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
// connection with a `CloseNotify` message and no more data will be forthcoming.
//
// Rustls yielded more data: advance the buffer, then see if more data is coming.
//
// We don't need to modify `self.eof` here, because it is only a temporary mark.
// rustls will only return 0 if is has received `CloseNotify`,
// in which case no additional processing is required.
Ok(n) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
// Rustls doesn't have more data to yield, but it believes the connection is open.
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
if !io_pending {
// If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
// but if it does, we can try again.
//
// If the rustls state is abnormal, it may cause a cyclic wakeup.
// but tokio's cooperative budget will prevent infinite wakeup.
cx.waker().wake_by_ref();
}
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut pos = 0;
while pos != buf.len() {
let mut would_block = false;
match self.session.writer().write(&buf[pos..]) {
Ok(n) => pos += n,
Err(err) => return Poll::Ready(Err(err)),
};
while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
return match (pos, would_block) {
(0, true) => Poll::Pending,
(n, true) => Poll::Ready(Ok(n)),
(_, false) => continue,
};
}
Poll::Ready(Ok(pos))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.writer().flush()?;
while self.session.wants_write() {
ready!(self.write_io(cx))?;
}
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
while self.session.wants_write() {
ready!(self.write_io(cx))?;
}
Pin::new(&mut self.io).poll_shutdown(cx)
}
}
/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
/// associated [`Context`].
///
/// Turns `Poll::Pending` into `WouldBlock`.
pub struct SyncReadAdapter<'a, 'b, T> {
pub io: &'a mut T,
pub cx: &'a mut Context<'b>,
}
impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut buf = ReadBuf::new(buf);
match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
Poll::Ready(Ok(())) => Ok(buf.filled().len()),
Poll::Ready(Err(err)) => Err(err),
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}
#[cfg(test)]
mod test_stream;