diff --git a/src/lib.rs b/src/lib.rs index 1a48ddc..2c6a7e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,15 @@ impl ClientConfigExt for Arc { } } +pub fn connect_async_with_session(stream: S, session: ClientSession) + -> ConnectAsync + where S: AsyncRead + AsyncWrite +{ + ConnectAsync(MidHandshake { + inner: Some(TlsStream::new(stream, session)) + }) +} + impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync @@ -63,6 +72,15 @@ impl ServerConfigExt for Arc { } } +pub fn accept_async_with_session(stream: S, session: ServerSession) + -> AcceptAsync + where S: AsyncRead + AsyncWrite +{ + AcceptAsync(MidHandshake { + inner: Some(TlsStream::new(stream, session)) + }) +} + impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; @@ -209,17 +227,33 @@ impl io::Write for TlsStream where S: AsyncRead + AsyncWrite, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - let output = self.session.write(buf)?; - - while self.session.wants_write() { - match self.session.write_tls(&mut self.io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, - Err(e) => return Err(e) - } + if buf.len() == 0 { + return Ok(0); } - Ok(output) + loop { + let output = self.session.write(buf)?; + + while self.session.wants_write() { + match self.session.write_tls(&mut self.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if output == 0 { + // Both rustls buffer and IO buffer are blocking. + return Err(io::Error::from(io::ErrorKind::WouldBlock)); + } else { + break; + } + } + Err(e) => return Err(e) + } + } + + if output > 0 { + // Already wrote something out. + return Ok(output); + } + } } fn flush(&mut self) -> io::Result<()> {