#34 properly implement TLS-1.3 shutdown behavior

This commit is contained in:
Yan Zhai 2019-04-19 21:08:18 +00:00
parent b6e39450ce
commit 87916dade6
3 changed files with 131 additions and 96 deletions

View File

@ -1,7 +1,6 @@
use super::*; use super::*;
use std::io::Write;
use rustls::Session; use rustls::Session;
use std::io::Write;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
@ -12,21 +11,14 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState, pub(crate) state: TlsState,
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
pub(crate) early_data: (usize, Vec<u8>) pub(crate) early_data: (usize, Vec<u8>),
}
#[derive(Debug)]
pub(crate) enum TlsState {
#[cfg(feature = "early-data")] EarlyData,
Stream,
Eof,
Shutdown
} }
pub(crate) enum MidHandshake<IO> { pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>), Handshaking(TlsStream<IO>),
#[cfg(feature = "early-data")] EarlyData(TlsStream<IO>), #[cfg(feature = "early-data")]
End EarlyData(TlsStream<IO>),
End,
} }
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
@ -47,7 +39,8 @@ impl<IO> TlsStream<IO> {
} }
impl<IO> Future for MidHandshake<IO> impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite, where
IO: AsyncRead + AsyncWrite,
{ {
type Item = TlsStream<IO>; type Item = TlsStream<IO>;
type Error = io::Error; type Error = io::Error;
@ -71,13 +64,14 @@ where IO: AsyncRead + AsyncWrite,
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!() MidHandshake::End => panic!(),
} }
} }
} }
impl<IO> io::Read for TlsStream<IO> impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.state { match self.state {
@ -106,31 +100,35 @@ where IO: AsyncRead + AsyncWrite
} }
self.read(buf) self.read(buf)
}, }
TlsState::Stream => { TlsState::Stream | TlsState::WriteShutdown => {
let mut stream = Stream::new(&mut self.io, &mut self.session); let mut stream = Stream::new(&mut self.io, &mut self.session);
match stream.read(buf) { match stream.read(buf) {
Ok(0) => { Ok(0) => {
self.state = TlsState::Eof; self.state.shutdown_read();
Ok(0) Ok(0)
}, }
Ok(n) => Ok(n), Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown; self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify(); stream.session.send_close_notify();
Ok(0) self.state.shutdown_write();
},
Err(e) => Err(e)
} }
}, Ok(0)
TlsState::Eof | TlsState::Shutdown => Ok(0), }
Err(e) => Err(e),
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
} }
} }
} }
impl<IO> io::Write for TlsStream<IO> impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session); let mut stream = Stream::new(&mut self.io, &mut self.session);
@ -164,8 +162,8 @@ where IO: AsyncRead + AsyncWrite
self.state = TlsState::Stream; self.state = TlsState::Stream;
data.clear(); data.clear();
stream.write(buf) stream.write(buf)
}, }
_ => stream.write(buf) _ => stream.write(buf),
} }
} }
@ -176,7 +174,8 @@ where IO: AsyncRead + AsyncWrite
} }
impl<IO> AsyncRead for TlsStream<IO> impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false false
@ -184,14 +183,15 @@ where IO: AsyncRead + AsyncWrite
} }
impl<IO> AsyncWrite for TlsStream<IO> impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state { match self.state {
TlsState::Shutdown => (), s if !s.writeable() => (),
_ => { _ => {
self.session.send_close_notify(); self.session.send_close_notify();
self.state = TlsState::Shutdown; self.state.shutdown_write();
} }
} }

View File

@ -3,39 +3,68 @@
pub extern crate rustls; pub extern crate rustls;
pub extern crate webpki; pub extern crate webpki;
extern crate futures;
extern crate tokio_io;
extern crate bytes; extern crate bytes;
extern crate futures;
extern crate iovec; extern crate iovec;
extern crate tokio_io;
mod common;
pub mod client; pub mod client;
mod common;
pub mod server; pub mod server;
use std::{ io, mem };
use std::sync::Arc;
use webpki::DNSNameRef;
use rustls::{
ClientSession, ServerSession,
ClientConfig, ServerConfig
};
use futures::{Async, Future, Poll};
use tokio_io::{ AsyncRead, AsyncWrite, try_nb };
use common::Stream; use common::Stream;
use futures::{Async, Future, Poll};
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession};
use std::sync::Arc;
use std::{io, mem};
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
use webpki::DNSNameRef;
#[derive(Debug, Copy, Clone)]
pub enum TlsState {
#[cfg(feature = "early-data")]
EarlyData,
Stream,
ReadShutdown,
WriteShutdown,
FullyShutdown,
}
impl TlsState {
pub(crate) fn shutdown_read(&mut self) {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown,
}
}
pub(crate) fn shutdown_write(&mut self) {
match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown,
}
}
pub(crate) fn writeable(&self) -> bool {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => true,
_ => false,
}
}
}
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
#[derive(Clone)] #[derive(Clone)]
pub struct TlsConnector { pub struct TlsConnector {
inner: Arc<ClientConfig>, inner: Arc<ClientConfig>,
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
early_data: bool early_data: bool,
} }
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
#[derive(Clone)] #[derive(Clone)]
pub struct TlsAcceptor { pub struct TlsAcceptor {
inner: Arc<ServerConfig> inner: Arc<ServerConfig>,
} }
impl From<Arc<ClientConfig>> for TlsConnector { impl From<Arc<ClientConfig>> for TlsConnector {
@ -43,7 +72,7 @@ impl From<Arc<ClientConfig>> for TlsConnector {
TlsConnector { TlsConnector {
inner, inner,
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
early_data: false early_data: false,
} }
} }
} }
@ -66,40 +95,45 @@ impl TlsConnector {
} }
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
self.connect_with(domain, stream, |_| ()) self.connect_with(domain, stream, |_| ())
} }
#[inline] #[inline]
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
-> Connect<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ClientSession) F: FnOnce(&mut ClientSession),
{ {
let mut session = ClientSession::new(&self.inner, domain); let mut session = ClientSession::new(&self.inner, domain);
f(&mut session); f(&mut session);
#[cfg(not(feature = "early-data"))] { #[cfg(not(feature = "early-data"))]
{
Connect(client::MidHandshake::Handshaking(client::TlsStream { Connect(client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream, session,
state: client::TlsState::Stream, io: stream,
state: TlsState::Stream,
})) }))
} }
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")]
{
Connect(if self.early_data { Connect(if self.early_data {
client::MidHandshake::EarlyData(client::TlsStream { client::MidHandshake::EarlyData(client::TlsStream {
session, io: stream, session,
state: client::TlsState::EarlyData, io: stream,
early_data: (0, Vec::new()) state: TlsState::EarlyData,
early_data: (0, Vec::new()),
}) })
} else { } else {
client::MidHandshake::Handshaking(client::TlsStream { client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream, session,
state: client::TlsState::Stream, io: stream,
early_data: (0, Vec::new()) state: TlsState::Stream,
early_data: (0, Vec::new()),
}) })
}) })
} }
@ -108,29 +142,29 @@ impl TlsConnector {
impl TlsAcceptor { impl TlsAcceptor {
pub fn accept<IO>(&self, stream: IO) -> Accept<IO> pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where IO: AsyncRead + AsyncWrite, where
IO: AsyncRead + AsyncWrite,
{ {
self.accept_with(stream, |_| ()) self.accept_with(stream, |_| ())
} }
#[inline] #[inline]
pub fn accept_with<IO, F>(&self, stream: IO, f: F) pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
-> Accept<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ServerSession) F: FnOnce(&mut ServerSession),
{ {
let mut session = ServerSession::new(&self.inner); let mut session = ServerSession::new(&self.inner);
f(&mut session); f(&mut session);
Accept(server::MidHandshake::Handshaking(server::TlsStream { Accept(server::MidHandshake::Handshaking(server::TlsStream {
session, io: stream, session,
state: server::TlsState::Stream, io: stream,
state: TlsState::Stream,
})) }))
} }
} }
/// Future returned from `ClientConfigExt::connect_async` which will resolve /// Future returned from `ClientConfigExt::connect_async` which will resolve
/// once the connection handshake has finished. /// once the connection handshake has finished.
pub struct Connect<IO>(client::MidHandshake<IO>); pub struct Connect<IO>(client::MidHandshake<IO>);
@ -139,7 +173,6 @@ pub struct Connect<IO>(client::MidHandshake<IO>);
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<IO>(server::MidHandshake<IO>); pub struct Accept<IO>(server::MidHandshake<IO>);
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> { impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
type Item = client::TlsStream<IO>; type Item = client::TlsStream<IO>;
type Error = io::Error; type Error = io::Error;

View File

@ -1,26 +1,18 @@
use super::*; use super::*;
use rustls::Session; use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
#[derive(Debug)] #[derive(Debug)]
pub struct TlsStream<IO> { pub struct TlsStream<IO> {
pub(crate) io: IO, pub(crate) io: IO,
pub(crate) session: ServerSession, pub(crate) session: ServerSession,
pub(crate) state: TlsState pub(crate) state: TlsState,
}
#[derive(Debug)]
pub(crate) enum TlsState {
Stream,
Eof,
Shutdown
} }
pub(crate) enum MidHandshake<IO> { pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>), Handshaking(TlsStream<IO>),
End End,
} }
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
@ -41,7 +33,8 @@ impl<IO> TlsStream<IO> {
} }
impl<IO> Future for MidHandshake<IO> impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite, where
IO: AsyncRead + AsyncWrite,
{ {
type Item = TlsStream<IO>; type Item = TlsStream<IO>;
type Error = io::Error; type Error = io::Error;
@ -63,38 +56,45 @@ where IO: AsyncRead + AsyncWrite,
match mem::replace(self, MidHandshake::End) { match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!() MidHandshake::End => panic!(),
} }
} }
} }
impl<IO> io::Read for TlsStream<IO> impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session); let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state { match self.state {
TlsState::Stream => match stream.read(buf) { TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) {
Ok(0) => { Ok(0) => {
self.state = TlsState::Eof; self.state.shutdown_read();
Ok(0) Ok(0)
}, }
Ok(n) => Ok(n), Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown; self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify(); stream.session.send_close_notify();
self.state.shutdown_write();
}
Ok(0) Ok(0)
}
Err(e) => Err(e),
}, },
Err(e) => Err(e) TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
}, #[cfg(feature = "early-data")]
TlsState::Eof | TlsState::Shutdown => Ok(0) s => unreachable!("server TLS can not hit this state: {:?}", s),
} }
} }
} }
impl<IO> io::Write for TlsStream<IO> impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session); let mut stream = Stream::new(&mut self.io, &mut self.session);
@ -108,7 +108,8 @@ where IO: AsyncRead + AsyncWrite
} }
impl<IO> AsyncRead for TlsStream<IO> impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
{ {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false false
@ -116,14 +117,15 @@ where IO: AsyncRead + AsyncWrite
} }
impl<IO> AsyncWrite for TlsStream<IO> impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite, where
IO: AsyncRead + AsyncWrite,
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state { match self.state {
TlsState::Shutdown => (), s if !s.writeable() => (),
_ => { _ => {
self.session.send_close_notify(); self.session.send_close_notify();
self.state = TlsState::Shutdown; self.state.shutdown_write();
} }
} }