wip client

This commit is contained in:
quininer 2019-05-18 16:05:10 +08:00
parent 017b1b64d1
commit 41c26ee63a
4 changed files with 132 additions and 115 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "tokio-rustls"
version = "0.10.0-alpha.2"
version = "0.12.0-alpha"
authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0"
repository = "https://github.com/quininer/tokio-rustls"

View File

@ -40,160 +40,154 @@ impl<IO> TlsStream<IO> {
impl<IO> Future for MidHandshake<IO>
where
IO: AsyncRead + AsyncWrite,
IO: AsyncRead + AsyncWrite + Unpin,
{
type Item = TlsStream<IO>;
type Error = io::Error;
type Output = io::Result<TlsStream<IO>>;
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let MidHandshake::Handshaking(stream) = self {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let MidHandshake::Handshaking(stream) = &mut *self {
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);
if stream.session.is_handshaking() {
try_nb!(stream.complete_io());
try_ready!(stream.complete_io(cx));
}
if stream.session.wants_write() {
try_nb!(stream.complete_io());
try_ready!(stream.complete_io(cx));
}
}
match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
match mem::replace(&mut *self, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
#[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(),
}
}
}
impl<IO> io::Read for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
{
let mut stream = Stream::new(&mut self.io, &mut self.session);
let (pos, data) = &mut self.early_data;
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
}
self.read(buf)
}
TlsState::Stream | TlsState::WriteShutdown => {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match stream.read(buf) {
Ok(0) => {
self.state.shutdown_read();
Ok(0)
}
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify();
self.state.shutdown_write();
}
Ok(0)
}
Err(e) => Err(e),
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
}
}
}
impl<IO> io::Write for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?;
data.extend_from_slice(&buf[..len]);
return Ok(len);
}
// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}
// end
self.state = TlsState::Stream;
data.clear();
stream.write(buf)
}
_ => stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}
impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
IO: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
unsafe fn initializer(&self) -> Initializer {
// TODO
Initializer::nop()
}
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session);
let (pos, data) = &mut this.early_data;
// complete handshake
if stream.session.is_handshaking() {
try_ready!(stream.complete_io(cx));
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = try_ready!(stream.poll_write(cx, &data[*pos..]));
*pos += len;
}
}
// end
this.state = TlsState::Stream;
data.clear();
Pin::new(this).poll_read(cx, buf)
}
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session);
match stream.poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
}
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
}
}
}
impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite,
IO: AsyncRead + AsyncWrite + Unpin,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session);
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
let (pos, data) = &mut this.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?; // TODO check pending
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
}
// complete handshake
if stream.session.is_handshaking() {
try_ready!(stream.complete_io(cx));
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = try_ready!(stream.poll_write(cx, &data[*pos..]));
*pos += len;
}
}
// end
this.state = TlsState::Stream;
data.clear();
stream.poll_write(cx, buf)
}
_ => stream.poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if self.state.writeable() {
self.session.send_close_notify();
self.state.shutdown_write();
}
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_nb!(stream.flush());
stream.io.shutdown()
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session);
try_ready!(stream.poll_flush(cx));
Pin::new(&mut this.io).poll_close(cx)
}
}

View File

@ -14,6 +14,7 @@ use smallvec::SmallVec;
pub struct Stream<'a, IO, S> {
pub io: &'a mut IO,
pub session: &'a mut S,
pub eof: bool
}
pub trait WriteTls<IO: AsyncWrite, S: Session> {
@ -29,7 +30,18 @@ enum Focus {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream { io, session }
Stream {
io,
session,
// The state so far is only used to detect EOF, so either Stream
// or EarlyData state should both be all right.
eof: false,
}
}
pub fn set_eof(mut self, eof: bool) -> Self {
self.eof = eof;
self
}
pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
@ -82,7 +94,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0;
let mut rdlen = 0;
let mut eof = false;
loop {
let mut write_would_block = false;
@ -99,9 +110,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}
}
if !eof && self.session.wants_read() {
if !self.eof && self.session.wants_read() {
match self.complete_read_io(cx) {
Poll::Ready(Ok(0)) => eof = true,
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
@ -114,7 +125,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Focus::Writable => write_would_block,
};
match (eof, self.session.is_handshaking(), would_block) {
match (self.eof, self.session.is_handshaking(), would_block) {
(true, true, _) => return Poll::Pending,
(_, false, true) => {
let would_block = match focus {
@ -167,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Str
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
while self.session.wants_read() {
match self.complete_inner_io(cx, Focus::Readable) {
Poll::Ready(Ok((0, _))) => break,
@ -181,7 +192,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(self.session.read(buf))
}
fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let len = self.session.write(buf)?;
while self.session.wants_write() {
match self.complete_inner_io(cx, Focus::Writable) {
@ -204,7 +215,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}
}
fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
pub fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.flush()?;
while self.session.wants_write() {
match self.complete_inner_io(cx, Focus::Writable) {
@ -213,7 +224,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
Poll::Ready(Ok(()))
Pin::new(&mut self.io).poll_flush(cx)
}
}

View File

@ -1,16 +1,27 @@
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
// pub mod client;
macro_rules! try_ready {
( $e:expr ) => {
match $e {
Poll::Ready(Ok(output)) => output,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
Poll::Pending => return Poll::Pending
}
}
}
pub mod client;
mod common;
// pub mod server;
/*
use common::Stream;
use futures::{Async, Future, Poll};
use std::pin::Pin;
use std::task::{ Poll, Context };
use std::future::Future;
use futures::io::{ AsyncRead, AsyncWrite, Initializer };
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession};
use std::sync::Arc;
use std::{io, mem};
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
use webpki::DNSNameRef;
#[derive(Debug, Copy, Clone)]
@ -54,6 +65,7 @@ pub struct TlsConnector {
early_data: bool,
}
/*
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
#[derive(Clone)]
pub struct TlsAcceptor {