implement WriteV
close https://github.com/quininer/tokio-rustls/issues/57
This commit is contained in:
parent
7530e2f739
commit
ce16555b13
@ -15,16 +15,17 @@ edition = "2018"
|
|||||||
github-actions = { repository = "quininer/tokio-rustls", workflow = "Rust" }
|
github-actions = { repository = "quininer/tokio-rustls", workflow = "Rust" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytes = "0.5"
|
|
||||||
tokio = "0.2.0"
|
tokio = "0.2.0"
|
||||||
futures-core = "0.3.1"
|
futures-core = "0.3.1"
|
||||||
rustls = "0.16"
|
rustls = "0.16"
|
||||||
webpki = "0.21"
|
webpki = "0.21"
|
||||||
|
|
||||||
|
bytes = { version = "0.5", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
early-data = []
|
early-data = []
|
||||||
dangerous_configuration = ["rustls/dangerous_configuration"]
|
dangerous_configuration = ["rustls/dangerous_configuration"]
|
||||||
unstable = []
|
unstable = ["bytes"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] }
|
tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] }
|
||||||
|
@ -54,7 +54,7 @@ where
|
|||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
#[cfg(feature = "unstable")]
|
#[cfg(feature = "unstable")]
|
||||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ where
|
|||||||
|
|
||||||
Poll::Ready(Ok(stream))
|
Poll::Ready(Ok(stream))
|
||||||
} else {
|
} else {
|
||||||
panic!()
|
panic!("unexpected polling after handshake")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
mod handshake;
|
mod handshake;
|
||||||
|
|
||||||
|
#[cfg(feature = "unstable")]
|
||||||
|
mod vecbuf;
|
||||||
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{ Poll, Context };
|
use std::task::{ Poll, Context };
|
||||||
use std::io::{ self, Read, Write };
|
use std::io::{ self, Read };
|
||||||
use rustls::Session;
|
use rustls::Session;
|
||||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
use tokio::io::{ AsyncRead, AsyncWrite };
|
||||||
use futures_core as futures;
|
use futures_core as futures;
|
||||||
@ -23,7 +26,8 @@ impl TlsState {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn shutdown_read(&mut self) {
|
pub fn shutdown_read(&mut self) {
|
||||||
match *self {
|
match *self {
|
||||||
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
TlsState::WriteShutdown | TlsState::FullyShutdown =>
|
||||||
|
*self = TlsState::FullyShutdown,
|
||||||
_ => *self = TlsState::ReadShutdown,
|
_ => *self = TlsState::ReadShutdown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -31,7 +35,8 @@ impl TlsState {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn shutdown_write(&mut self) {
|
pub fn shutdown_write(&mut self) {
|
||||||
match *self {
|
match *self {
|
||||||
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
TlsState::ReadShutdown | TlsState::FullyShutdown =>
|
||||||
|
*self = TlsState::FullyShutdown,
|
||||||
_ => *self = TlsState::WriteShutdown,
|
_ => *self = TlsState::WriteShutdown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -132,7 +137,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
Poll::Ready(Ok(n))
|
Poll::Ready(Ok(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "unstable"))]
|
||||||
pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
|
pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
struct Writer<'a, 'b, T> {
|
struct Writer<'a, 'b, T> {
|
||||||
io: &'a mut T,
|
io: &'a mut T,
|
||||||
cx: &'a mut Context<'b>
|
cx: &'a mut Context<'b>
|
||||||
@ -162,6 +170,36 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "unstable")]
|
||||||
|
pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
|
||||||
|
use rustls::WriteV;
|
||||||
|
|
||||||
|
struct Writer<'a, 'b, T> {
|
||||||
|
io: &'a mut T,
|
||||||
|
cx: &'a mut Context<'b>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, T: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, T> {
|
||||||
|
fn writev(&mut self, vbuf: &[&[u8]]) -> io::Result<usize> {
|
||||||
|
use vecbuf::VecBuf;
|
||||||
|
|
||||||
|
let mut vbuf = VecBuf::new(vbuf);
|
||||||
|
|
||||||
|
match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) {
|
||||||
|
Poll::Ready(result) => result,
|
||||||
|
Poll::Pending => Err(io::ErrorKind::WouldBlock.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut writer = Writer { io: self.io, cx };
|
||||||
|
|
||||||
|
match self.session.writev_tls(&mut writer) {
|
||||||
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
|
||||||
|
result => Poll::Ready(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
|
pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
|
||||||
let mut wrlen = 0;
|
let mut wrlen = 0;
|
||||||
let mut rdlen = 0;
|
let mut rdlen = 0;
|
||||||
|
128
src/common/vecbuf.rs
Normal file
128
src/common/vecbuf.rs
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
use std::io::IoSlice;
|
||||||
|
use std::cmp::{ self, Ordering };
|
||||||
|
use bytes::Buf;
|
||||||
|
|
||||||
|
|
||||||
|
pub struct VecBuf<'a, 'b: 'a> {
|
||||||
|
pos: usize,
|
||||||
|
cur: usize,
|
||||||
|
inner: &'a [&'b [u8]]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> VecBuf<'a, 'b> {
|
||||||
|
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
|
||||||
|
VecBuf { pos: 0, cur: 0, inner: vbytes }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
||||||
|
fn remaining(&self) -> usize {
|
||||||
|
let sum = self.inner
|
||||||
|
.iter()
|
||||||
|
.skip(self.pos)
|
||||||
|
.map(|bytes| bytes.len())
|
||||||
|
.sum::<usize>();
|
||||||
|
sum - self.cur
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bytes(&self) -> &[u8] {
|
||||||
|
&self.inner[self.pos][self.cur..]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn advance(&mut self, cnt: usize) {
|
||||||
|
let current = self.inner[self.pos].len();
|
||||||
|
match (self.cur + cnt).cmp(¤t) {
|
||||||
|
Ordering::Equal => if self.pos + 1 < self.inner.len() {
|
||||||
|
self.pos += 1;
|
||||||
|
self.cur = 0;
|
||||||
|
} else {
|
||||||
|
self.cur += cnt;
|
||||||
|
},
|
||||||
|
Ordering::Greater => {
|
||||||
|
if self.pos + 1 < self.inner.len() {
|
||||||
|
self.pos += 1;
|
||||||
|
}
|
||||||
|
let remaining = self.cur + cnt - current;
|
||||||
|
self.advance(remaining);
|
||||||
|
},
|
||||||
|
Ordering::Less => self.cur += cnt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::needless_range_loop)]
|
||||||
|
fn bytes_vectored<'c>(&'c self, dst: &mut [IoSlice<'c>]) -> usize {
|
||||||
|
let len = cmp::min(self.inner.len() - self.pos, dst.len());
|
||||||
|
|
||||||
|
if len > 0 {
|
||||||
|
dst[0] = IoSlice::new(self.bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 1..len {
|
||||||
|
dst[i] = IoSlice::new(&self.inner[self.pos + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
len
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test_vecbuf {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_fresh_cursor_vec() {
|
||||||
|
let mut buf = VecBuf::new(&[b"he", b"llo"]);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 5);
|
||||||
|
assert_eq!(buf.bytes(), b"he");
|
||||||
|
|
||||||
|
buf.advance(1);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 4);
|
||||||
|
assert_eq!(buf.bytes(), b"e");
|
||||||
|
|
||||||
|
buf.advance(1);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 3);
|
||||||
|
assert_eq!(buf.bytes(), b"llo");
|
||||||
|
|
||||||
|
buf.advance(3);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 0);
|
||||||
|
assert_eq!(buf.bytes(), b"");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_u8() {
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
|
||||||
|
assert_eq!(0x21, buf.get_u8());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_u16() {
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||||
|
assert_eq!(0x2154, buf.get_u16());
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||||
|
assert_eq!(0x5421, buf.get_u16_le());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_get_u16_buffer_underflow() {
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21"]);
|
||||||
|
buf.get_u16();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bufs_vec() {
|
||||||
|
let buf = VecBuf::new(&[b"he", b"llo"]);
|
||||||
|
|
||||||
|
let b1: &[u8] = &mut [0];
|
||||||
|
let b2: &[u8] = &mut [0];
|
||||||
|
|
||||||
|
let mut dst: [IoSlice; 2] =
|
||||||
|
[IoSlice::new(b1), IoSlice::new(b2)];
|
||||||
|
|
||||||
|
assert_eq!(2, buf.bytes_vectored(&mut dst[..]));
|
||||||
|
}
|
||||||
|
}
|
@ -51,8 +51,8 @@ impl From<Arc<ServerConfig>> for TlsAcceptor {
|
|||||||
impl TlsConnector {
|
impl TlsConnector {
|
||||||
/// Enable 0-RTT.
|
/// Enable 0-RTT.
|
||||||
///
|
///
|
||||||
/// Note that you want to use 0-RTT.
|
/// If you want to use 0-RTT,
|
||||||
/// You must set `enable_early_data` to `true` in `ClientConfig`.
|
/// You must also set `ClientConfig.enable_early_data` to `true`.
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
pub fn early_data(mut self, flag: bool) -> TlsConnector {
|
pub fn early_data(mut self, flag: bool) -> TlsConnector {
|
||||||
self.early_data = flag;
|
self.early_data = flag;
|
||||||
@ -158,6 +158,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Connect<IO> {
|
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Connect<IO> {
|
||||||
|
#[inline]
|
||||||
fn is_terminated(&self) -> bool {
|
fn is_terminated(&self) -> bool {
|
||||||
self.0.is_terminated()
|
self.0.is_terminated()
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ where
|
|||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
#[cfg(feature = "unstable")]
|
#[cfg(feature = "unstable")]
|
||||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
||||||
// TODO
|
// TODO
|
||||||
//
|
//
|
||||||
// https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer
|
// https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer
|
||||||
|
Loading…
Reference in New Issue
Block a user