diff --git a/Cargo.toml b/Cargo.toml index e289e0f..94edce4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ appveyor = { repository = "quininer/tokio-rustls" } futures-core = { version = "0.2.0", optional = true } futures-io = { version = "0.2.0", optional = true } tokio = { version = "0.1.6", optional = true } +bytes = { version = "*" } +iovec = { version = "*" } rustls = "0.13" webpki = "0.18.1" diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..799ee7e --- /dev/null +++ b/src/common.rs @@ -0,0 +1,99 @@ +use std::cmp::{ self, Ordering }; +use std::io::{ self, Read, Write }; +use rustls::{ Session, WriteV }; +use tokio::prelude::Async; +use tokio::io::AsyncWrite; +use bytes::Buf; +use iovec::IoVec; + + +pub struct Stream<'a, S: 'a, IO: 'a> { + session: &'a mut S, + io: &'a mut IO +} + +/* +impl<'a, S: Session, IO: Write> Stream<'a, S, IO> { + pub default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} +*/ + +impl<'a, S: Session, IO: AsyncWrite> Stream<'a, S, IO> { + pub fn write_tls(&mut self) -> io::Result { + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + + +struct VecBuf<'a, 'b: 'a> { + pos: usize, + cur: usize, + inner: &'a [&'b [u8]] +} + +impl<'a, 'b> VecBuf<'a, 'b> { + fn new(vbytes: &'a [&'b [u8]]) -> Self { + VecBuf { pos: 0, cur: 0, inner: vbytes } + } +} + +impl<'a, 'b> Buf for VecBuf<'a, 'b> { + fn remaining(&self) -> usize { + let sum = self.inner + .iter() + .skip(self.pos) + .map(|bytes| bytes.len()) + .sum::(); + sum - self.cur + } + + fn bytes(&self) -> &[u8] { + &self.inner[self.pos][self.cur..] + } + + fn advance(&mut self, cnt: usize) { + let current = self.inner[self.pos].len(); + match (self.cur + cnt).cmp(¤t) { + Ordering::Equal => { + if self.pos < self.inner.len() { + self.pos += 1; + } + self.cur = 0; + }, + Ordering::Greater => { + if self.pos < self.inner.len() { + self.pos += 1; + } + let remaining = self.cur + cnt - current; + self.advance(remaining); + }, + Ordering::Less => self.cur += cnt, + } + } + + fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { + let len = cmp::min(self.inner.len() - self.pos, dst.len()); + + for i in 0..len { + dst[i] = self.inner[self.pos + i].into(); + } + + len + } +} diff --git a/src/lib.rs b/src/lib.rs index d1c6c7d..81da5fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,12 @@ pub extern crate rustls; pub extern crate webpki; +extern crate tokio; +extern crate bytes; +extern crate iovec; + + +mod common; #[cfg(feature = "tokio")] mod tokio_impl; #[cfg(feature = "unstable-futures")] mod futures_impl; diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 936c14b..e9a00a9 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,9 +1,7 @@ -extern crate tokio; - use super::*; -use self::tokio::prelude::*; -use self::tokio::io::{ AsyncRead, AsyncWrite }; -use self::tokio::prelude::Poll; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; impl Future for ConnectAsync {