diff --git a/src/futures_impl.rs b/src/futures_impl.rs index 22c637c..86bc9da 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -100,13 +100,27 @@ impl AsyncRead for TlsStream C: Session { fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - let (io, session) = self.get_mut(); - let mut taskio = TaskStream { io, task: ctx }; - let mut stream = Stream::new(session, &mut taskio); + if self.eof { + return Ok(Async::Ready(0)); + } - match io::Read::read(&mut stream, buf) { + // TODO nll + let result = { + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + io::Read::read(&mut stream, buf) + }; + + match result { + Ok(0) => { self.eof = true; Ok(Async::Ready(0)) }, Ok(n) => Ok(Async::Ready(n)), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(Async::Ready(0)), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.eof = true; + self.is_shutdown = true; + self.session.send_close_notify(); + Ok(Async::Ready(0)) + }, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), Err(e) => Err(e) } diff --git a/src/lib.rs b/src/lib.rs index 3337e0d..18b3ae2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,7 +56,7 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) where S: io::Read + io::Write { ConnectAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) }) } @@ -77,7 +77,7 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) where S: io::Read + io::Write { AcceptAsync(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false }) + inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) }) } @@ -92,6 +92,7 @@ struct MidHandshake { #[derive(Debug)] pub struct TlsStream { is_shutdown: bool, + eof: bool, io: S, session: C } @@ -112,12 +113,26 @@ impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let (io, session) = self.get_mut(); - let mut stream = Stream::new(session, io); + if self.eof { + return Ok(0); + } - match stream.read(buf) { + // TODO nll + let result = { + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); + stream.read(buf) + }; + + match result { + Ok(0) => { self.eof = true; Ok(0) }, Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.eof = true; + self.is_shutdown = true; + self.session.send_close_notify(); + Ok(0) + }, Err(e) => Err(e) } }