Merge pull request #35 from jerryz920/yan/issue-34
#34 properly implement TLS-1.3 shutdown behavior
This commit is contained in:
commit
00f1022f88
@ -1,7 +1,6 @@
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use rustls::Session;
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
@ -12,21 +11,14 @@ pub struct TlsStream<IO> {
|
||||
pub(crate) state: TlsState,
|
||||
|
||||
#[cfg(feature = "early-data")]
|
||||
pub(crate) early_data: (usize, Vec<u8>)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TlsState {
|
||||
#[cfg(feature = "early-data")] EarlyData,
|
||||
Stream,
|
||||
Eof,
|
||||
Shutdown
|
||||
pub(crate) early_data: (usize, Vec<u8>),
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
#[cfg(feature = "early-data")] EarlyData(TlsStream<IO>),
|
||||
End
|
||||
#[cfg(feature = "early-data")]
|
||||
EarlyData(TlsStream<IO>),
|
||||
End,
|
||||
}
|
||||
|
||||
impl<IO> TlsStream<IO> {
|
||||
@ -47,7 +39,8 @@ impl<IO> TlsStream<IO> {
|
||||
}
|
||||
|
||||
impl<IO> Future for MidHandshake<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
@ -71,13 +64,14 @@ where IO: AsyncRead + AsyncWrite,
|
||||
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
|
||||
#[cfg(feature = "early-data")]
|
||||
MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
|
||||
MidHandshake::End => panic!()
|
||||
MidHandshake::End => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Read for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.state {
|
||||
@ -106,31 +100,35 @@ where IO: AsyncRead + AsyncWrite
|
||||
}
|
||||
|
||||
self.read(buf)
|
||||
},
|
||||
TlsState::Stream => {
|
||||
}
|
||||
TlsState::Stream | TlsState::WriteShutdown => {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
||||
match stream.read(buf) {
|
||||
Ok(0) => {
|
||||
self.state = TlsState::Eof;
|
||||
self.state.shutdown_read();
|
||||
Ok(0)
|
||||
},
|
||||
}
|
||||
Ok(n) => Ok(n),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.state = TlsState::Shutdown;
|
||||
self.state.shutdown_read();
|
||||
if self.state.writeable() {
|
||||
stream.session.send_close_notify();
|
||||
Ok(0)
|
||||
},
|
||||
Err(e) => Err(e)
|
||||
self.state.shutdown_write();
|
||||
}
|
||||
},
|
||||
TlsState::Eof | TlsState::Shutdown => Ok(0),
|
||||
Ok(0)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Write for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
@ -164,8 +162,8 @@ where IO: AsyncRead + AsyncWrite
|
||||
self.state = TlsState::Stream;
|
||||
data.clear();
|
||||
stream.write(buf)
|
||||
},
|
||||
_ => stream.write(buf)
|
||||
}
|
||||
_ => stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,7 +174,8 @@ where IO: AsyncRead + AsyncWrite
|
||||
}
|
||||
|
||||
impl<IO> AsyncRead for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||
false
|
||||
@ -184,15 +183,13 @@ where IO: AsyncRead + AsyncWrite
|
||||
}
|
||||
|
||||
impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self.state {
|
||||
TlsState::Shutdown => (),
|
||||
_ => {
|
||||
if self.state.writeable() {
|
||||
self.session.send_close_notify();
|
||||
self.state = TlsState::Shutdown;
|
||||
}
|
||||
self.state.shutdown_write();
|
||||
}
|
||||
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
107
src/lib.rs
107
src/lib.rs
@ -3,39 +3,68 @@
|
||||
pub extern crate rustls;
|
||||
pub extern crate webpki;
|
||||
|
||||
extern crate futures;
|
||||
extern crate tokio_io;
|
||||
extern crate bytes;
|
||||
extern crate futures;
|
||||
extern crate iovec;
|
||||
extern crate tokio_io;
|
||||
|
||||
mod common;
|
||||
pub mod client;
|
||||
mod common;
|
||||
pub mod server;
|
||||
|
||||
use std::{ io, mem };
|
||||
use std::sync::Arc;
|
||||
use webpki::DNSNameRef;
|
||||
use rustls::{
|
||||
ClientSession, ServerSession,
|
||||
ClientConfig, ServerConfig
|
||||
};
|
||||
use futures::{Async, Future, Poll};
|
||||
use tokio_io::{ AsyncRead, AsyncWrite, try_nb };
|
||||
use common::Stream;
|
||||
use futures::{Async, Future, Poll};
|
||||
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession};
|
||||
use std::sync::Arc;
|
||||
use std::{io, mem};
|
||||
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
|
||||
use webpki::DNSNameRef;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum TlsState {
|
||||
#[cfg(feature = "early-data")]
|
||||
EarlyData,
|
||||
Stream,
|
||||
ReadShutdown,
|
||||
WriteShutdown,
|
||||
FullyShutdown,
|
||||
}
|
||||
|
||||
impl TlsState {
|
||||
pub(crate) fn shutdown_read(&mut self) {
|
||||
match *self {
|
||||
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
||||
_ => *self = TlsState::ReadShutdown,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown_write(&mut self) {
|
||||
match *self {
|
||||
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
||||
_ => *self = TlsState::WriteShutdown,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn writeable(&self) -> bool {
|
||||
match *self {
|
||||
TlsState::WriteShutdown | TlsState::FullyShutdown => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
|
||||
#[derive(Clone)]
|
||||
pub struct TlsConnector {
|
||||
inner: Arc<ClientConfig>,
|
||||
#[cfg(feature = "early-data")]
|
||||
early_data: bool
|
||||
early_data: bool,
|
||||
}
|
||||
|
||||
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
|
||||
#[derive(Clone)]
|
||||
pub struct TlsAcceptor {
|
||||
inner: Arc<ServerConfig>
|
||||
inner: Arc<ServerConfig>,
|
||||
}
|
||||
|
||||
impl From<Arc<ClientConfig>> for TlsConnector {
|
||||
@ -43,7 +72,7 @@ impl From<Arc<ClientConfig>> for TlsConnector {
|
||||
TlsConnector {
|
||||
inner,
|
||||
#[cfg(feature = "early-data")]
|
||||
early_data: false
|
||||
early_data: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -66,40 +95,45 @@ impl TlsConnector {
|
||||
}
|
||||
|
||||
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
self.connect_with(domain, stream, |_| ())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F)
|
||||
-> Connect<IO>
|
||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
F: FnOnce(&mut ClientSession)
|
||||
F: FnOnce(&mut ClientSession),
|
||||
{
|
||||
let mut session = ClientSession::new(&self.inner, domain);
|
||||
f(&mut session);
|
||||
|
||||
#[cfg(not(feature = "early-data"))] {
|
||||
#[cfg(not(feature = "early-data"))]
|
||||
{
|
||||
Connect(client::MidHandshake::Handshaking(client::TlsStream {
|
||||
session, io: stream,
|
||||
state: client::TlsState::Stream,
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::Stream,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(feature = "early-data")] {
|
||||
#[cfg(feature = "early-data")]
|
||||
{
|
||||
Connect(if self.early_data {
|
||||
client::MidHandshake::EarlyData(client::TlsStream {
|
||||
session, io: stream,
|
||||
state: client::TlsState::EarlyData,
|
||||
early_data: (0, Vec::new())
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::EarlyData,
|
||||
early_data: (0, Vec::new()),
|
||||
})
|
||||
} else {
|
||||
client::MidHandshake::Handshaking(client::TlsStream {
|
||||
session, io: stream,
|
||||
state: client::TlsState::Stream,
|
||||
early_data: (0, Vec::new())
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::Stream,
|
||||
early_data: (0, Vec::new()),
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -108,29 +142,29 @@ impl TlsConnector {
|
||||
|
||||
impl TlsAcceptor {
|
||||
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
self.accept_with(stream, |_| ())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn accept_with<IO, F>(&self, stream: IO, f: F)
|
||||
-> Accept<IO>
|
||||
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
F: FnOnce(&mut ServerSession)
|
||||
F: FnOnce(&mut ServerSession),
|
||||
{
|
||||
let mut session = ServerSession::new(&self.inner);
|
||||
f(&mut session);
|
||||
|
||||
Accept(server::MidHandshake::Handshaking(server::TlsStream {
|
||||
session, io: stream,
|
||||
state: server::TlsState::Stream,
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::Stream,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Future returned from `ClientConfigExt::connect_async` which will resolve
|
||||
/// once the connection handshake has finished.
|
||||
pub struct Connect<IO>(client::MidHandshake<IO>);
|
||||
@ -139,7 +173,6 @@ pub struct Connect<IO>(client::MidHandshake<IO>);
|
||||
/// once the accept handshake has finished.
|
||||
pub struct Accept<IO>(server::MidHandshake<IO>);
|
||||
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
|
||||
type Item = client::TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
|
@ -1,26 +1,18 @@
|
||||
use super::*;
|
||||
use rustls::Session;
|
||||
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStream<IO> {
|
||||
pub(crate) io: IO,
|
||||
pub(crate) session: ServerSession,
|
||||
pub(crate) state: TlsState
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TlsState {
|
||||
Stream,
|
||||
Eof,
|
||||
Shutdown
|
||||
pub(crate) state: TlsState,
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
End
|
||||
End,
|
||||
}
|
||||
|
||||
impl<IO> TlsStream<IO> {
|
||||
@ -41,7 +33,8 @@ impl<IO> TlsStream<IO> {
|
||||
}
|
||||
|
||||
impl<IO> Future for MidHandshake<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
@ -63,38 +56,45 @@ where IO: AsyncRead + AsyncWrite,
|
||||
|
||||
match mem::replace(self, MidHandshake::End) {
|
||||
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
|
||||
MidHandshake::End => panic!()
|
||||
MidHandshake::End => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Read for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
||||
match self.state {
|
||||
TlsState::Stream => match stream.read(buf) {
|
||||
TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) {
|
||||
Ok(0) => {
|
||||
self.state = TlsState::Eof;
|
||||
self.state.shutdown_read();
|
||||
Ok(0)
|
||||
},
|
||||
}
|
||||
Ok(n) => Ok(n),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.state = TlsState::Shutdown;
|
||||
self.state.shutdown_read();
|
||||
if self.state.writeable() {
|
||||
stream.session.send_close_notify();
|
||||
self.state.shutdown_write();
|
||||
}
|
||||
Ok(0)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
},
|
||||
Err(e) => Err(e)
|
||||
},
|
||||
TlsState::Eof | TlsState::Shutdown => Ok(0)
|
||||
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
|
||||
#[cfg(feature = "early-data")]
|
||||
s => unreachable!("server TLS can not hit this state: {:?}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Write for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
@ -108,7 +108,8 @@ where IO: AsyncRead + AsyncWrite
|
||||
}
|
||||
|
||||
impl<IO> AsyncRead for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||
false
|
||||
@ -116,15 +117,13 @@ where IO: AsyncRead + AsyncWrite
|
||||
}
|
||||
|
||||
impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self.state {
|
||||
TlsState::Shutdown => (),
|
||||
_ => {
|
||||
if self.state.writeable() {
|
||||
self.session.send_close_notify();
|
||||
self.state = TlsState::Shutdown;
|
||||
}
|
||||
self.state.shutdown_write();
|
||||
}
|
||||
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
Loading…
Reference in New Issue
Block a user