From 485cf8463989c25f0faa8d66c9c3dfdb7bce0063 Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 25 Feb 2019 23:48:06 +0800 Subject: [PATCH] make 0-RTT optional --- .travis.yml | 1 + Cargo.toml | 3 +++ appveyor.yml | 1 + src/client.rs | 16 +++++++++++----- src/lib.rs | 38 +++++++++++++++++++++++++++----------- 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3653f1f..9efee9d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ matrix: script: - cargo test + - cargo test --features early-data - cd examples/server - cargo check - cd ../../examples/client diff --git a/Cargo.toml b/Cargo.toml index b15bea2..ff95d5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ iovec = "0.1" rustls = "0.15" webpki = "0.19" +[features] +early-data = [] + [dev-dependencies] tokio = "0.1.6" lazy_static = "1" diff --git a/appveyor.yml b/appveyor.yml index 038274b..26db365 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -14,6 +14,7 @@ build: false test_script: - 'cargo test' + - 'cargo test --features early-data' - 'cd examples/server' - 'cargo check' - 'cd ../../examples/client' diff --git a/src/client.rs b/src/client.rs index 3e9d73f..91a65aa 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,4 @@ use super::*; -use std::io::Write; use rustls::Session; @@ -10,12 +9,14 @@ pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ClientSession, pub(crate) state: TlsState, + + #[cfg(feature = "early-data")] pub(crate) early_data: (usize, Vec) } #[derive(Debug)] pub(crate) enum TlsState { - EarlyData, + #[cfg(feature = "early-data")] EarlyData, Stream, Eof, Shutdown @@ -23,7 +24,7 @@ pub(crate) enum TlsState { pub(crate) enum MidHandshake { Handshaking(TlsStream), - EarlyData(TlsStream), + #[cfg(feature = "early-data")] EarlyData(TlsStream), End } @@ -66,8 +67,9 @@ where IO: AsyncRead + AsyncWrite, } match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) - | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + #[cfg(feature = "early-data")] + MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), MidHandshake::End => panic!() } } @@ -80,7 +82,10 @@ where IO: AsyncRead + AsyncWrite let mut stream = Stream::new(&mut self.io, &mut self.session); match self.state { + #[cfg(feature = "early-data")] TlsState::EarlyData => { + use std::io::Write; + let (pos, data) = &mut self.early_data; // complete handshake @@ -126,6 +131,7 @@ where IO: AsyncRead + AsyncWrite let mut stream = Stream::new(&mut self.io, &mut self.session); match self.state { + #[cfg(feature = "early-data")] TlsState::EarlyData => { let (pos, data) = &mut self.early_data; diff --git a/src/lib.rs b/src/lib.rs index 511cdd9..6a77fbb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ use common::Stream; #[derive(Clone)] pub struct TlsConnector { inner: Arc, + #[cfg(feature = "early-data")] early_data: bool } @@ -39,7 +40,11 @@ pub struct TlsAcceptor { impl From> for TlsConnector { fn from(inner: Arc) -> TlsConnector { - TlsConnector { inner, early_data: false } + TlsConnector { + inner, + #[cfg(feature = "early-data")] + early_data: false + } } } @@ -54,6 +59,7 @@ impl TlsConnector { /// /// Note that you want to use 0-RTT. /// You must set `enable_early_data` to `true` in `ClientConfig`. + #[cfg(feature = "early-data")] pub fn early_data(mut self, flag: bool) -> TlsConnector { self.early_data = flag; self @@ -75,19 +81,28 @@ impl TlsConnector { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - Connect(if self.early_data { - client::MidHandshake::EarlyData(client::TlsStream { - session, io: stream, - state: client::TlsState::EarlyData, - early_data: (0, Vec::new()) - }) - } else { - client::MidHandshake::Handshaking(client::TlsStream { + #[cfg(not(feature = "early-data"))] { + Connect(client::MidHandshake::Handshaking(client::TlsStream { session, io: stream, state: client::TlsState::Stream, - early_data: (0, Vec::new()) + })) + } + + #[cfg(feature = "early-data")] { + Connect(if self.early_data { + client::MidHandshake::EarlyData(client::TlsStream { + session, io: stream, + state: client::TlsState::EarlyData, + early_data: (0, Vec::new()) + }) + } else { + client::MidHandshake::Handshaking(client::TlsStream { + session, io: stream, + state: client::TlsState::Stream, + early_data: (0, Vec::new()) + }) }) - }) + } } } @@ -143,5 +158,6 @@ impl Future for Accept { } } +#[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt;