#34 properly implement TLS-1.3 shutdown behavior

This commit is contained in:
Yan Zhai 2019-04-19 21:08:18 +00:00
parent b6e39450ce
commit 87916dade6
3 changed files with 131 additions and 96 deletions

View File

@ -1,7 +1,6 @@
use super::*;
use std::io::Write;
use rustls::Session;
use std::io::Write;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
@ -12,21 +11,14 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState,
#[cfg(feature = "early-data")]
pub(crate) early_data: (usize, Vec<u8>)
}
#[derive(Debug)]
pub(crate) enum TlsState {
#[cfg(feature = "early-data")] EarlyData,
Stream,
Eof,
Shutdown
pub(crate) early_data: (usize, Vec<u8>),
}
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
#[cfg(feature = "early-data")] EarlyData(TlsStream<IO>),
End
#[cfg(feature = "early-data")]
EarlyData(TlsStream<IO>),
End,
}
impl<IO> TlsStream<IO> {
@ -47,7 +39,8 @@ impl<IO> TlsStream<IO> {
}
impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite,
where
IO: AsyncRead + AsyncWrite,
{
type Item = TlsStream<IO>;
type Error = io::Error;
@ -71,13 +64,14 @@ where IO: AsyncRead + AsyncWrite,
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
#[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
MidHandshake::End => panic!(),
}
}
}
impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.state {
@ -106,31 +100,35 @@ where IO: AsyncRead + AsyncWrite
}
self.read(buf)
},
TlsState::Stream => {
}
TlsState::Stream | TlsState::WriteShutdown => {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
self.state.shutdown_read();
Ok(0)
},
}
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
self.state.shutdown_write();
}
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
Ok(0)
}
Err(e) => Err(e),
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
}
}
}
impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
@ -164,8 +162,8 @@ where IO: AsyncRead + AsyncWrite
self.state = TlsState::Stream;
data.clear();
stream.write(buf)
},
_ => stream.write(buf)
}
_ => stream.write(buf),
}
}
@ -176,7 +174,8 @@ where IO: AsyncRead + AsyncWrite
}
impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
@ -184,14 +183,15 @@ where IO: AsyncRead + AsyncWrite
}
impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
s if !s.writeable() => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
self.state.shutdown_write();
}
}

View File

@ -3,39 +3,68 @@
pub extern crate rustls;
pub extern crate webpki;
extern crate futures;
extern crate tokio_io;
extern crate bytes;
extern crate futures;
extern crate iovec;
extern crate tokio_io;
mod common;
pub mod client;
mod common;
pub mod server;
use std::{ io, mem };
use std::sync::Arc;
use webpki::DNSNameRef;
use rustls::{
ClientSession, ServerSession,
ClientConfig, ServerConfig
};
use futures::{Async, Future, Poll};
use tokio_io::{ AsyncRead, AsyncWrite, try_nb };
use common::Stream;
use futures::{Async, Future, Poll};
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession};
use std::sync::Arc;
use std::{io, mem};
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
use webpki::DNSNameRef;
#[derive(Debug, Copy, Clone)]
pub enum TlsState {
#[cfg(feature = "early-data")]
EarlyData,
Stream,
ReadShutdown,
WriteShutdown,
FullyShutdown,
}
impl TlsState {
pub(crate) fn shutdown_read(&mut self) {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown,
}
}
pub(crate) fn shutdown_write(&mut self) {
match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown,
}
}
pub(crate) fn writeable(&self) -> bool {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => true,
_ => false,
}
}
}
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
#[derive(Clone)]
pub struct TlsConnector {
inner: Arc<ClientConfig>,
#[cfg(feature = "early-data")]
early_data: bool
early_data: bool,
}
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
#[derive(Clone)]
pub struct TlsAcceptor {
inner: Arc<ServerConfig>
inner: Arc<ServerConfig>,
}
impl From<Arc<ClientConfig>> for TlsConnector {
@ -43,7 +72,7 @@ impl From<Arc<ClientConfig>> for TlsConnector {
TlsConnector {
inner,
#[cfg(feature = "early-data")]
early_data: false
early_data: false,
}
}
}
@ -66,40 +95,45 @@ impl TlsConnector {
}
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
self.connect_with(domain, stream, |_| ())
}
#[inline]
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F)
-> Connect<IO>
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ClientSession)
F: FnOnce(&mut ClientSession),
{
let mut session = ClientSession::new(&self.inner, domain);
f(&mut session);
#[cfg(not(feature = "early-data"))] {
#[cfg(not(feature = "early-data"))]
{
Connect(client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream,
state: client::TlsState::Stream,
session,
io: stream,
state: TlsState::Stream,
}))
}
#[cfg(feature = "early-data")] {
#[cfg(feature = "early-data")]
{
Connect(if self.early_data {
client::MidHandshake::EarlyData(client::TlsStream {
session, io: stream,
state: client::TlsState::EarlyData,
early_data: (0, Vec::new())
session,
io: stream,
state: TlsState::EarlyData,
early_data: (0, Vec::new()),
})
} else {
client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream,
state: client::TlsState::Stream,
early_data: (0, Vec::new())
session,
io: stream,
state: TlsState::Stream,
early_data: (0, Vec::new()),
})
})
}
@ -108,29 +142,29 @@ impl TlsConnector {
impl TlsAcceptor {
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where IO: AsyncRead + AsyncWrite,
where
IO: AsyncRead + AsyncWrite,
{
self.accept_with(stream, |_| ())
}
#[inline]
pub fn accept_with<IO, F>(&self, stream: IO, f: F)
-> Accept<IO>
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where
IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ServerSession)
F: FnOnce(&mut ServerSession),
{
let mut session = ServerSession::new(&self.inner);
f(&mut session);
Accept(server::MidHandshake::Handshaking(server::TlsStream {
session, io: stream,
state: server::TlsState::Stream,
session,
io: stream,
state: TlsState::Stream,
}))
}
}
/// Future returned from `ClientConfigExt::connect_async` which will resolve
/// once the connection handshake has finished.
pub struct Connect<IO>(client::MidHandshake<IO>);
@ -139,7 +173,6 @@ pub struct Connect<IO>(client::MidHandshake<IO>);
/// once the accept handshake has finished.
pub struct Accept<IO>(server::MidHandshake<IO>);
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
type Item = client::TlsStream<IO>;
type Error = io::Error;

View File

@ -1,26 +1,18 @@
use super::*;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ServerSession,
pub(crate) state: TlsState
}
#[derive(Debug)]
pub(crate) enum TlsState {
Stream,
Eof,
Shutdown
pub(crate) state: TlsState,
}
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
End
End,
}
impl<IO> TlsStream<IO> {
@ -41,7 +33,8 @@ impl<IO> TlsStream<IO> {
}
impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite,
where
IO: AsyncRead + AsyncWrite,
{
type Item = TlsStream<IO>;
type Error = io::Error;
@ -63,38 +56,45 @@ where IO: AsyncRead + AsyncWrite,
match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
MidHandshake::End => panic!(),
}
}
}
impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::Stream => match stream.read(buf) {
TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
self.state.shutdown_read();
Ok(0)
},
}
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify();
self.state.shutdown_write();
}
Ok(0)
}
Err(e) => Err(e),
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0)
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
#[cfg(feature = "early-data")]
s => unreachable!("server TLS can not hit this state: {:?}", s),
}
}
}
impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
@ -108,7 +108,8 @@ where IO: AsyncRead + AsyncWrite
}
impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
@ -116,14 +117,15 @@ where IO: AsyncRead + AsyncWrite
}
impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite,
where
IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
s if !s.writeable() => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
self.state.shutdown_write();
}
}