wip server

This commit is contained in:
quininer 2019-05-18 18:18:26 +08:00
parent 41c26ee63a
commit 4cc374fd4c
4 changed files with 102 additions and 86 deletions

View File

@ -45,10 +45,13 @@ where
type Output = io::Result<TlsStream<IO>>; type Output = io::Result<TlsStream<IO>>;
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let MidHandshake::Handshaking(stream) = &mut *self { let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this {
let eof = !stream.state.readable();
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session); let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() { if stream.session.is_handshaking() {
try_ready!(stream.complete_io(cx)); try_ready!(stream.complete_io(cx));
@ -59,7 +62,7 @@ where
} }
} }
match mem::replace(&mut *self, MidHandshake::End) { match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
@ -83,7 +86,8 @@ where
TlsState::EarlyData => { TlsState::EarlyData => {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session); let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let (pos, data) = &mut this.early_data; let (pos, data) = &mut this.early_data;
// complete handshake // complete handshake
@ -107,7 +111,8 @@ where
} }
TlsState::Stream | TlsState::WriteShutdown => { TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session); let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
match stream.poll_read(cx, buf) { match stream.poll_read(cx, buf) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
@ -136,9 +141,10 @@ impl<IO> AsyncWrite for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session); let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
match this.state { match this.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
@ -174,9 +180,11 @@ where
} }
} }
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session).poll_flush(cx) Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable())
.poll_flush(cx)
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
@ -186,7 +194,8 @@ where
} }
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session); let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
try_ready!(stream.poll_flush(cx)); try_ready!(stream.poll_flush(cx));
Pin::new(&mut this.io).poll_close(cx) Pin::new(&mut this.io).poll_close(cx)
} }

View File

@ -3,7 +3,7 @@
use std::pin::Pin; use std::pin::Pin;
use std::task::Poll; use std::task::Poll;
use std::marker::Unpin; use std::marker::Unpin;
use std::io::{ self, Read, Write }; use std::io::{ self, Read };
use rustls::Session; use rustls::Session;
use rustls::WriteV; use rustls::WriteV;
use futures::task::Context; use futures::task::Context;

View File

@ -12,7 +12,7 @@ macro_rules! try_ready {
pub mod client; pub mod client;
mod common; mod common;
// pub mod server; pub mod server;
use common::Stream; use common::Stream;
use std::pin::Pin; use std::pin::Pin;
@ -25,7 +25,7 @@ use std::{io, mem};
use webpki::DNSNameRef; use webpki::DNSNameRef;
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub enum TlsState { enum TlsState {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
EarlyData, EarlyData,
Stream, Stream,
@ -35,26 +35,33 @@ pub enum TlsState {
} }
impl TlsState { impl TlsState {
pub(crate) fn shutdown_read(&mut self) { fn shutdown_read(&mut self) {
match *self { match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown, _ => *self = TlsState::ReadShutdown,
} }
} }
pub(crate) fn shutdown_write(&mut self) { fn shutdown_write(&mut self) {
match *self { match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown, _ => *self = TlsState::WriteShutdown,
} }
} }
pub(crate) fn writeable(&self) -> bool { fn writeable(&self) -> bool {
match *self { match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => false, TlsState::WriteShutdown | TlsState::FullyShutdown => false,
_ => true, _ => true,
} }
} }
fn readable(self) -> bool {
match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
_ => true,
}
}
} }
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
@ -65,7 +72,6 @@ pub struct TlsConnector {
early_data: bool, early_data: bool,
} }
/*
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
#[derive(Clone)] #[derive(Clone)]
pub struct TlsAcceptor { pub struct TlsAcceptor {
@ -170,32 +176,31 @@ impl TlsAcceptor {
} }
} }
/// Future returned from `ClientConfigExt::connect_async` which will resolve /// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished. /// once the connection handshake has finished.
pub struct Connect<IO>(client::MidHandshake<IO>); pub struct Connect<IO>(client::MidHandshake<IO>);
/// Future returned from `ServerConfigExt::accept_async` which will resolve /// Future returned from `TlsAcceptor::accept` which will resolve
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<IO>(server::MidHandshake<IO>); pub struct Accept<IO>(server::MidHandshake<IO>);
impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
type Item = client::TlsStream<IO>; type Output = io::Result<client::TlsStream<IO>>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll() Pin::new(&mut self.0).poll(cx)
} }
} }
impl<IO: AsyncRead + AsyncWrite> Future for Accept<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
type Item = server::TlsStream<IO>; type Output = io::Result<server::TlsStream<IO>>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll() Pin::new(&mut self.0).poll(cx)
} }
} }
/*
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
#[cfg(test)] #[cfg(test)]
mod test_0rtt; mod test_0rtt;

View File

@ -34,100 +34,102 @@ impl<IO> TlsStream<IO> {
impl<IO> Future for MidHandshake<IO> impl<IO> Future for MidHandshake<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
type Item = TlsStream<IO>; type Output = io::Result<TlsStream<IO>>;
type Error = io::Error;
#[inline] #[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let MidHandshake::Handshaking(stream) = self { let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this {
let eof = !stream.state.readable();
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session); let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() { if stream.session.is_handshaking() {
try_nb!(stream.complete_io()); try_ready!(stream.complete_io(cx));
} }
if stream.session.wants_write() { if stream.session.wants_write() {
try_nb!(stream.complete_io()); try_ready!(stream.complete_io(cx));
} }
} }
match mem::replace(self, MidHandshake::End) { match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(), MidHandshake::End => panic!(),
} }
} }
} }
impl<IO> io::Read for TlsStream<IO> impl<IO> AsyncRead for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { unsafe fn initializer(&self) -> Initializer {
let mut stream = Stream::new(&mut self.io, &mut self.session); // TODO
Initializer::nop()
}
match self.state { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) { let this = self.get_mut();
Ok(0) => { let mut stream = Stream::new(&mut this.io, &mut this.session)
self.state.shutdown_read(); .set_eof(!this.state.readable());
Ok(0)
match this.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
} }
Ok(n) => Ok(n), Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
self.state.shutdown_read(); this.state.shutdown_read();
if self.state.writeable() { if this.state.writeable() {
stream.session.send_close_notify(); stream.session.send_close_notify();
self.state.shutdown_write(); this.state.shutdown_write();
} }
Ok(0) Poll::Ready(Ok(0))
} }
Err(e) => Err(e), Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending
}, },
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
s => unreachable!("server TLS can not hit this state: {:?}", s), s => unreachable!("server TLS can not hit this state: {:?}", s),
} }
} }
} }
impl<IO> io::Write for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
stream.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}
impl<IO> AsyncWrite for TlsStream<IO> impl<IO> AsyncWrite for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable())
.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable())
.poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if self.state.writeable() { if self.state.writeable() {
self.session.send_close_notify(); self.session.send_close_notify();
self.state.shutdown_write(); self.state.shutdown_write();
} }
let mut stream = Stream::new(&mut self.io, &mut self.session); let this = self.get_mut();
try_nb!(stream.complete_io()); let mut stream = Stream::new(&mut this.io, &mut this.session)
stream.io.shutdown() .set_eof(!this.state.readable());
try_ready!(stream.complete_io(cx));
Pin::new(&mut this.io).poll_close(cx)
} }
} }