Improve for ServerSesssion

This commit is contained in:
quininer 2019-02-18 20:41:52 +08:00
parent 65932f5150
commit 527db99d02
5 changed files with 362 additions and 305 deletions

196
src/client.rs Normal file
View File

@ -0,0 +1,196 @@
use super::*;
use std::io::Write;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientSession,
pub(crate) state: TlsState,
pub(crate) early_data: (usize, Vec<u8>)
}
#[derive(Debug)]
pub(crate) enum TlsState {
EarlyData,
Stream,
Eof,
Shutdown
}
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
EarlyData(TlsStream<IO>),
End
}
impl<IO> TlsStream<IO> {
#[inline]
pub fn get_ref(&self) -> (&IO, &ClientSession) {
(&self.io, &self.session)
}
#[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
(&mut self.io, &mut self.session)
}
#[inline]
pub fn into_inner(self) -> (IO, ClientSession) {
(self.io, self.session)
}
}
impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite,
{
type Item = TlsStream<IO>;
type Error = io::Error;
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
MidHandshake::Handshaking(stream) => {
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);
if stream.session.is_handshaking() {
try_nb!(stream.complete_io());
}
if stream.session.wants_write() {
try_nb!(stream.complete_io());
}
},
_ => ()
}
match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream)
| MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
}
}
}
impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
stream.read(buf)
},
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
}
}
}
impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?;
data.extend_from_slice(&buf[..len]);
return Ok(len);
}
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
stream.write(buf)
},
_ => stream.write(buf)
}
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
}
}
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_nb!(stream.complete_io());
}
self.io.shutdown()
}
}

View File

@ -8,19 +8,19 @@ extern crate tokio_io;
extern crate bytes; extern crate bytes;
extern crate iovec; extern crate iovec;
mod common; mod common;
mod tokio_impl; pub mod client;
pub mod server;
use std::mem; use std::{ io, mem };
use std::io::{ self, Write };
use std::sync::Arc; use std::sync::Arc;
use webpki::DNSNameRef; use webpki::DNSNameRef;
use rustls::{ use rustls::{
Session, ClientSession, ServerSession, ClientSession, ServerSession,
ClientConfig, ServerConfig ClientConfig, ServerConfig
}; };
use tokio_io::{ AsyncRead, AsyncWrite }; use futures::{Async, Future, Poll};
use tokio_io::{ AsyncRead, AsyncWrite, try_nb };
use common::Stream; use common::Stream;
@ -74,15 +74,15 @@ impl TlsConnector {
f(&mut session); f(&mut session);
Connect(if self.early_data { Connect(if self.early_data {
MidHandshake::EarlyData(TlsStream { client::MidHandshake::EarlyData(client::TlsStream {
session, io: stream, session, io: stream,
state: TlsState::EarlyData, state: client::TlsState::EarlyData,
early_data: (0, Vec::new()) early_data: (0, Vec::new())
}) })
} else { } else {
MidHandshake::Handshaking(TlsStream { client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream, session, io: stream,
state: TlsState::Stream, state: client::TlsState::Stream,
early_data: (0, Vec::new()) early_data: (0, Vec::new())
}) })
}) })
@ -106,10 +106,9 @@ impl TlsAcceptor {
let mut session = ServerSession::new(&self.inner); let mut session = ServerSession::new(&self.inner);
f(&mut session); f(&mut session);
Accept(MidHandshake::Handshaking(TlsStream { Accept(server::MidHandshake::Handshaking(server::TlsStream {
session, io: stream, session, io: stream,
state: TlsState::Stream, state: server::TlsState::Stream,
early_data: (0, Vec::new())
})) }))
} }
} }
@ -117,182 +116,28 @@ impl TlsAcceptor {
/// Future returned from `ClientConfigExt::connect_async` which will resolve /// Future returned from `ClientConfigExt::connect_async` which will resolve
/// once the connection handshake has finished. /// once the connection handshake has finished.
pub struct Connect<IO>(MidHandshake<IO, ClientSession>); pub struct Connect<IO>(client::MidHandshake<IO>);
/// Future returned from `ServerConfigExt::accept_async` which will resolve /// Future returned from `ServerConfigExt::accept_async` which will resolve
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<IO>(MidHandshake<IO, ServerSession>); pub struct Accept<IO>(server::MidHandshake<IO>);
enum MidHandshake<IO, S> {
Handshaking(TlsStream<IO, S>),
EarlyData(TlsStream<IO, S>),
End
}
/// A wrapper around an underlying raw stream which implements the TLS or SSL impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
/// protocol. type Item = client::TlsStream<IO>;
#[derive(Debug)] type Error = io::Error;
pub struct TlsStream<IO, S> {
io: IO,
session: S,
state: TlsState,
early_data: (usize, Vec<u8>)
}
#[derive(Debug)] fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
enum TlsState { self.0.poll()
EarlyData,
Stream,
Eof,
Shutdown
}
impl<IO, S> TlsStream<IO, S> {
#[inline]
pub fn get_ref(&self) -> (&IO, &S) {
(&self.io, &self.session)
}
#[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut S) {
(&mut self.io, &mut self.session)
}
#[inline]
pub fn into_inner(self) -> (IO, S) {
(self.io, self.session)
} }
} }
impl<IO> io::Read for TlsStream<IO, ClientSession> impl<IO: AsyncRead + AsyncWrite> Future for Accept<IO> {
where IO: AsyncRead + AsyncWrite type Item = server::TlsStream<IO>;
{ type Error = io::Error;
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
TlsState::EarlyData => { self.0.poll()
let (pos, data) = &mut self.early_data;
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
stream.read(buf)
},
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
}
}
}
impl<IO> io::Read for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
TlsState::EarlyData => unreachable!()
}
}
}
impl<IO> io::Write for TlsStream<IO, ClientSession>
where IO: AsyncRead + AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?;
data.extend_from_slice(&buf[..len]);
return Ok(len);
}
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
stream.write(buf)
},
_ => stream.write(buf)
}
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> io::Write for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
stream.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
} }
} }

139
src/server.rs Normal file
View File

@ -0,0 +1,139 @@
use super::*;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ServerSession,
pub(crate) state: TlsState
}
#[derive(Debug)]
pub(crate) enum TlsState {
Stream,
Eof,
Shutdown
}
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
End
}
impl<IO> TlsStream<IO> {
#[inline]
pub fn get_ref(&self) -> (&IO, &ServerSession) {
(&self.io, &self.session)
}
#[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) {
(&mut self.io, &mut self.session)
}
#[inline]
pub fn into_inner(self) -> (IO, ServerSession) {
(self.io, self.session)
}
}
impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite,
{
type Item = TlsStream<IO>;
type Error = io::Error;
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
MidHandshake::Handshaking(stream) => {
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);
if stream.session.is_handshaking() {
try_nb!(stream.complete_io());
}
if stream.session.wants_write() {
try_nb!(stream.complete_io());
}
},
_ => ()
}
match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
}
}
}
impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0)
}
}
}
impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
stream.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
}
}
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_nb!(stream.complete_io());
}
self.io.shutdown()
}
}

View File

@ -8,12 +8,12 @@ use std::net::ToSocketAddrs;
use self::tokio::io as aio; use self::tokio::io as aio;
use self::tokio::prelude::*; use self::tokio::prelude::*;
use self::tokio::net::TcpStream; use self::tokio::net::TcpStream;
use rustls::{ ClientConfig, ClientSession }; use rustls::ClientConfig;
use ::{ TlsConnector, TlsStream }; use ::{ TlsConnector, client::TlsStream };
fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool) fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool)
-> io::Result<(TlsStream<TcpStream, ClientSession>, String)> -> io::Result<(TlsStream<TcpStream>, String)>
{ {
let config = TlsConnector::from(config).early_data(rtt0); let config = TlsConnector::from(config).early_data(rtt0);
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);

View File

@ -1,123 +0,0 @@
use super::*;
use tokio_io::{ AsyncRead, AsyncWrite };
use futures::{Async, Future, Poll};
use common::Stream;
macro_rules! try_async {
( $e:expr ) => {
match $e {
Ok(n) => n,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
return Ok(Async::NotReady),
Err(e) => return Err(e)
}
}
}
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
type Item = TlsStream<IO, ClientSession>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.0.poll()
}
}
impl<IO: AsyncRead + AsyncWrite> Future for Accept<IO> {
type Item = TlsStream<IO, ServerSession>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.0.poll()
}
}
impl<IO, S> Future for MidHandshake<IO, S>
where
IO: AsyncRead + AsyncWrite,
S: Session
{
type Item = TlsStream<IO, S>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
MidHandshake::Handshaking(stream) => {
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);
if stream.session.is_handshaking() {
try_async!(stream.complete_io());
}
if stream.session.wants_write() {
try_async!(stream.complete_io());
}
},
_ => ()
}
match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream)
| MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
}
}
}
impl<IO> AsyncRead for TlsStream<IO, ClientSession>
where IO: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<IO> AsyncRead for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<IO> AsyncWrite for TlsStream<IO, ClientSession>
where IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
}
}
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_async!(stream.complete_io());
}
self.io.shutdown()
}
}
impl<IO> AsyncWrite for TlsStream<IO, ServerSession>
where IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
}
}
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_async!(stream.complete_io());
}
self.io.shutdown()
}
}