feat: handle CloseNotify alert

This commit is contained in:
quininer 2018-03-23 19:03:30 +08:00
parent 034357336e
commit fddb77759f
2 changed files with 40 additions and 11 deletions

View File

@ -100,13 +100,27 @@ impl<S, C> AsyncRead for TlsStream<S, C>
C: Session C: Session
{ {
fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> { fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
if self.eof {
return Ok(Async::Ready(0));
}
// TODO nll
let result = {
let (io, session) = self.get_mut(); let (io, session) = self.get_mut();
let mut taskio = TaskStream { io, task: ctx }; let mut taskio = TaskStream { io, task: ctx };
let mut stream = Stream::new(session, &mut taskio); let mut stream = Stream::new(session, &mut taskio);
io::Read::read(&mut stream, buf)
};
match io::Read::read(&mut stream, buf) { match result {
Ok(0) => { self.eof = true; Ok(Async::Ready(0)) },
Ok(n) => Ok(Async::Ready(n)), Ok(n) => Ok(Async::Ready(n)),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(Async::Ready(0)), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.eof = true;
self.is_shutdown = true;
self.session.send_close_notify();
Ok(Async::Ready(0))
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
Err(e) => Err(e) Err(e) => Err(e)
} }

View File

@ -56,7 +56,7 @@ pub fn connect_async_with_session<S>(stream: S, session: ClientSession)
where S: io::Read + io::Write where S: io::Read + io::Write
{ {
ConnectAsync(MidHandshake { ConnectAsync(MidHandshake {
inner: Some(TlsStream { session, io: stream, is_shutdown: false }) inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
}) })
} }
@ -77,7 +77,7 @@ pub fn accept_async_with_session<S>(stream: S, session: ServerSession)
where S: io::Read + io::Write where S: io::Read + io::Write
{ {
AcceptAsync(MidHandshake { AcceptAsync(MidHandshake {
inner: Some(TlsStream { session, io: stream, is_shutdown: false }) inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false })
}) })
} }
@ -92,6 +92,7 @@ struct MidHandshake<S, C> {
#[derive(Debug)] #[derive(Debug)]
pub struct TlsStream<S, C> { pub struct TlsStream<S, C> {
is_shutdown: bool, is_shutdown: bool,
eof: bool,
io: S, io: S,
session: C session: C
} }
@ -112,12 +113,26 @@ impl<S, C> io::Read for TlsStream<S, C>
where S: io::Read + io::Write, C: Session where S: io::Read + io::Write, C: Session
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.eof {
return Ok(0);
}
// TODO nll
let result = {
let (io, session) = self.get_mut(); let (io, session) = self.get_mut();
let mut stream = Stream::new(session, io); let mut stream = Stream::new(session, io);
stream.read(buf)
};
match stream.read(buf) { match result {
Ok(0) => { self.eof = true; Ok(0) },
Ok(n) => Ok(n), Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.eof = true;
self.is_shutdown = true;
self.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e) Err(e) => Err(e)
} }
} }