Add 0-RTT support

This commit is contained in:
quininer 2019-02-18 16:51:31 +08:00
parent 7d6ed0acfc
commit 3e605aafe4
4 changed files with 247 additions and 90 deletions

View File

@ -6,18 +6,18 @@ use rustls::WriteV;
use tokio_io::{ AsyncRead, AsyncWrite }; use tokio_io::{ AsyncRead, AsyncWrite };
pub struct Stream<'a, S: 'a, IO: 'a> { pub struct Stream<'a, IO: 'a, S: 'a> {
pub session: &'a mut S, pub io: &'a mut IO,
pub io: &'a mut IO pub session: &'a mut S
} }
pub trait WriteTls<'a, S: Session, IO: AsyncRead + AsyncWrite>: Read + Write { pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write {
fn write_tls(&mut self) -> io::Result<usize>; fn write_tls(&mut self) -> io::Result<usize>;
} }
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream { session, io } Stream { io, session }
} }
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
@ -66,7 +66,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> {
} }
} }
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> {
fn write_tls(&mut self) -> io::Result<usize> { fn write_tls(&mut self) -> io::Result<usize> {
use futures::Async; use futures::Async;
use self::vecbuf::VecBuf; use self::vecbuf::VecBuf;
@ -89,7 +89,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<
} }
} }
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.session.wants_read() { while self.session.wants_read() {
if let (0, 0) = self.complete_io()? { if let (0, 0) = self.complete_io()? {
@ -100,7 +100,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> {
} }
} }
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Write for Stream<'a, S, IO> { impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let len = self.session.write(buf)?; let len = self.session.write(buf)?;
while self.session.wants_write() { while self.session.wants_write() {

View File

@ -80,7 +80,7 @@ fn stream_good() -> io::Result<()> {
{ {
let mut good = Good(&mut server); let mut good = Good(&mut server);
let mut stream = Stream::new(&mut client, &mut good); let mut stream = Stream::new(&mut good, &mut client);
let mut buf = Vec::new(); let mut buf = Vec::new();
stream.read_to_end(&mut buf)?; stream.read_to_end(&mut buf)?;
@ -102,7 +102,7 @@ fn stream_bad() -> io::Result<()> {
client.set_buffer_limit(1024); client.set_buffer_limit(1024);
let mut bad = Bad(true); let mut bad = Bad(true);
let mut stream = Stream::new(&mut client, &mut bad); let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(stream.write(&[0x42; 8])?, 8); assert_eq!(stream.write(&[0x42; 8])?, 8);
assert_eq!(stream.write(&[0x42; 8])?, 8); assert_eq!(stream.write(&[0x42; 8])?, 8);
let r = stream.write(&[0x00; 1024])?; // fill buffer let r = stream.write(&[0x00; 1024])?; // fill buffer
@ -121,7 +121,7 @@ fn stream_handshake() -> io::Result<()> {
{ {
let mut good = Good(&mut server); let mut good = Good(&mut server);
let mut stream = Stream::new(&mut client, &mut good); let mut stream = Stream::new(&mut good, &mut client);
let (r, w) = stream.complete_io()?; let (r, w) = stream.complete_io()?;
assert!(r > 0); assert!(r > 0);
@ -141,7 +141,7 @@ fn stream_handshake_eof() -> io::Result<()> {
let (_, mut client) = make_pair(); let (_, mut client) = make_pair();
let mut bad = Bad(false); let mut bad = Bad(false);
let mut stream = Stream::new(&mut client, &mut bad); let mut stream = Stream::new(&mut bad, &mut client);
let r = stream.complete_io(); let r = stream.complete_io();
assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
@ -171,7 +171,7 @@ fn make_pair() -> (ServerSession, ClientSession) {
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) {
let mut good = Good(server); let mut good = Good(server);
let mut stream = Stream::new(client, &mut good); let mut stream = Stream::new(&mut good, client);
stream.complete_io().unwrap(); stream.complete_io().unwrap();
stream.complete_io().unwrap(); stream.complete_io().unwrap();
} }

View File

@ -12,12 +12,13 @@ extern crate iovec;
mod common; mod common;
mod tokio_impl; mod tokio_impl;
use std::io; use std::mem;
use std::io::{ self, Write };
use std::sync::Arc; use std::sync::Arc;
use webpki::DNSNameRef; use webpki::DNSNameRef;
use rustls::{ use rustls::{
Session, ClientSession, ServerSession, Session, ClientSession, ServerSession,
ClientConfig, ServerConfig, ClientConfig, ServerConfig
}; };
use tokio_io::{ AsyncRead, AsyncWrite }; use tokio_io::{ AsyncRead, AsyncWrite };
use common::Stream; use common::Stream;
@ -25,7 +26,8 @@ use common::Stream;
#[derive(Clone)] #[derive(Clone)]
pub struct TlsConnector { pub struct TlsConnector {
inner: Arc<ClientConfig> inner: Arc<ClientConfig>,
early_data: bool
} }
#[derive(Clone)] #[derive(Clone)]
@ -35,7 +37,7 @@ pub struct TlsAcceptor {
impl From<Arc<ClientConfig>> for TlsConnector { impl From<Arc<ClientConfig>> for TlsConnector {
fn from(inner: Arc<ClientConfig>) -> TlsConnector { fn from(inner: Arc<ClientConfig>) -> TlsConnector {
TlsConnector { inner } TlsConnector { inner, early_data: false }
} }
} }
@ -46,19 +48,39 @@ impl From<Arc<ServerConfig>> for TlsAcceptor {
} }
impl TlsConnector { impl TlsConnector {
pub fn early_data(mut self, flag: bool) -> TlsConnector {
self.early_data = flag;
self
}
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where IO: AsyncRead + AsyncWrite where IO: AsyncRead + AsyncWrite
{ {
Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) self.connect_with(domain, stream, |_| ())
} }
#[inline] #[inline]
pub fn connect_with_session<IO>(stream: IO, session: ClientSession) pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F)
-> Connect<IO> -> Connect<IO>
where IO: AsyncRead + AsyncWrite where
IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ClientSession)
{ {
Connect(MidHandshake { let mut session = ClientSession::new(&self.inner, domain);
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) f(&mut session);
Connect(if self.early_data {
MidHandshake::EarlyData(TlsStream {
session, io: stream,
state: TlsState::EarlyData,
early_data: (0, Vec::new())
})
} else {
MidHandshake::Handshaking(TlsStream {
session, io: stream,
state: TlsState::Stream,
early_data: (0, Vec::new())
})
}) })
} }
} }
@ -67,16 +89,24 @@ impl TlsAcceptor {
pub fn accept<IO>(&self, stream: IO) -> Accept<IO> pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where IO: AsyncRead + AsyncWrite, where IO: AsyncRead + AsyncWrite,
{ {
Self::accept_with_session(stream, ServerSession::new(&self.inner)) self.accept_with(stream, |_| ())
} }
#[inline] #[inline]
pub fn accept_with_session<IO>(stream: IO, session: ServerSession) -> Accept<IO> pub fn accept_with<IO, F>(&self, stream: IO, f: F)
where IO: AsyncRead + AsyncWrite -> Accept<IO>
where
IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ServerSession)
{ {
Accept(MidHandshake { let mut session = ServerSession::new(&self.inner);
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) f(&mut session);
})
Accept(MidHandshake::Handshaking(TlsStream {
session, io: stream,
state: TlsState::Stream,
early_data: (0, Vec::new())
}))
} }
} }
@ -89,9 +119,10 @@ pub struct Connect<IO>(MidHandshake<IO, ClientSession>);
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<IO>(MidHandshake<IO, ServerSession>); pub struct Accept<IO>(MidHandshake<IO, ServerSession>);
enum MidHandshake<IO, S> {
struct MidHandshake<IO, S> { Handshaking(TlsStream<IO, S>),
inner: Option<TlsStream<IO, S>> EarlyData(TlsStream<IO, S>),
End
} }
@ -99,10 +130,18 @@ struct MidHandshake<IO, S> {
/// protocol. /// protocol.
#[derive(Debug)] #[derive(Debug)]
pub struct TlsStream<IO, S> { pub struct TlsStream<IO, S> {
is_shutdown: bool,
eof: bool,
io: IO, io: IO,
session: S session: S,
state: TlsState,
early_data: (usize, Vec<u8>)
}
#[derive(Debug)]
enum TlsState {
EarlyData,
Stream,
Eof,
Shutdown
} }
impl<IO, S> TlsStream<IO, S> { impl<IO, S> TlsStream<IO, S> {
@ -122,50 +161,135 @@ impl<IO, S> TlsStream<IO, S> {
} }
} }
impl<IO, S: Session> From<(IO, S)> for TlsStream<IO, S> { impl<IO> io::Read for TlsStream<IO, ClientSession>
#[inline] where IO: AsyncRead + AsyncWrite
fn from((io, session): (IO, S)) -> TlsStream<IO, S> {
assert!(!session.is_handshaking());
TlsStream {
is_shutdown: false,
eof: false,
io, session
}
}
}
impl<IO, S> io::Read for TlsStream<IO, S>
where IO: AsyncRead + AsyncWrite, S: Session
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.eof { let mut stream = Stream::new(&mut self.io, &mut self.session);
return Ok(0);
match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
} }
match Stream::new(&mut self.session, &mut self.io).read(buf) { // write early data (fallback)
Ok(0) => { self.eof = true; Ok(0) }, if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
*pos = 0;
data.clear();
stream.read(buf)
},
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n), Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.eof = true; self.state = TlsState::Shutdown;
self.is_shutdown = true; stream.session.send_close_notify();
self.session.send_close_notify();
Ok(0) Ok(0)
}, },
Err(e) => Err(e) Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
} }
} }
} }
impl<IO, S> io::Write for TlsStream<IO, S> impl<IO> io::Read for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite, S: Session where IO: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
TlsState::EarlyData => unreachable!()
}
}
}
impl<IO> io::Write for TlsStream<IO, ClientSession>
where IO: AsyncRead + AsyncWrite
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Stream::new(&mut self.session, &mut self.io).write(buf) let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?;
data.extend_from_slice(&buf[..len]);
return Ok(len);
}
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
*pos = 0;
data.clear();
stream.write(buf)
},
_ => stream.write(buf)
}
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.session, &mut self.io).flush()?; Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> io::Write for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
stream.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush() self.io.flush()
} }
} }

View File

@ -42,10 +42,10 @@ where
type Error = io::Error; type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
{ match self {
let stream = self.inner.as_mut().unwrap(); MidHandshake::Handshaking(stream) => {
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(session, io); let mut stream = Stream::new(io, session);
if stream.session.is_handshaking() { if stream.session.is_handshaking() {
try_async!(stream.complete_io()); try_async!(stream.complete_io());
@ -54,35 +54,68 @@ where
if stream.session.wants_write() { if stream.session.wants_write() {
try_async!(stream.complete_io()); try_async!(stream.complete_io());
} }
},
_ => ()
} }
Ok(Async::Ready(self.inner.take().unwrap())) match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream)
| MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
}
} }
} }
impl<IO, S> AsyncRead for TlsStream<IO, S> impl<IO> AsyncRead for TlsStream<IO, ClientSession>
where where IO: AsyncRead + AsyncWrite
IO: AsyncRead + AsyncWrite,
S: Session
{ {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false false
} }
} }
impl<IO, S> AsyncWrite for TlsStream<IO, S> impl<IO> AsyncRead for TlsStream<IO, ServerSession>
where where IO: AsyncRead + AsyncWrite
IO: AsyncRead + AsyncWrite, {
S: Session unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<IO> AsyncWrite for TlsStream<IO, ClientSession>
where IO: AsyncRead + AsyncWrite,
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
if !self.is_shutdown { match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify(); self.session.send_close_notify();
self.is_shutdown = true; self.state = TlsState::Shutdown;
}
} }
{ {
let mut stream = Stream::new(&mut self.session, &mut self.io); let mut stream = Stream::new(&mut self.io, &mut self.session);
try_async!(stream.complete_io());
}
self.io.shutdown()
}
}
impl<IO> AsyncWrite for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
}
}
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_async!(stream.complete_io()); try_async!(stream.complete_io());
} }
self.io.shutdown() self.io.shutdown()