From ce16555b13c556277edec825310696be2ad29930 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 7 Jan 2020 23:57:00 +0800 Subject: [PATCH] implement WriteV close https://github.com/quininer/tokio-rustls/issues/57 --- Cargo.toml | 5 +- src/client.rs | 2 +- src/common/handshake.rs | 2 +- src/common/mod.rs | 44 +++++++++++++- src/common/vecbuf.rs | 128 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 5 +- src/server.rs | 2 +- 7 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 src/common/vecbuf.rs diff --git a/Cargo.toml b/Cargo.toml index 6ed6429..52964d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,16 +15,17 @@ edition = "2018" github-actions = { repository = "quininer/tokio-rustls", workflow = "Rust" } [dependencies] -bytes = "0.5" tokio = "0.2.0" futures-core = "0.3.1" rustls = "0.16" webpki = "0.21" +bytes = { version = "0.5", optional = true } + [features] early-data = [] dangerous_configuration = ["rustls/dangerous_configuration"] -unstable = [] +unstable = ["bytes"] [dev-dependencies] tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] } diff --git a/src/client.rs b/src/client.rs index 25d5874..5007aa8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -54,7 +54,7 @@ where IO: AsyncRead + AsyncWrite + Unpin, { #[cfg(feature = "unstable")] - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { false } diff --git a/src/common/handshake.rs b/src/common/handshake.rs index 0006b56..c59541e 100644 --- a/src/common/handshake.rs +++ b/src/common/handshake.rs @@ -78,7 +78,7 @@ where Poll::Ready(Ok(stream)) } else { - panic!() + panic!("unexpected polling after handshake") } } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 9f6d9ac..1d0dd07 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,8 +1,11 @@ mod handshake; +#[cfg(feature = "unstable")] +mod vecbuf; + use std::pin::Pin; use std::task::{ Poll, Context }; -use std::io::{ self, Read, Write }; +use std::io::{ self, Read }; use rustls::Session; use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; @@ -23,7 +26,8 @@ impl TlsState { #[inline] pub fn shutdown_read(&mut self) { match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + TlsState::WriteShutdown | TlsState::FullyShutdown => + *self = TlsState::FullyShutdown, _ => *self = TlsState::ReadShutdown, } } @@ -31,7 +35,8 @@ impl TlsState { #[inline] pub fn shutdown_write(&mut self) { match *self { - TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + TlsState::ReadShutdown | TlsState::FullyShutdown => + *self = TlsState::FullyShutdown, _ => *self = TlsState::WriteShutdown, } } @@ -132,7 +137,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } + #[cfg(not(feature = "unstable"))] pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + use std::io::Write; + struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -162,6 +170,36 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } + #[cfg(feature = "unstable")] + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + use rustls::WriteV; + + struct Writer<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> { + fn writev(&mut self, vbuf: &[&[u8]]) -> io::Result { + use vecbuf::VecBuf; + + let mut vbuf = VecBuf::new(vbuf); + + match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut writer = Writer { io: self.io, cx }; + + match self.session.writev_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs new file mode 100644 index 0000000..6ea19e3 --- /dev/null +++ b/src/common/vecbuf.rs @@ -0,0 +1,128 @@ +use std::io::IoSlice; +use std::cmp::{ self, Ordering }; +use bytes::Buf; + + +pub struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + pub fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; + self.cur = 0; + } else { + self.cur += cnt; + }, + Ordering::Greater => { + if self.pos + 1 < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + #[allow(clippy::needless_range_loop)] + fn bytes_vectored<'c>(&'c self, dst: &mut [IoSlice<'c>]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + if len > 0 { + dst[0] = IoSlice::new(self.bytes()); + } + + for i in 1..len { + dst[i] = IoSlice::new(&self.inner[self.pos + i]); + } + + len + } +} + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(1); + + assert_eq!(buf.remaining(), 4); + assert_eq!(buf.bytes(), b"e"); + + buf.advance(1); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [IoSlice; 2] = + [IoSlice::new(b1), IoSlice::new(b2)]; + + assert_eq!(2, buf.bytes_vectored(&mut dst[..])); + } +} diff --git a/src/lib.rs b/src/lib.rs index 28d9de1..db34b07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,8 +51,8 @@ impl From> for TlsAcceptor { impl TlsConnector { /// Enable 0-RTT. /// - /// Note that you want to use 0-RTT. - /// You must set `enable_early_data` to `true` in `ClientConfig`. + /// If you want to use 0-RTT, + /// You must also set `ClientConfig.enable_early_data` to `true`. #[cfg(feature = "early-data")] pub fn early_data(mut self, flag: bool) -> TlsConnector { self.early_data = flag; @@ -158,6 +158,7 @@ impl Future for Connect { } impl FusedFuture for Connect { + #[inline] fn is_terminated(&self) -> bool { self.0.is_terminated() } diff --git a/src/server.rs b/src/server.rs index aa7164e..abf86d6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -53,7 +53,7 @@ where IO: AsyncRead + AsyncWrite + Unpin, { #[cfg(feature = "unstable")] - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { // TODO // // https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer