Revert "refactor: separate read and write"

This reverts commit 4dd5e19a19.
This commit is contained in:
quininer 2019-10-01 10:24:22 +08:00
parent 315d927473
commit 4109c34207
6 changed files with 147 additions and 96 deletions

View File

@ -53,11 +53,11 @@ where
let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
futures::ready!(stream.complete_io(cx))?;
}
if stream.session.wants_write() {
futures::ready!(stream.handshake(cx))?;
futures::ready!(stream.complete_io(cx))?;
}
}
@ -81,7 +81,32 @@ where
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => Poll::Pending,
TlsState::EarlyData => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
let (pos, data) = &mut this.early_data;
// complete handshake
if stream.session.is_handshaking() {
futures::ready!(stream.complete_io(cx))?;
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}
// end
this.state = TlsState::Stream;
data.clear();
Pin::new(this).poll_read(cx, buf)
}
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
@ -91,7 +116,7 @@ where
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
},
}
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
@ -100,8 +125,9 @@ where
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
},
output => output
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
@ -127,7 +153,7 @@ where
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match dbg!(early_data.write(buf)) {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
return Poll::Pending,
@ -139,7 +165,7 @@ where
// complete handshake
if stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
futures::ready!(stream.complete_io(cx))?;
}
// write early data (fallback)
@ -163,14 +189,6 @@ where
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
#[cfg(feature = "early-data")] {
// complete handshake
if stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}
}
stream.as_mut_pin().poll_flush(cx)
}
@ -183,11 +201,6 @@ where
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
// TODO
//
// should we complete the handshake?
stream.as_mut_pin().poll_shutdown(cx)
}
}

View File

@ -13,6 +13,17 @@ pub struct Stream<'a, IO, S> {
pub eof: bool
}
trait WriteTls<IO: AsyncWrite, S: Session> {
fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize>;
}
#[derive(Clone, Copy)]
enum Focus {
Empty,
Readable,
Writable
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream {
@ -33,7 +44,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Pin::new(self)
}
fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
self.complete_inner_io(cx, Focus::Empty)
}
fn complete_read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
@ -61,7 +76,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let _ = self.write_io(cx);
let _ = self.write_tls(cx);
io::Error::new(io::ErrorKind::InvalidData, err)
})?;
@ -69,7 +84,85 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(n))
}
fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
fn complete_write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
match self.write_tls(cx) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
}
}
fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0;
let mut rdlen = 0;
loop {
let mut write_would_block = false;
let mut read_would_block = false;
while self.session.wants_write() {
match self.complete_write_io(cx) {
Poll::Ready(Ok(n)) => wrlen += n,
Poll::Pending => {
write_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
if let Focus::Writable = focus {
if !write_would_block {
return Poll::Ready(Ok((rdlen, wrlen)));
} else {
return Poll::Pending;
}
}
if !self.eof && self.session.wants_read() {
match self.complete_read_io(cx) {
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
let would_block = match focus {
Focus::Empty => write_would_block || read_would_block,
Focus::Readable => read_would_block,
Focus::Writable => write_would_block,
};
match (self.eof, self.session.is_handshaking(), would_block) {
(true, true, _) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
return Poll::Ready(Err(err));
},
(_, false, true) => {
let would_block = match focus {
Focus::Empty => rdlen == 0 && wrlen == 0,
Focus::Readable => rdlen == 0,
Focus::Writable => wrlen == 0
};
return if would_block {
Poll::Pending
} else {
Poll::Ready(Ok((rdlen, wrlen)))
};
},
(_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))),
(_, true, true) => return Poll::Pending,
(..) => ()
}
}
}
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Stream<'a, IO, S> {
fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize> {
// TODO writev
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>
@ -92,58 +185,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}
let mut writer = Writer { io: self.io, cx };
match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
result => Poll::Ready(result)
}
}
pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0;
let mut rdlen = 0;
loop {
let mut write_would_block = false;
let mut read_would_block = false;
while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(n)) => wrlen += n,
Poll::Pending => {
write_would_block = true;
break
},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
if !self.eof && self.session.wants_read() {
match self.read_io(cx) {
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => read_would_block = true,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
let would_block = write_would_block || read_would_block;
return match (self.eof, self.session.is_handshaking(), would_block) {
(true, true, _) => {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
Poll::Ready(Err(err))
},
(_, false, true) => if rdlen != 0 || wrlen != 0 {
Poll::Ready(Ok((rdlen, wrlen)))
} else {
Poll::Pending
},
(_, false, _) => Poll::Ready(Ok((rdlen, wrlen))),
(_, true, true) => Poll::Pending,
(..) => continue
}
}
self.session.write_tls(&mut writer)
}
}
@ -152,8 +194,8 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
let this = self.get_mut();
while this.session.wants_read() {
match this.read_io(cx) {
Poll::Ready(Ok(0)) => break,
match this.complete_inner_io(cx, Focus::Readable) {
Poll::Ready(Ok((0, _))) => break,
Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
@ -178,7 +220,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
Err(err) => return Poll::Ready(Err(err))
};
while this.session.wants_write() {
match this.write_io(cx) {
match this.complete_inner_io(cx, Focus::Writable) {
Poll::Ready(Ok(_)) => (),
Poll::Pending if len != 0 => break,
Poll::Pending => return Poll::Pending,
@ -204,7 +246,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
this.session.flush()?;
while this.session.wants_write() {
futures::ready!(this.write_io(cx))?;
futures::ready!(this.complete_inner_io(cx, Focus::Writable))?;
}
Pin::new(&mut this.io).poll_flush(cx)
}
@ -213,7 +255,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
let this = self.get_mut();
while this.session.wants_write() {
futures::ready!(this.write_io(cx))?;
futures::ready!(this.complete_inner_io(cx, Focus::Writable))?;
}
Pin::new(&mut this.io).poll_shutdown(cx)

View File

@ -83,7 +83,6 @@ async fn stream_good() -> io::Result<()> {
stream.read_to_end(&mut buf).await?;
assert_eq!(buf, FILE);
stream.write_all(b"Hello World!").await?;
stream.flush().await?;
}
let mut buf = String::new();
@ -120,12 +119,12 @@ async fn stream_handshake() -> io::Result<()> {
{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client);
let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?;
let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?;
assert!(r > 0);
assert!(w > 0);
poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake
poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake
}
assert!(!server.is_handshaking());
@ -142,7 +141,7 @@ async fn stream_handshake_eof() -> io::Result<()> {
let mut stream = Stream::new(&mut bad, &mut client);
let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.handshake(&mut cx);
let r = stream.complete_io(&mut cx);
assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof)));
Ok(()) as io::Result<()>
@ -188,11 +187,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut
let mut stream = Stream::new(&mut good, client);
if stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
ready!(stream.complete_io(cx))?;
}
if stream.session.wants_write() {
ready!(stream.handshake(cx))?;
ready!(stream.complete_io(cx))?;
}
Poll::Ready(Ok(()))

View File

@ -48,11 +48,11 @@ where
let mut stream = Stream::new(io, session).set_eof(eof);
if stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
futures::ready!(stream.complete_io(cx))?;
}
if stream.session.wants_write() {
futures::ready!(stream.handshake(cx))?;
futures::ready!(stream.complete_io(cx))?;
}
}

View File

@ -22,7 +22,6 @@ async fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool)
let stream = TcpStream::connect(&addr).await?;
let mut stream = connector.connect(domain, stream).await?;
stream.write_all(input.as_bytes()).await?;
stream.flush().await?;
stream.read_to_end(&mut buf).await?;
Ok((stream, String::from_utf8(buf).unwrap()))

View File

@ -52,19 +52,18 @@ lazy_static!{
let mut buf = vec![0; 8192];
let n = stream.read(&mut buf).await?;
stream.write(&buf[..n]).await?;
stream.flush().await?;
let _ = stream.read(&mut buf).await?;
Ok(()) as io::Result<()>
}.unwrap_or_else(|err| eprintln!("server: {:?}", err));
};
handle.spawn(fut).unwrap();
handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap();
}
Ok(()) as io::Result<()>
}.unwrap_or_else(|err| eprintln!("server: {:?}", err));
};
runtime.block_on(done);
runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err)));
});
let addr = recv.recv().unwrap();
@ -86,7 +85,6 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>)
let stream = TcpStream::connect(&addr).await?;
let mut stream = config.connect(domain, stream).await?;
stream.write_all(FILE).await?;
stream.flush().await?;
stream.read_exact(&mut buf).await?;
assert_eq!(buf, FILE);