commit 66835b50403031bd28eb0ee976b40e69e49579dd Author: quininer kel Date: Tue Feb 21 11:52:43 2017 +0800 [Added] init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9d37c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock diff --git a/.gitjournal.toml b/.gitjournal.toml new file mode 100644 index 0000000..508a97e --- /dev/null +++ b/.gitjournal.toml @@ -0,0 +1,10 @@ +categories = ["Added", "Changed", "Fixed", "Improved", "Removed"] +category_delimiters = ["[", "]"] +colored_output = true +enable_debug = true +enable_footers = false +excluded_commit_tags = [] +show_commit_hash = false +show_prefix = false +sort_by = "date" +template_prefix = "" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e8ea275 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tokio-rustls" +version = "0.1.0" +authors = ["quininer kel "] + +[dependencies] +futures = "*" +tokio-core = "*" +rustls = "*" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b51ae20 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,192 @@ +extern crate futures; +extern crate tokio_core; +extern crate rustls; + +use std::io; +use std::sync::Arc; +use futures::{ Future, Poll, Async }; +use tokio_core::io::Io; +use rustls::{ Session, ClientSession, ServerSession }; +pub use rustls::{ ClientConfig, ServerConfig }; + + +pub trait TlsConnectorExt { + fn connect_async(&self, domain: &str, stream: S) + -> ConnectAsync + where S: Io; +} + +pub trait TlsAcceptorExt { + fn accept_async(&self, stream: S) + -> AcceptAsync + where S: Io; +} + + +pub struct ConnectAsync(MidHandshake); + +pub struct AcceptAsync(MidHandshake); + + +impl TlsConnectorExt for Arc { + fn connect_async(&self, domain: &str, stream: S) + -> ConnectAsync + where S: Io + { + ConnectAsync(MidHandshake { + inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) + }) + } +} + +impl TlsAcceptorExt for Arc { + fn accept_async(&self, stream: S) + -> AcceptAsync + where S: Io + { + AcceptAsync(MidHandshake { + inner: Some(TlsStream::new(stream, ServerSession::new(self))) + }) + } +} + +impl Future for ConnectAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +impl Future for AcceptAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + + +struct MidHandshake { + inner: Option> +} + +impl Future for MidHandshake + where S: Io, C: Session +{ + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let stream = self.inner.as_mut().unwrap_or_else(|| unreachable!()); + if !stream.session.is_handshaking() { break }; + + match stream.do_io() { + Ok(()) => continue, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + } + if !stream.session.is_handshaking() { break }; + + if stream.eof { + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); + } else { + return Ok(Async::NotReady); + } + } + + Ok(Async::Ready(self.inner.take().unwrap_or_else(|| unreachable!()))) + } +} + + +pub struct TlsStream { + eof: bool, + io: S, + session: C +} + +impl TlsStream + where S: Io, C: Session +{ + #[inline] + pub fn new(io: S, session: C) -> TlsStream { + TlsStream { + eof: false, + io: io, + session: session + } + } + + pub fn do_io(&mut self) -> io::Result<()> { + loop { + let read_would_block = match (!self.eof && self.session.wants_read(), self.io.poll_read()) { + (true, Async::Ready(())) => { + match self.session.read_tls(&mut self.io) { + Ok(0) => self.eof = true, + Ok(_) => self.session.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + }; + continue + }, + (true, Async::NotReady) => true, + (false, _) => false, + }; + + let write_would_block = match (self.session.wants_write(), self.io.poll_write()) { + (true, Async::Ready(())) => match self.session.write_tls(&mut self.io) { + Ok(_) => continue, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => return Err(e) + }, + (true, Async::NotReady) => true, + (false, _) => false + }; + + if read_would_block || write_would_block { + return Err(io::Error::from(io::ErrorKind::WouldBlock)); + } else { + return Ok(()); + } + } + } +} + +impl io::Read for TlsStream + where S: Io, C: Session +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + loop { + match self.session.read(buf) { + Ok(0) if !self.eof => self.do_io()?, + output => return output + } + } + } +} + +impl io::Write for TlsStream + where S: Io, C: Session +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + while self.session.wants_write() && self.io.poll_write().is_ready() { + self.session.write_tls(&mut self.io)?; + } + self.session.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.session.flush()?; + while self.session.wants_write() && self.io.poll_write().is_ready() { + self.session.write_tls(&mut self.io)?; + } + Ok(()) + } +} + +impl Io for TlsStream where S: Io, C: Session {}