commit
8b8647b43d
@ -25,3 +25,4 @@ webpki = "0.19"
|
||||
[dev-dependencies]
|
||||
tokio = "0.1.6"
|
||||
lazy_static = "1"
|
||||
webpki-roots = "0.16"
|
||||
|
196
src/client.rs
Normal file
196
src/client.rs
Normal file
@ -0,0 +1,196 @@
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use rustls::Session;
|
||||
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStream<IO> {
|
||||
pub(crate) io: IO,
|
||||
pub(crate) session: ClientSession,
|
||||
pub(crate) state: TlsState,
|
||||
pub(crate) early_data: (usize, Vec<u8>)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TlsState {
|
||||
EarlyData,
|
||||
Stream,
|
||||
Eof,
|
||||
Shutdown
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
EarlyData(TlsStream<IO>),
|
||||
End
|
||||
}
|
||||
|
||||
impl<IO> TlsStream<IO> {
|
||||
#[inline]
|
||||
pub fn get_ref(&self) -> (&IO, &ClientSession) {
|
||||
(&self.io, &self.session)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
|
||||
(&mut self.io, &mut self.session)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_inner(self) -> (IO, ClientSession) {
|
||||
(self.io, self.session)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Future for MidHandshake<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
match self {
|
||||
MidHandshake::Handshaking(stream) => {
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(io, session);
|
||||
|
||||
if stream.session.is_handshaking() {
|
||||
try_nb!(stream.complete_io());
|
||||
}
|
||||
|
||||
if stream.session.wants_write() {
|
||||
try_nb!(stream.complete_io());
|
||||
}
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
|
||||
match mem::replace(self, MidHandshake::End) {
|
||||
MidHandshake::Handshaking(stream)
|
||||
| MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
|
||||
MidHandshake::End => panic!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Read for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
||||
match self.state {
|
||||
TlsState::EarlyData => {
|
||||
let (pos, data) = &mut self.early_data;
|
||||
|
||||
// complete handshake
|
||||
if stream.session.is_handshaking() {
|
||||
stream.complete_io()?;
|
||||
}
|
||||
|
||||
// write early data (fallback)
|
||||
if !stream.session.is_early_data_accepted() {
|
||||
while *pos < data.len() {
|
||||
let len = stream.write(&data[*pos..])?;
|
||||
*pos += len;
|
||||
}
|
||||
}
|
||||
|
||||
// end
|
||||
self.state = TlsState::Stream;
|
||||
data.clear();
|
||||
stream.read(buf)
|
||||
},
|
||||
TlsState::Stream => match stream.read(buf) {
|
||||
Ok(0) => {
|
||||
self.state = TlsState::Eof;
|
||||
Ok(0)
|
||||
},
|
||||
Ok(n) => Ok(n),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.state = TlsState::Shutdown;
|
||||
stream.session.send_close_notify();
|
||||
Ok(0)
|
||||
},
|
||||
Err(e) => Err(e)
|
||||
},
|
||||
TlsState::Eof | TlsState::Shutdown => Ok(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Write for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
||||
match self.state {
|
||||
TlsState::EarlyData => {
|
||||
let (pos, data) = &mut self.early_data;
|
||||
|
||||
// write early data
|
||||
if let Some(mut early_data) = stream.session.early_data() {
|
||||
let len = early_data.write(buf)?;
|
||||
data.extend_from_slice(&buf[..len]);
|
||||
return Ok(len);
|
||||
}
|
||||
|
||||
// complete handshake
|
||||
if stream.session.is_handshaking() {
|
||||
stream.complete_io()?;
|
||||
}
|
||||
|
||||
// write early data (fallback)
|
||||
if !stream.session.is_early_data_accepted() {
|
||||
while *pos < data.len() {
|
||||
let len = stream.write(&data[*pos..])?;
|
||||
*pos += len;
|
||||
}
|
||||
}
|
||||
|
||||
// end
|
||||
self.state = TlsState::Stream;
|
||||
data.clear();
|
||||
stream.write(buf)
|
||||
},
|
||||
_ => stream.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Stream::new(&mut self.io, &mut self.session).flush()?;
|
||||
self.io.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> AsyncRead for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self.state {
|
||||
TlsState::Shutdown => (),
|
||||
_ => {
|
||||
self.session.send_close_notify();
|
||||
self.state = TlsState::Shutdown;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
try_nb!(stream.complete_io());
|
||||
}
|
||||
self.io.shutdown()
|
||||
}
|
||||
}
|
@ -6,18 +6,18 @@ use rustls::WriteV;
|
||||
use tokio_io::{ AsyncRead, AsyncWrite };
|
||||
|
||||
|
||||
pub struct Stream<'a, S: 'a, IO: 'a> {
|
||||
pub session: &'a mut S,
|
||||
pub io: &'a mut IO
|
||||
pub struct Stream<'a, IO: 'a, S: 'a> {
|
||||
pub io: &'a mut IO,
|
||||
pub session: &'a mut S
|
||||
}
|
||||
|
||||
pub trait WriteTls<'a, S: Session, IO: AsyncRead + AsyncWrite>: Read + Write {
|
||||
pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write {
|
||||
fn write_tls(&mut self) -> io::Result<usize>;
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> {
|
||||
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
|
||||
Stream { session, io }
|
||||
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
|
||||
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
|
||||
Stream { io, session }
|
||||
}
|
||||
|
||||
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
|
||||
@ -66,7 +66,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> {
|
||||
fn write_tls(&mut self) -> io::Result<usize> {
|
||||
use futures::Async;
|
||||
use self::vecbuf::VecBuf;
|
||||
@ -89,7 +89,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> {
|
||||
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
while self.session.wants_read() {
|
||||
if let (0, 0) = self.complete_io()? {
|
||||
@ -100,7 +100,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Write for Stream<'a, S, IO> {
|
||||
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let len = self.session.write(buf)?;
|
||||
while self.session.wants_write() {
|
||||
|
@ -80,7 +80,7 @@ fn stream_good() -> io::Result<()> {
|
||||
|
||||
{
|
||||
let mut good = Good(&mut server);
|
||||
let mut stream = Stream::new(&mut client, &mut good);
|
||||
let mut stream = Stream::new(&mut good, &mut client);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
stream.read_to_end(&mut buf)?;
|
||||
@ -102,7 +102,7 @@ fn stream_bad() -> io::Result<()> {
|
||||
client.set_buffer_limit(1024);
|
||||
|
||||
let mut bad = Bad(true);
|
||||
let mut stream = Stream::new(&mut client, &mut bad);
|
||||
let mut stream = Stream::new(&mut bad, &mut client);
|
||||
assert_eq!(stream.write(&[0x42; 8])?, 8);
|
||||
assert_eq!(stream.write(&[0x42; 8])?, 8);
|
||||
let r = stream.write(&[0x00; 1024])?; // fill buffer
|
||||
@ -121,7 +121,7 @@ fn stream_handshake() -> io::Result<()> {
|
||||
|
||||
{
|
||||
let mut good = Good(&mut server);
|
||||
let mut stream = Stream::new(&mut client, &mut good);
|
||||
let mut stream = Stream::new(&mut good, &mut client);
|
||||
let (r, w) = stream.complete_io()?;
|
||||
|
||||
assert!(r > 0);
|
||||
@ -141,7 +141,7 @@ fn stream_handshake_eof() -> io::Result<()> {
|
||||
let (_, mut client) = make_pair();
|
||||
|
||||
let mut bad = Bad(false);
|
||||
let mut stream = Stream::new(&mut client, &mut bad);
|
||||
let mut stream = Stream::new(&mut bad, &mut client);
|
||||
let r = stream.complete_io();
|
||||
|
||||
assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
|
||||
@ -171,7 +171,7 @@ fn make_pair() -> (ServerSession, ClientSession) {
|
||||
|
||||
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) {
|
||||
let mut good = Good(server);
|
||||
let mut stream = Stream::new(client, &mut good);
|
||||
let mut stream = Stream::new(&mut good, client);
|
||||
stream.complete_io().unwrap();
|
||||
stream.complete_io().unwrap();
|
||||
}
|
||||
|
158
src/lib.rs
158
src/lib.rs
@ -8,24 +8,26 @@ extern crate tokio_io;
|
||||
extern crate bytes;
|
||||
extern crate iovec;
|
||||
|
||||
|
||||
mod common;
|
||||
mod tokio_impl;
|
||||
pub mod client;
|
||||
pub mod server;
|
||||
|
||||
use std::io;
|
||||
use std::{ io, mem };
|
||||
use std::sync::Arc;
|
||||
use webpki::DNSNameRef;
|
||||
use rustls::{
|
||||
Session, ClientSession, ServerSession,
|
||||
ClientConfig, ServerConfig,
|
||||
ClientSession, ServerSession,
|
||||
ClientConfig, ServerConfig
|
||||
};
|
||||
use tokio_io::{ AsyncRead, AsyncWrite };
|
||||
use futures::{Async, Future, Poll};
|
||||
use tokio_io::{ AsyncRead, AsyncWrite, try_nb };
|
||||
use common::Stream;
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TlsConnector {
|
||||
inner: Arc<ClientConfig>
|
||||
inner: Arc<ClientConfig>,
|
||||
early_data: bool
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -35,7 +37,7 @@ pub struct TlsAcceptor {
|
||||
|
||||
impl From<Arc<ClientConfig>> for TlsConnector {
|
||||
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
|
||||
TlsConnector { inner }
|
||||
TlsConnector { inner, early_data: false }
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,19 +48,43 @@ impl From<Arc<ServerConfig>> for TlsAcceptor {
|
||||
}
|
||||
|
||||
impl TlsConnector {
|
||||
/// Enable 0-RTT.
|
||||
///
|
||||
/// Note that you want to use 0-RTT.
|
||||
/// You must set `enable_early_data` to `true` in `ClientConfig`.
|
||||
pub fn early_data(mut self, flag: bool) -> TlsConnector {
|
||||
self.early_data = flag;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
Self::connect_with_session(stream, ClientSession::new(&self.inner, domain))
|
||||
self.connect_with(domain, stream, |_| ())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn connect_with_session<IO>(stream: IO, session: ClientSession)
|
||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F)
|
||||
-> Connect<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
F: FnOnce(&mut ClientSession)
|
||||
{
|
||||
Connect(MidHandshake {
|
||||
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
|
||||
let mut session = ClientSession::new(&self.inner, domain);
|
||||
f(&mut session);
|
||||
|
||||
Connect(if self.early_data {
|
||||
client::MidHandshake::EarlyData(client::TlsStream {
|
||||
session, io: stream,
|
||||
state: client::TlsState::EarlyData,
|
||||
early_data: (0, Vec::new())
|
||||
})
|
||||
} else {
|
||||
client::MidHandshake::Handshaking(client::TlsStream {
|
||||
session, io: stream,
|
||||
state: client::TlsState::Stream,
|
||||
early_data: (0, Vec::new())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -67,105 +93,53 @@ impl TlsAcceptor {
|
||||
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
Self::accept_with_session(stream, ServerSession::new(&self.inner))
|
||||
self.accept_with(stream, |_| ())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn accept_with_session<IO>(stream: IO, session: ServerSession) -> Accept<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
pub fn accept_with<IO, F>(&self, stream: IO, f: F)
|
||||
-> Accept<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
F: FnOnce(&mut ServerSession)
|
||||
{
|
||||
Accept(MidHandshake {
|
||||
inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
|
||||
})
|
||||
let mut session = ServerSession::new(&self.inner);
|
||||
f(&mut session);
|
||||
|
||||
Accept(server::MidHandshake::Handshaking(server::TlsStream {
|
||||
session, io: stream,
|
||||
state: server::TlsState::Stream,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Future returned from `ClientConfigExt::connect_async` which will resolve
|
||||
/// once the connection handshake has finished.
|
||||
pub struct Connect<IO>(MidHandshake<IO, ClientSession>);
|
||||
pub struct Connect<IO>(client::MidHandshake<IO>);
|
||||
|
||||
/// Future returned from `ServerConfigExt::accept_async` which will resolve
|
||||
/// once the accept handshake has finished.
|
||||
pub struct Accept<IO>(MidHandshake<IO, ServerSession>);
|
||||
pub struct Accept<IO>(server::MidHandshake<IO>);
|
||||
|
||||
|
||||
struct MidHandshake<IO, S> {
|
||||
inner: Option<TlsStream<IO, S>>
|
||||
}
|
||||
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
|
||||
type Item = client::TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStream<IO, S> {
|
||||
is_shutdown: bool,
|
||||
eof: bool,
|
||||
io: IO,
|
||||
session: S
|
||||
}
|
||||
|
||||
impl<IO, S> TlsStream<IO, S> {
|
||||
#[inline]
|
||||
pub fn get_ref(&self) -> (&IO, &S) {
|
||||
(&self.io, &self.session)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_mut(&mut self) -> (&mut IO, &mut S) {
|
||||
(&mut self.io, &mut self.session)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_inner(self) -> (IO, S) {
|
||||
(self.io, self.session)
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
self.0.poll()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO, S: Session> From<(IO, S)> for TlsStream<IO, S> {
|
||||
#[inline]
|
||||
fn from((io, session): (IO, S)) -> TlsStream<IO, S> {
|
||||
assert!(!session.is_handshaking());
|
||||
impl<IO: AsyncRead + AsyncWrite> Future for Accept<IO> {
|
||||
type Item = server::TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
|
||||
TlsStream {
|
||||
is_shutdown: false,
|
||||
eof: false,
|
||||
io, session
|
||||
}
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
self.0.poll()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO, S> io::Read for TlsStream<IO, S>
|
||||
where IO: AsyncRead + AsyncWrite, S: Session
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
if self.eof {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
match Stream::new(&mut self.session, &mut self.io).read(buf) {
|
||||
Ok(0) => { self.eof = true; Ok(0) },
|
||||
Ok(n) => Ok(n),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.eof = true;
|
||||
self.is_shutdown = true;
|
||||
self.session.send_close_notify();
|
||||
Ok(0)
|
||||
},
|
||||
Err(e) => Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO, S> io::Write for TlsStream<IO, S>
|
||||
where IO: AsyncRead + AsyncWrite, S: Session
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
Stream::new(&mut self.session, &mut self.io).write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Stream::new(&mut self.session, &mut self.io).flush()?;
|
||||
self.io.flush()
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod test_0rtt;
|
||||
|
139
src/server.rs
Normal file
139
src/server.rs
Normal file
@ -0,0 +1,139 @@
|
||||
use super::*;
|
||||
use rustls::Session;
|
||||
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStream<IO> {
|
||||
pub(crate) io: IO,
|
||||
pub(crate) session: ServerSession,
|
||||
pub(crate) state: TlsState
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TlsState {
|
||||
Stream,
|
||||
Eof,
|
||||
Shutdown
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
End
|
||||
}
|
||||
|
||||
impl<IO> TlsStream<IO> {
|
||||
#[inline]
|
||||
pub fn get_ref(&self) -> (&IO, &ServerSession) {
|
||||
(&self.io, &self.session)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) {
|
||||
(&mut self.io, &mut self.session)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_inner(self) -> (IO, ServerSession) {
|
||||
(self.io, self.session)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Future for MidHandshake<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = TlsStream<IO>;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
match self {
|
||||
MidHandshake::Handshaking(stream) => {
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(io, session);
|
||||
|
||||
if stream.session.is_handshaking() {
|
||||
try_nb!(stream.complete_io());
|
||||
}
|
||||
|
||||
if stream.session.wants_write() {
|
||||
try_nb!(stream.complete_io());
|
||||
}
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
|
||||
match mem::replace(self, MidHandshake::End) {
|
||||
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
|
||||
MidHandshake::End => panic!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Read for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
|
||||
match self.state {
|
||||
TlsState::Stream => match stream.read(buf) {
|
||||
Ok(0) => {
|
||||
self.state = TlsState::Eof;
|
||||
Ok(0)
|
||||
},
|
||||
Ok(n) => Ok(n),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.state = TlsState::Shutdown;
|
||||
stream.session.send_close_notify();
|
||||
Ok(0)
|
||||
},
|
||||
Err(e) => Err(e)
|
||||
},
|
||||
TlsState::Eof | TlsState::Shutdown => Ok(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> io::Write for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
stream.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Stream::new(&mut self.io, &mut self.session).flush()?;
|
||||
self.io.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> AsyncRead for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where IO: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self.state {
|
||||
TlsState::Shutdown => (),
|
||||
_ => {
|
||||
self.session.send_close_notify();
|
||||
self.state = TlsState::Shutdown;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut stream = Stream::new(&mut self.io, &mut self.session);
|
||||
try_nb!(stream.complete_io());
|
||||
}
|
||||
self.io.shutdown()
|
||||
}
|
||||
}
|
51
src/test_0rtt.rs
Normal file
51
src/test_0rtt.rs
Normal file
@ -0,0 +1,51 @@
|
||||
extern crate tokio;
|
||||
extern crate webpki;
|
||||
extern crate webpki_roots;
|
||||
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use std::net::ToSocketAddrs;
|
||||
use self::tokio::io as aio;
|
||||
use self::tokio::prelude::*;
|
||||
use self::tokio::net::TcpStream;
|
||||
use rustls::ClientConfig;
|
||||
use ::{ TlsConnector, client::TlsStream };
|
||||
|
||||
|
||||
fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool)
|
||||
-> io::Result<(TlsStream<TcpStream>, String)>
|
||||
{
|
||||
let config = TlsConnector::from(config).early_data(rtt0);
|
||||
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
||||
|
||||
let addr = (domain, 443)
|
||||
.to_socket_addrs()?
|
||||
.next().unwrap();
|
||||
|
||||
TcpStream::connect(&addr)
|
||||
.and_then(move |stream| {
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
||||
config.connect(domain, stream)
|
||||
})
|
||||
.and_then(move |stream| aio::write_all(stream, input))
|
||||
.and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new()))
|
||||
.map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap()))
|
||||
.wait()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_0rtt() {
|
||||
let mut config = ClientConfig::new();
|
||||
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
||||
config.enable_early_data = true;
|
||||
let config = Arc::new(config);
|
||||
let domain = "mozilla-modern.badssl.com";
|
||||
|
||||
let (_, output) = get(config.clone(), domain, false).unwrap();
|
||||
assert!(output.contains("<title>mozilla-modern.badssl.com</title>"));
|
||||
|
||||
let (io, output) = get(config.clone(), domain, true).unwrap();
|
||||
assert!(output.contains("<title>mozilla-modern.badssl.com</title>"));
|
||||
|
||||
assert_eq!(io.early_data.0, 0);
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
use super::*;
|
||||
use tokio_io::{ AsyncRead, AsyncWrite };
|
||||
use futures::{Async, Future, Poll};
|
||||
use common::Stream;
|
||||
|
||||
|
||||
macro_rules! try_async {
|
||||
( $e:expr ) => {
|
||||
match $e {
|
||||
Ok(n) => n,
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
|
||||
return Ok(Async::NotReady),
|
||||
Err(e) => return Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
|
||||
type Item = TlsStream<IO, ClientSession>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
self.0.poll()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite> Future for Accept<IO> {
|
||||
type Item = TlsStream<IO, ServerSession>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
self.0.poll()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO, S> Future for MidHandshake<IO, S>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
S: Session
|
||||
{
|
||||
type Item = TlsStream<IO, S>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
{
|
||||
let stream = self.inner.as_mut().unwrap();
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(session, io);
|
||||
|
||||
if stream.session.is_handshaking() {
|
||||
try_async!(stream.complete_io());
|
||||
}
|
||||
|
||||
if stream.session.wants_write() {
|
||||
try_async!(stream.complete_io());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Async::Ready(self.inner.take().unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO, S> AsyncRead for TlsStream<IO, S>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
S: Session
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO, S> AsyncWrite for TlsStream<IO, S>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
S: Session
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
if !self.is_shutdown {
|
||||
self.session.send_close_notify();
|
||||
self.is_shutdown = true;
|
||||
}
|
||||
|
||||
{
|
||||
let mut stream = Stream::new(&mut self.session, &mut self.io);
|
||||
try_async!(stream.complete_io());
|
||||
}
|
||||
self.io.shutdown()
|
||||
}
|
||||
}
|
@ -66,17 +66,14 @@ fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
|
||||
&*TEST_SERVER
|
||||
}
|
||||
|
||||
fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> {
|
||||
fn start_client(addr: &SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
|
||||
use tokio::prelude::*;
|
||||
use tokio::io as aio;
|
||||
|
||||
const FILE: &'static [u8] = include_bytes!("../README.md");
|
||||
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
||||
let mut config = ClientConfig::new();
|
||||
let mut chain = BufReader::new(Cursor::new(chain));
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
let config = TlsConnector::from(Arc::new(config));
|
||||
let config = TlsConnector::from(config);
|
||||
|
||||
let done = TcpStream::connect(addr)
|
||||
.and_then(|stream| config.connect(domain, stream))
|
||||
@ -95,13 +92,23 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()>
|
||||
fn pass() {
|
||||
let (addr, domain, chain) = start_server();
|
||||
|
||||
start_client(addr, domain, chain).unwrap();
|
||||
let mut config = ClientConfig::new();
|
||||
let mut chain = BufReader::new(Cursor::new(chain));
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
let config = Arc::new(config);
|
||||
|
||||
start_client(addr, domain, config.clone()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fail() {
|
||||
let (addr, domain, chain) = start_server();
|
||||
|
||||
let mut config = ClientConfig::new();
|
||||
let mut chain = BufReader::new(Cursor::new(chain));
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
let config = Arc::new(config);
|
||||
|
||||
assert_ne!(domain, &"google.com");
|
||||
assert!(start_client(addr, "google.com", chain).is_err());
|
||||
assert!(start_client(addr, "google.com", config).is_err());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user