wip client

This commit is contained in:
quininer 2019-05-18 16:05:10 +08:00
parent 017b1b64d1
commit 41c26ee63a
4 changed files with 132 additions and 115 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.10.0-alpha.2" version = "0.12.0-alpha"
authors = ["quininer kel <quininer@live.com>"] authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
repository = "https://github.com/quininer/tokio-rustls" repository = "https://github.com/quininer/tokio-rustls"

View File

@ -40,160 +40,154 @@ impl<IO> TlsStream<IO> {
impl<IO> Future for MidHandshake<IO> impl<IO> Future for MidHandshake<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
type Item = TlsStream<IO>; type Output = io::Result<TlsStream<IO>>;
type Error = io::Error;
#[inline] #[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let MidHandshake::Handshaking(stream) = self { if let MidHandshake::Handshaking(stream) = &mut *self {
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session); let mut stream = Stream::new(io, session);
if stream.session.is_handshaking() { if stream.session.is_handshaking() {
try_nb!(stream.complete_io()); try_ready!(stream.complete_io(cx));
} }
if stream.session.wants_write() { if stream.session.wants_write() {
try_nb!(stream.complete_io()); try_ready!(stream.complete_io(cx));
} }
} }
match mem::replace(self, MidHandshake::End) { match mem::replace(&mut *self, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(), MidHandshake::End => panic!(),
} }
} }
} }
impl<IO> io::Read for TlsStream<IO> impl<IO> AsyncRead for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { unsafe fn initializer(&self) -> Initializer {
match self.state { // TODO
#[cfg(feature = "early-data")] Initializer::nop()
TlsState::EarlyData => {
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
let (pos, data) = &mut self.early_data;
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
}
self.read(buf)
}
TlsState::Stream | TlsState::WriteShutdown => {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match stream.read(buf) {
Ok(0) => {
self.state.shutdown_read();
Ok(0)
}
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify();
self.state.shutdown_write();
}
Ok(0)
}
Err(e) => Err(e),
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
}
} }
}
impl<IO> io::Write for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state { match self.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData => { TlsState::EarlyData => {
let (pos, data) = &mut self.early_data; let this = self.get_mut();
// write early data let mut stream = Stream::new(&mut this.io, &mut this.session);
if let Some(mut early_data) = stream.session.early_data() { let (pos, data) = &mut this.early_data;
let len = early_data.write(buf)?;
data.extend_from_slice(&buf[..len]);
return Ok(len);
}
// complete handshake // complete handshake
if stream.session.is_handshaking() { if stream.session.is_handshaking() {
stream.complete_io()?; try_ready!(stream.complete_io(cx));
} }
// write early data (fallback) // write early data (fallback)
if !stream.session.is_early_data_accepted() { if !stream.session.is_early_data_accepted() {
while *pos < data.len() { while *pos < data.len() {
let len = stream.write(&data[*pos..])?; let len = try_ready!(stream.poll_write(cx, &data[*pos..]));
*pos += len; *pos += len;
} }
} }
// end // end
self.state = TlsState::Stream; this.state = TlsState::Stream;
data.clear(); data.clear();
stream.write(buf)
Pin::new(this).poll_read(cx, buf)
} }
_ => stream.write(buf), TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session);
match stream.poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
}
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
} }
} }
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
} }
impl<IO> AsyncWrite for TlsStream<IO> impl<IO> AsyncWrite for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session);
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
let (pos, data) = &mut this.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?; // TODO check pending
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
}
// complete handshake
if stream.session.is_handshaking() {
try_ready!(stream.complete_io(cx));
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = try_ready!(stream.poll_write(cx, &data[*pos..]));
*pos += len;
}
}
// end
this.state = TlsState::Stream;
data.clear();
stream.poll_write(cx, buf)
}
_ => stream.poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if self.state.writeable() { if self.state.writeable() {
self.session.send_close_notify(); self.session.send_close_notify();
self.state.shutdown_write(); self.state.shutdown_write();
} }
let mut stream = Stream::new(&mut self.io, &mut self.session); let this = self.get_mut();
try_nb!(stream.flush()); let mut stream = Stream::new(&mut this.io, &mut this.session);
stream.io.shutdown() try_ready!(stream.poll_flush(cx));
Pin::new(&mut this.io).poll_close(cx)
} }
} }

View File

@ -14,6 +14,7 @@ use smallvec::SmallVec;
pub struct Stream<'a, IO, S> { pub struct Stream<'a, IO, S> {
pub io: &'a mut IO, pub io: &'a mut IO,
pub session: &'a mut S, pub session: &'a mut S,
pub eof: bool
} }
pub trait WriteTls<IO: AsyncWrite, S: Session> { pub trait WriteTls<IO: AsyncWrite, S: Session> {
@ -29,7 +30,18 @@ enum Focus {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream { io, session } Stream {
io,
session,
// The state so far is only used to detect EOF, so either Stream
// or EarlyData state should both be all right.
eof: false,
}
}
pub fn set_eof(mut self, eof: bool) -> Self {
self.eof = eof;
self
} }
pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> { pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
@ -82,7 +94,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll<io::Result<(usize, usize)>> { fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0; let mut wrlen = 0;
let mut rdlen = 0; let mut rdlen = 0;
let mut eof = false;
loop { loop {
let mut write_would_block = false; let mut write_would_block = false;
@ -99,9 +110,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
} }
if !eof && self.session.wants_read() { if !self.eof && self.session.wants_read() {
match self.complete_read_io(cx) { match self.complete_read_io(cx) {
Poll::Ready(Ok(0)) => eof = true, Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n, Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true, Poll::Pending => read_would_block = true,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
@ -114,7 +125,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Focus::Writable => write_would_block, Focus::Writable => write_would_block,
}; };
match (eof, self.session.is_handshaking(), would_block) { match (self.eof, self.session.is_handshaking(), would_block) {
(true, true, _) => return Poll::Pending, (true, true, _) => return Poll::Pending,
(_, false, true) => { (_, false, true) => {
let would_block = match focus { let would_block = match focus {
@ -167,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Str
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
while self.session.wants_read() { while self.session.wants_read() {
match self.complete_inner_io(cx, Focus::Readable) { match self.complete_inner_io(cx, Focus::Readable) {
Poll::Ready(Ok((0, _))) => break, Poll::Ready(Ok((0, _))) => break,
@ -181,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(self.session.read(buf)) Poll::Ready(self.session.read(buf))
} }
fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let len = self.session.write(buf)?; let len = self.session.write(buf)?;
while self.session.wants_write() { while self.session.wants_write() {
match self.complete_inner_io(cx, Focus::Writable) { match self.complete_inner_io(cx, Focus::Writable) {
@ -204,7 +215,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
} }
fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { pub fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.flush()?; self.session.flush()?;
while self.session.wants_write() { while self.session.wants_write() {
match self.complete_inner_io(cx, Focus::Writable) { match self.complete_inner_io(cx, Focus::Writable) {
@ -213,7 +224,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
} }
} }
Poll::Ready(Ok(())) Pin::new(&mut self.io).poll_flush(cx)
} }
} }

View File

@ -1,16 +1,27 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
// pub mod client; macro_rules! try_ready {
( $e:expr ) => {
match $e {
Poll::Ready(Ok(output)) => output,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
Poll::Pending => return Poll::Pending
}
}
}
pub mod client;
mod common; mod common;
// pub mod server; // pub mod server;
/*
use common::Stream; use common::Stream;
use futures::{Async, Future, Poll}; use std::pin::Pin;
use std::task::{ Poll, Context };
use std::future::Future;
use futures::io::{ AsyncRead, AsyncWrite, Initializer };
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession};
use std::sync::Arc; use std::sync::Arc;
use std::{io, mem}; use std::{io, mem};
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
use webpki::DNSNameRef; use webpki::DNSNameRef;
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -54,6 +65,7 @@ pub struct TlsConnector {
early_data: bool, early_data: bool,
} }
/*
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
#[derive(Clone)] #[derive(Clone)]
pub struct TlsAcceptor { pub struct TlsAcceptor {