diff --git a/src/common.rs b/src/common.rs index df83537..070caeb 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,23 +12,16 @@ pub struct Stream<'a, S: 'a, IO: 'a> { io: &'a mut IO } -pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { +pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { fn write_tls(&mut self) -> io::Result; - fn complete_io(&mut self) -> io::Result<(usize, usize)>; } impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { Stream { session, io } } -} -impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { - default fn write_tls(&mut self) -> io::Result { - self.session.write_tls(self.io) - } - - fn complete_io(&mut self) -> io::Result<(usize, usize)> { + pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 let until_handshaked = self.session.is_handshaking(); @@ -74,6 +67,32 @@ impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, I } } +impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + default fn write_tls(&mut self) -> io::Result { + self.session.write_tls(self.io) + } +} + +impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { + fn write_tls(&mut self) -> io::Result { + struct V<'a, IO: 'a>(&'a mut IO); + + impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { + let mut vbytes = VecBuf::new(vbytes); + match self.0.write_buf(&mut vbytes) { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), + Err(err) => Err(err) + } + } + } + + let mut vecbuf = V(self.io); + self.session.writev_tls(&mut vecbuf) + } +} + impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { @@ -102,28 +121,7 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { - fn write_tls(&mut self) -> io::Result { - struct V<'a, IO: 'a>(&'a mut IO); - impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { - fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { - let mut vbytes = VecBuf::new(vbytes); - match self.0.write_buf(&mut vbytes) { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), - Err(err) => Err(err) - } - } - } - - let mut vecbuf = V(self.io); - self.session.writev_tls(&mut vecbuf) - } -} - - -// TODO test struct VecBuf<'a, 'b: 'a> { pos: usize, cur: usize, @@ -153,14 +151,14 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { fn advance(&mut self, cnt: usize) { let current = self.inner[self.pos].len(); match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => { - if self.pos < self.inner.len() { - self.pos += 1; - } + Ordering::Equal => if self.pos + 1 < self.inner.len() { + self.pos += 1; self.cur = 0; + } else { + self.cur += cnt; }, Ordering::Greater => { - if self.pos < self.inner.len() { + if self.pos + 1 < self.inner.len() { self.pos += 1; } let remaining = self.cur + cnt - current; @@ -180,3 +178,60 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> { len } } + +#[cfg(test)] +mod test_vecbuf { + use super::*; + + #[test] + fn test_fresh_cursor_vec() { + let mut buf = VecBuf::new(&[b"he", b"llo"]); + + assert_eq!(buf.remaining(), 5); + assert_eq!(buf.bytes(), b"he"); + + buf.advance(2); + + assert_eq!(buf.remaining(), 3); + assert_eq!(buf.bytes(), b"llo"); + + buf.advance(3); + + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.bytes(), b""); + } + + #[test] + fn test_get_u8() { + let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); + assert_eq!(0x21, buf.get_u8()); + } + + #[test] + fn test_get_u16() { + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x2154, buf.get_u16_be()); + let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); + assert_eq!(0x5421, buf.get_u16_le()); + } + + #[test] + #[should_panic] + fn test_get_u16_buffer_underflow() { + let mut buf = VecBuf::new(&[b"\x21"]); + buf.get_u16_be(); + } + + #[test] + fn test_bufs_vec() { + let buf = VecBuf::new(&[b"he", b"llo"]); + + let b1: &[u8] = &mut [0]; + let b2: &[u8] = &mut [0]; + + let mut dst: [&IoVec; 2] = + [b1.into(), b2.into()]; + + assert_eq!(2, buf.bytes_vec(&mut dst[..])); + } +} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 663d6ca..9f09705 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -2,7 +2,7 @@ use super::*; use tokio::prelude::*; use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::prelude::Poll; -use common::{ Stream, CompleteIo }; +use common::Stream; impl Future for ConnectAsync {