fix vecbuf

This commit is contained in:
quininer 2018-08-16 18:43:35 +08:00
parent 518ad51376
commit 26046efc3c
2 changed files with 91 additions and 36 deletions

View File

@ -12,23 +12,16 @@ pub struct Stream<'a, S: 'a, IO: 'a> {
io: &'a mut IO io: &'a mut IO
} }
pub trait CompleteIo<'a, S: Session, IO: Read + Write>: Read + Write { pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write {
fn write_tls(&mut self) -> io::Result<usize>; fn write_tls(&mut self) -> io::Result<usize>;
fn complete_io(&mut self) -> io::Result<(usize, usize)>;
} }
impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> { impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> {
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
Stream { session, io } Stream { session, io }
} }
}
impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, IO> { pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
default fn write_tls(&mut self) -> io::Result<usize> {
self.session.write_tls(self.io)
}
fn complete_io(&mut self) -> io::Result<(usize, usize)> {
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161 // fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161
let until_handshaked = self.session.is_handshaking(); let until_handshaked = self.session.is_handshaking();
@ -74,6 +67,32 @@ impl<'a, S: Session, IO: Read + Write> CompleteIo<'a, S, IO> for Stream<'a, S, I
} }
} }
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
default fn write_tls(&mut self) -> io::Result<usize> {
self.session.write_tls(self.io)
}
}
impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
fn write_tls(&mut self) -> io::Result<usize> {
struct V<'a, IO: 'a>(&'a mut IO);
impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> {
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize> {
let mut vbytes = VecBuf::new(vbytes);
match self.0.write_buf(&mut vbytes) {
Ok(Async::Ready(n)) => Ok(n),
Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()),
Err(err) => Err(err)
}
}
}
let mut vecbuf = V(self.io);
self.session.writev_tls(&mut vecbuf)
}
}
impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> { impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.session.wants_read() { while self.session.wants_read() {
@ -102,28 +121,7 @@ impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> {
} }
} }
impl<'a, S: Session, IO: Read + AsyncWrite> CompleteIo<'a, S, IO> for Stream<'a, S, IO> {
fn write_tls(&mut self) -> io::Result<usize> {
struct V<'a, IO: 'a>(&'a mut IO);
impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> {
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize> {
let mut vbytes = VecBuf::new(vbytes);
match self.0.write_buf(&mut vbytes) {
Ok(Async::Ready(n)) => Ok(n),
Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()),
Err(err) => Err(err)
}
}
}
let mut vecbuf = V(self.io);
self.session.writev_tls(&mut vecbuf)
}
}
// TODO test
struct VecBuf<'a, 'b: 'a> { struct VecBuf<'a, 'b: 'a> {
pos: usize, pos: usize,
cur: usize, cur: usize,
@ -153,14 +151,14 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn advance(&mut self, cnt: usize) { fn advance(&mut self, cnt: usize) {
let current = self.inner[self.pos].len(); let current = self.inner[self.pos].len();
match (self.cur + cnt).cmp(&current) { match (self.cur + cnt).cmp(&current) {
Ordering::Equal => { Ordering::Equal => if self.pos + 1 < self.inner.len() {
if self.pos < self.inner.len() { self.pos += 1;
self.pos += 1;
}
self.cur = 0; self.cur = 0;
} else {
self.cur += cnt;
}, },
Ordering::Greater => { Ordering::Greater => {
if self.pos < self.inner.len() { if self.pos + 1 < self.inner.len() {
self.pos += 1; self.pos += 1;
} }
let remaining = self.cur + cnt - current; let remaining = self.cur + cnt - current;
@ -180,3 +178,60 @@ impl<'a, 'b> Buf for VecBuf<'a, 'b> {
len len
} }
} }
#[cfg(test)]
mod test_vecbuf {
use super::*;
#[test]
fn test_fresh_cursor_vec() {
let mut buf = VecBuf::new(&[b"he", b"llo"]);
assert_eq!(buf.remaining(), 5);
assert_eq!(buf.bytes(), b"he");
buf.advance(2);
assert_eq!(buf.remaining(), 3);
assert_eq!(buf.bytes(), b"llo");
buf.advance(3);
assert_eq!(buf.remaining(), 0);
assert_eq!(buf.bytes(), b"");
}
#[test]
fn test_get_u8() {
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
assert_eq!(0x21, buf.get_u8());
}
#[test]
fn test_get_u16() {
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
assert_eq!(0x2154, buf.get_u16_be());
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
assert_eq!(0x5421, buf.get_u16_le());
}
#[test]
#[should_panic]
fn test_get_u16_buffer_underflow() {
let mut buf = VecBuf::new(&[b"\x21"]);
buf.get_u16_be();
}
#[test]
fn test_bufs_vec() {
let buf = VecBuf::new(&[b"he", b"llo"]);
let b1: &[u8] = &mut [0];
let b2: &[u8] = &mut [0];
let mut dst: [&IoVec; 2] =
[b1.into(), b2.into()];
assert_eq!(2, buf.bytes_vec(&mut dst[..]));
}
}

View File

@ -2,7 +2,7 @@ use super::*;
use tokio::prelude::*; use tokio::prelude::*;
use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::io::{ AsyncRead, AsyncWrite };
use tokio::prelude::Poll; use tokio::prelude::Poll;
use common::{ Stream, CompleteIo }; use common::Stream;
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> { impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {