diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..8d6758e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,196 @@ +use super::*; +use std::io::Write; +use rustls::Session; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ClientSession, + pub(crate) state: TlsState, + pub(crate) early_data: (usize, Vec) +} + +#[derive(Debug)] +pub(crate) enum TlsState { + EarlyData, + Stream, + Eof, + Shutdown +} + +pub(crate) enum MidHandshake { + Handshaking(TlsStream), + EarlyData(TlsStream), + End +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ClientSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ClientSession) { + (self.io, self.session) + } +} + +impl Future for MidHandshake +where IO: AsyncRead + AsyncWrite, +{ + type Item = TlsStream; + type Error = io::Error; + + #[inline] + fn poll(&mut self) -> Poll { + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); + + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } + }, + _ => () + } + + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) + | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } + } +} + +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); + stream.read(buf) + }, + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0), + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = early_data.write(buf)?; + data.extend_from_slice(&buf[..len]); + return Ok(len); + } + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); + stream.write(buf) + }, + _ => stream.write(buf) + } + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + } + self.io.shutdown() + } +} diff --git a/src/lib.rs b/src/lib.rs index dd34452..446e80d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,19 +8,19 @@ extern crate tokio_io; extern crate bytes; extern crate iovec; - mod common; -mod tokio_impl; +pub mod client; +pub mod server; -use std::mem; -use std::io::{ self, Write }; +use std::{ io, mem }; use std::sync::Arc; use webpki::DNSNameRef; use rustls::{ - Session, ClientSession, ServerSession, + ClientSession, ServerSession, ClientConfig, ServerConfig }; -use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::{Async, Future, Poll}; +use tokio_io::{ AsyncRead, AsyncWrite, try_nb }; use common::Stream; @@ -74,15 +74,15 @@ impl TlsConnector { f(&mut session); Connect(if self.early_data { - MidHandshake::EarlyData(TlsStream { + client::MidHandshake::EarlyData(client::TlsStream { session, io: stream, - state: TlsState::EarlyData, + state: client::TlsState::EarlyData, early_data: (0, Vec::new()) }) } else { - MidHandshake::Handshaking(TlsStream { + client::MidHandshake::Handshaking(client::TlsStream { session, io: stream, - state: TlsState::Stream, + state: client::TlsState::Stream, early_data: (0, Vec::new()) }) }) @@ -106,10 +106,9 @@ impl TlsAcceptor { let mut session = ServerSession::new(&self.inner); f(&mut session); - Accept(MidHandshake::Handshaking(TlsStream { + Accept(server::MidHandshake::Handshaking(server::TlsStream { session, io: stream, - state: TlsState::Stream, - early_data: (0, Vec::new()) + state: server::TlsState::Stream, })) } } @@ -117,182 +116,28 @@ impl TlsAcceptor { /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct Connect(MidHandshake); +pub struct Connect(client::MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct Accept(MidHandshake); - -enum MidHandshake { - Handshaking(TlsStream), - EarlyData(TlsStream), - End -} +pub struct Accept(server::MidHandshake); -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -#[derive(Debug)] -pub struct TlsStream { - io: IO, - session: S, - state: TlsState, - early_data: (usize, Vec) -} +impl Future for Connect { + type Item = client::TlsStream; + type Error = io::Error; -#[derive(Debug)] -enum TlsState { - EarlyData, - Stream, - Eof, - Shutdown -} - -impl TlsStream { - #[inline] - pub fn get_ref(&self) -> (&IO, &S) { - (&self.io, &self.session) - } - - #[inline] - pub fn get_mut(&mut self) -> (&mut IO, &mut S) { - (&mut self.io, &mut self.session) - } - - #[inline] - pub fn into_inner(self) -> (IO, S) { - (self.io, self.session) + fn poll(&mut self) -> Poll { + self.0.poll() } } -impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); +impl Future for Accept { + type Item = server::TlsStream; + type Error = io::Error; - match self.state { - TlsState::EarlyData => { - let (pos, data) = &mut self.early_data; - - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; - } - } - - // end - self.state = TlsState::Stream; - data.clear(); - stream.read(buf) - }, - TlsState::Stream => match stream.read(buf) { - Ok(0) => { - self.state = TlsState::Eof; - Ok(0) - }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); - Ok(0) - }, - Err(e) => Err(e) - }, - TlsState::Eof | TlsState::Shutdown => Ok(0), - } - } -} - -impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - - match self.state { - TlsState::Stream => match stream.read(buf) { - Ok(0) => { - self.state = TlsState::Eof; - Ok(0) - }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); - Ok(0) - }, - Err(e) => Err(e) - }, - TlsState::Eof | TlsState::Shutdown => Ok(0), - TlsState::EarlyData => unreachable!() - } - } -} - -impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - - match self.state { - TlsState::EarlyData => { - let (pos, data) = &mut self.early_data; - - // write early data - if let Some(mut early_data) = stream.session.early_data() { - let len = early_data.write(buf)?; - data.extend_from_slice(&buf[..len]); - return Ok(len); - } - - // complete handshake - if stream.session.is_handshaking() { - stream.complete_io()?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = stream.write(&data[*pos..])?; - *pos += len; - } - } - - // end - self.state = TlsState::Stream; - data.clear(); - stream.write(buf) - }, - _ => stream.write(buf) - } - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() - } -} - -impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut stream = Stream::new(&mut self.io, &mut self.session); - stream.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.io, &mut self.session).flush()?; - self.io.flush() + fn poll(&mut self) -> Poll { + self.0.poll() } } diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..42dd18d --- /dev/null +++ b/src/server.rs @@ -0,0 +1,139 @@ +use super::*; +use rustls::Session; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ServerSession, + pub(crate) state: TlsState +} + +#[derive(Debug)] +pub(crate) enum TlsState { + Stream, + Eof, + Shutdown +} + +pub(crate) enum MidHandshake { + Handshaking(TlsStream), + End +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ServerSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ServerSession) { + (self.io, self.session) + } +} + +impl Future for MidHandshake +where IO: AsyncRead + AsyncWrite, +{ + type Item = TlsStream; + type Error = io::Error; + + #[inline] + fn poll(&mut self) -> Poll { + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); + + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } + }, + _ => () + } + + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } + } +} + +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0) + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + stream.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + } + self.io.shutdown() + } +} diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 56c9d7b..0182406 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -8,12 +8,12 @@ use std::net::ToSocketAddrs; use self::tokio::io as aio; use self::tokio::prelude::*; use self::tokio::net::TcpStream; -use rustls::{ ClientConfig, ClientSession }; -use ::{ TlsConnector, TlsStream }; +use rustls::ClientConfig; +use ::{ TlsConnector, client::TlsStream }; fn get(config: Arc, domain: &str, rtt0: bool) - -> io::Result<(TlsStream, String)> + -> io::Result<(TlsStream, String)> { let config = TlsConnector::from(config).early_data(rtt0); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs deleted file mode 100644 index f97cde3..0000000 --- a/src/tokio_impl.rs +++ /dev/null @@ -1,123 +0,0 @@ -use super::*; -use tokio_io::{ AsyncRead, AsyncWrite }; -use futures::{Async, Future, Poll}; -use common::Stream; - - -macro_rules! try_async { - ( $e:expr ) => { - match $e { - Ok(n) => n, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - return Ok(Async::NotReady), - Err(e) => return Err(e) - } - } -} - -impl Future for Connect { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -impl Future for Accept { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite, - S: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - match self { - MidHandshake::Handshaking(stream) => { - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session); - - if stream.session.is_handshaking() { - try_async!(stream.complete_io()); - } - - if stream.session.wants_write() { - try_async!(stream.complete_io()); - } - }, - _ => () - } - - match mem::replace(self, MidHandshake::End) { - MidHandshake::Handshaking(stream) - | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), - MidHandshake::End => panic!() - } - } -} - -impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite, -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - TlsState::Shutdown => (), - _ => { - self.session.send_close_notify(); - self.state = TlsState::Shutdown; - } - } - - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_async!(stream.complete_io()); - } - self.io.shutdown() - } -} - -impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite, -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - TlsState::Shutdown => (), - _ => { - self.session.send_close_notify(); - self.state = TlsState::Shutdown; - } - } - - { - let mut stream = Stream::new(&mut self.io, &mut self.session); - try_async!(stream.complete_io()); - } - self.io.shutdown() - } -}