diff --git a/tokio-rustls/Cargo.toml b/tokio-rustls/Cargo.toml index 27e2d04..044d3a7 100644 --- a/tokio-rustls/Cargo.toml +++ b/tokio-rustls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.22.0" +version = "0.23.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/tokio-rs/tls" @@ -13,15 +13,19 @@ edition = "2018" [dependencies] tokio = "1.0" -rustls = "0.19" -webpki = "0.21" +rustls = { version = "0.20", default-features = false } +webpki = "0.22" [features] -early-data = [] +default = ["logging", "tls12"] dangerous_configuration = ["rustls/dangerous_configuration"] +early-data = [] +logging = ["rustls/logging"] +tls12 = ["rustls/tls12"] [dev-dependencies] tokio = { version = "1.0", features = ["full"] } futures-util = "0.3.1" lazy_static = "1" -webpki-roots = "0.21" +webpki-roots = "0.22" +rustls-pemfile = "0.2.1" diff --git a/tokio-rustls/examples/client/Cargo.toml b/tokio-rustls/examples/client/Cargo.toml index eef9250..4506cc9 100644 --- a/tokio-rustls/examples/client/Cargo.toml +++ b/tokio-rustls/examples/client/Cargo.toml @@ -8,4 +8,5 @@ edition = "2018" tokio = { version = "1.0", features = [ "full" ] } argh = "0.1" tokio-rustls = { path = "../.." } -webpki-roots = "0.21" +webpki-roots = "0.22" +rustls-pemfile = "0.2" \ No newline at end of file diff --git a/tokio-rustls/examples/client/src/main.rs b/tokio-rustls/examples/client/src/main.rs index 3527e20..44d21f7 100644 --- a/tokio-rustls/examples/client/src/main.rs +++ b/tokio-rustls/examples/client/src/main.rs @@ -1,4 +1,5 @@ use argh::FromArgs; +use std::convert::TryFrom; use std::fs::File; use std::io; use std::io::BufReader; @@ -7,7 +8,8 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::io::{copy, split, stdin as tokio_stdin, stdout as tokio_stdout, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_rustls::{rustls::ClientConfig, webpki::DNSNameRef, TlsConnector}; +use tokio_rustls::rustls::{self, OwnedTrustAnchor}; +use tokio_rustls::{webpki, TlsConnector}; /// Tokio Rustls client example #[derive(FromArgs)] @@ -40,25 +42,42 @@ async fn main() -> io::Result<()> { let domain = options.domain.unwrap_or(options.host); let content = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); - let mut config = ClientConfig::new(); + let mut root_cert_store = rustls::RootCertStore::empty(); if let Some(cafile) = &options.cafile { let mut pem = BufReader::new(File::open(cafile)?); - config - .root_store - .add_pem_file(&mut pem) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?; + let certs = rustls_pemfile::certs(&mut pem)?; + let trust_anchors = certs.iter().map(|cert| { + let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }); + root_cert_store.add_server_trust_anchors(trust_anchors); } else { - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + root_cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); } + + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); // i guess this was previously the default? let connector = TlsConnector::from(Arc::new(config)); let stream = TcpStream::connect(&addr).await?; let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout()); - let domain = DNSNameRef::try_from_ascii_str(&domain) + let domain = rustls::ServerName::try_from(domain.as_str()) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; let mut stream = connector.connect(domain, stream).await?; diff --git a/tokio-rustls/examples/server/Cargo.toml b/tokio-rustls/examples/server/Cargo.toml index ba91b44..c563ed7 100644 --- a/tokio-rustls/examples/server/Cargo.toml +++ b/tokio-rustls/examples/server/Cargo.toml @@ -8,3 +8,4 @@ edition = "2018" tokio = { version = "1.0", features = [ "full" ] } argh = "0.1" tokio-rustls = { path = "../.." } +rustls-pemfile = "0.2.1" \ No newline at end of file diff --git a/tokio-rustls/examples/server/src/main.rs b/tokio-rustls/examples/server/src/main.rs index 65fcf39..1a7c1f7 100644 --- a/tokio-rustls/examples/server/src/main.rs +++ b/tokio-rustls/examples/server/src/main.rs @@ -1,4 +1,5 @@ use argh::FromArgs; +use rustls_pemfile::{certs, rsa_private_keys}; use std::fs::File; use std::io::{self, BufReader}; use std::net::ToSocketAddrs; @@ -6,8 +7,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::io::{copy, sink, split, AsyncWriteExt}; use tokio::net::TcpListener; -use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys}; -use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig}; +use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; /// Tokio Rustls server example @@ -33,11 +33,13 @@ struct Options { fn load_certs(path: &Path) -> io::Result> { certs(&mut BufReader::new(File::open(path)?)) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) + .map(|mut certs| certs.drain(..).map(Certificate).collect()) } fn load_keys(path: &Path) -> io::Result> { rsa_private_keys(&mut BufReader::new(File::open(path)?)) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) + .map(|mut keys| keys.drain(..).map(PrivateKey).collect()) } #[tokio::main] @@ -53,9 +55,10 @@ async fn main() -> io::Result<()> { let mut keys = load_keys(&options.key)?; let flag_echo = options.echo_mode; - let mut config = ServerConfig::new(NoClientAuth::new()); - config - .set_single_cert(certs, keys.remove(0)) + let config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; let acceptor = TlsAcceptor::from(Arc::new(config)); diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 9bd20ad..3bd0e1f 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -1,36 +1,35 @@ use super::*; use crate::common::IoSession; -use rustls::Session; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { pub(crate) io: IO, - pub(crate) session: ClientSession, + pub(crate) session: ClientConnection, pub(crate) state: TlsState, } impl TlsStream { #[inline] - pub fn get_ref(&self) -> (&IO, &ClientSession) { + pub fn get_ref(&self) -> (&IO, &ClientConnection) { (&self.io, &self.session) } #[inline] - pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) { + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) { (&mut self.io, &mut self.session) } #[inline] - pub fn into_inner(self) -> (IO, ClientSession) { + pub fn into_inner(self) -> (IO, ClientConnection) { (self.io, self.session) } } impl IoSession for TlsStream { type Io = IO; - type Session = ClientSession; + type Session = ClientConnection; #[inline] fn skip_handshake(&self) -> bool { @@ -68,7 +67,7 @@ where match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(())) => { - if prev == buf.remaining() { + if prev == buf.remaining() || stream.eof { this.state.shutdown_read(); } diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs index 39139fa..fcb6dc9 100644 --- a/tokio-rustls/src/common/handshake.rs +++ b/tokio-rustls/src/common/handshake.rs @@ -1,6 +1,7 @@ use crate::common::{Stream, TlsState}; -use rustls::Session; +use rustls::{ConnectionCommon, SideData}; use std::future::Future; +use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; @@ -15,28 +16,30 @@ pub(crate) trait IoSession { fn into_io(self) -> Self::Io; } -pub(crate) enum MidHandshake { +pub(crate) enum MidHandshake { Handshaking(IS), End, + Error { io: IS::Io, error: io::Error }, } -impl Future for MidHandshake +impl Future for MidHandshake where IS: IoSession + Unpin, IS::Io: AsyncRead + AsyncWrite + Unpin, - IS::Session: Session + Unpin, + IS::Session: DerefMut + Deref> + Unpin, + SD: SideData, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - let mut stream = - if let MidHandshake::Handshaking(stream) = mem::replace(this, MidHandshake::End) { - stream - } else { - panic!("unexpected polling after handshake") - }; + let mut stream = match mem::replace(this, MidHandshake::End) { + MidHandshake::Handshaking(stream) => stream, + // Starting the handshake returned an error; fail the future immediately. + MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), + _ => panic!("unexpected polling after handshake"), + }; if !stream.skip_handshake() { let (state, io, session) = stream.get_mut(); diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index a7b9fa6..06dc39b 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -1,8 +1,9 @@ mod handshake; pub(crate) use handshake::{IoSession, MidHandshake}; -use rustls::Session; +use rustls::{ConnectionCommon, SideData}; use std::io::{self, IoSlice, Read, Write}; +use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -57,20 +58,26 @@ impl TlsState { } } -pub struct Stream<'a, IO, S> { +pub struct Stream<'a, IO, C> { pub io: &'a mut IO, - pub session: &'a mut S, + pub session: &'a mut C, pub eof: bool, + pub unexpected_eof: bool, } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { - pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C> +where + C: DerefMut + Deref>, + SD: SideData, +{ + pub fn new(io: &'a mut IO, session: &'a mut C) -> Self { Stream { io, session, // The state so far is only used to detect EOF, so either Stream // or EarlyData state should both be all right. eof: false, + unexpected_eof: false, } } @@ -214,7 +221,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C> +where + C: DerefMut + Deref>, + SD: SideData, +{ fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -223,10 +234,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a let prev = buf.remaining(); while buf.remaining() != 0 { - let mut would_block = false; + let mut io_pending = false; // read a packet - while self.session.wants_read() { + while !self.eof && self.session.wants_read() { match self.read_io(cx) { Poll::Ready(Ok(0)) => { self.eof = true; @@ -234,30 +245,51 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } Poll::Ready(Ok(_)) => (), Poll::Pending => { - would_block = true; + io_pending = true; break; } Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } } - return match self.session.read(buf.initialize_unfilled()) { - Ok(0) if prev == buf.remaining() && would_block => Poll::Pending, + return match self.session.reader().read(buf.initialize_unfilled()) { + // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the + // connection with a `CloseNotify` message and no more data will be forthcoming. + Ok(0) => break, + + // Rustls yielded more data: advance the buffer, then see if more data is coming. Ok(n) => { buf.advance(n); - if self.eof || would_block { + if self.eof || io_pending { break; } else { continue; } } - Err(ref err) - if err.kind() == io::ErrorKind::ConnectionAborted - && prev != buf.remaining() => - { - break + + // Rustls doesn't have more data to yield, but it believes the connection is open. + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + if prev == buf.remaining() && io_pending { + Poll::Pending + } else if self.eof || io_pending { + break; + } else { + continue; + } } + + Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => { + self.eof = true; + self.unexpected_eof = true; + if prev == buf.remaining() { + Poll::Ready(Err(err)) + } else { + break; + } + } + + // This should be unreachable. Err(err) => Poll::Ready(Err(err)), }; } @@ -266,7 +298,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C> +where + C: DerefMut + Deref>, + SD: SideData, +{ fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context, @@ -277,7 +313,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' while pos != buf.len() { let mut would_block = false; - match self.session.write(&buf[pos..]) { + match self.session.writer().write(&buf[pos..]) { Ok(n) => pos += n, Err(err) => return Poll::Ready(Err(err)), }; @@ -304,7 +340,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.session.flush()?; + self.session.writer().flush()?; while self.session.wants_write() { ready!(self.write_io(cx))?; } diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 034f292..9f1359c 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -1,16 +1,15 @@ use super::Stream; use futures_util::future::poll_fn; use futures_util::task::noop_waker_ref; -use rustls::internal::pemfile::{certs, rsa_private_keys}; -use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session}; +use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection}; +use rustls_pemfile::{certs, rsa_private_keys}; use std::io::{self, BufReader, Cursor, Read, Write}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; -use webpki::DNSNameRef; -struct Good<'a>(&'a mut dyn Session); +struct Good<'a>(&'a mut Connection); impl<'a> AsyncRead for Good<'a> { fn poll_read( @@ -50,9 +49,10 @@ impl<'a> AsyncWrite for Good<'a> { Poll::Ready(Ok(())) } - fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.0.send_close_notify(); - Poll::Ready(Ok(())) + dbg!("sent close notify"); + self.poll_flush(cx) } } @@ -120,23 +120,28 @@ impl AsyncWrite for Eof { async fn stream_good() -> io::Result<()> { const FILE: &[u8] = include_bytes!("../../README.md"); - let (mut server, mut client) = make_pair(); + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; - io::copy(&mut Cursor::new(FILE), &mut server)?; + io::copy(&mut Cursor::new(FILE), &mut server.writer())?; + server.send_close_notify(); + let mut server = Connection::from(server); { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); let mut buf = Vec::new(); - stream.read_to_end(&mut buf).await?; + dbg!(stream.read_to_end(&mut buf).await)?; assert_eq!(buf, FILE); - stream.write_all(b"Hello World!").await?; - stream.flush().await?; + dbg!(stream.write_all(b"Hello World!").await)?; + stream.session.send_close_notify(); + dbg!(stream.shutdown().await)?; } let mut buf = String::new(); - server.read_to_string(&mut buf)?; + dbg!(server.process_new_packets()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + dbg!(server.reader().read_to_string(&mut buf))?; assert_eq!(buf, "Hello World!"); Ok(()) as io::Result<()> @@ -144,9 +149,10 @@ async fn stream_good() -> io::Result<()> { #[tokio::test] async fn stream_bad() -> io::Result<()> { - let (mut server, mut client) = make_pair(); + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; - client.set_buffer_limit(1024); + client.set_buffer_limit(Some(1024)); let mut bad = Pending; let mut stream = Stream::new(&mut bad, &mut client); @@ -170,7 +176,8 @@ async fn stream_bad() -> io::Result<()> { #[tokio::test] async fn stream_handshake() -> io::Result<()> { - let (mut server, mut client) = make_pair(); + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); { let mut good = Good(&mut server); @@ -208,42 +215,72 @@ async fn stream_handshake_eof() -> io::Result<()> { #[tokio::test] async fn stream_eof() -> io::Result<()> { - let (mut server, mut client) = make_pair(); + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client).set_eof(true); let mut buf = Vec::new(); - stream.read_to_end(&mut buf).await?; - assert_eq!(buf.len(), 0); + let result = stream.read_to_end(&mut buf).await; + assert_eq!( + result.err().map(|e| e.kind()), + Some(io::ErrorKind::UnexpectedEof) + ); Ok(()) as io::Result<()> } -fn make_pair() -> (ServerSession, ClientSession) { +fn make_pair() -> (ServerConnection, ClientConnection) { + use std::convert::TryFrom; + const CERT: &str = include_str!("../../tests/end.cert"); const CHAIN: &str = include_str!("../../tests/end.chain"); const RSA: &str = include_str!("../../tests/end.rsa"); - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let cert = certs(&mut BufReader::new(Cursor::new(CERT))) + .unwrap() + .drain(..) + .map(rustls::Certificate) + .collect(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut sconfig = ServerConfig::new(NoClientAuth::new()); - sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap(); - let server = ServerSession::new(&Arc::new(sconfig)); + let mut keys = keys.drain(..).map(rustls::PrivateKey); + let sconfig = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert, keys.next().unwrap()) + .unwrap(); + let server = ServerConnection::new(Arc::new(sconfig)).unwrap(); - let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap(); - let mut cconfig = ClientConfig::new(); + let domain = rustls::ServerName::try_from("localhost").unwrap(); + let mut client_root_cert_store = RootCertStore::empty(); let mut chain = BufReader::new(Cursor::new(CHAIN)); - cconfig.root_store.add_pem_file(&mut chain).unwrap(); - let client = ClientSession::new(&Arc::new(cconfig), domain); + let certs = certs(&mut chain).unwrap(); + let trust_anchors = certs + .iter() + .map(|cert| { + let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + .collect::>(); + client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter()); + let cconfig = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(client_root_cert_store) + .with_no_client_auth(); + let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap(); (server, client) } fn do_handshake( - client: &mut ClientSession, - server: &mut ServerSession, + client: &mut ClientConnection, + server: &mut Connection, cx: &mut Context<'_>, ) -> Poll> { let mut good = Good(server); diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index 8f07b58..a8e7302 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -14,14 +14,13 @@ mod common; pub mod server; use common::{MidHandshake, Stream, TlsState}; -use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; +use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use webpki::DNSNameRef; pub use rustls; pub use webpki; @@ -68,19 +67,29 @@ impl TlsConnector { } #[inline] - pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect + pub fn connect(&self, domain: rustls::ServerName, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, { self.connect_with(domain, stream, |_| ()) } - pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect + pub fn connect_with(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, - F: FnOnce(&mut ClientSession), + F: FnOnce(&mut ClientConnection), { - let mut session = ClientSession::new(&self.inner, domain); + let mut session = match ClientConnection::new(self.inner.clone(), domain) { + Ok(session) => session, + Err(error) => { + return Connect(MidHandshake::Error { + io: stream, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; f(&mut session); Connect(MidHandshake::Handshaking(client::TlsStream { @@ -113,9 +122,19 @@ impl TlsAcceptor { pub fn accept_with(&self, stream: IO, f: F) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, - F: FnOnce(&mut ServerSession), + F: FnOnce(&mut ServerConnection), { - let mut session = ServerSession::new(&self.inner); + let mut session = match ServerConnection::new(self.inner.clone()) { + Ok(session) => session, + Err(error) => { + return Accept(MidHandshake::Error { + io: stream, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; f(&mut session); Accept(MidHandshake::Handshaking(server::TlsStream { @@ -201,7 +220,7 @@ pub enum TlsStream { } impl TlsStream { - pub fn get_ref(&self) -> (&T, &dyn Session) { + pub fn get_ref(&self) -> (&T, &CommonState) { use TlsStream::*; match self { Client(io) => { @@ -215,7 +234,7 @@ impl TlsStream { } } - pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) { + pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) { use TlsStream::*; match self { Client(io) => { diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs index 7ea7ce9..cf30b11 100644 --- a/tokio-rustls/src/server.rs +++ b/tokio-rustls/src/server.rs @@ -1,36 +1,35 @@ use super::*; use crate::common::IoSession; -use rustls::Session; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { pub(crate) io: IO, - pub(crate) session: ServerSession, + pub(crate) session: ServerConnection, pub(crate) state: TlsState, } impl TlsStream { #[inline] - pub fn get_ref(&self) -> (&IO, &ServerSession) { + pub fn get_ref(&self) -> (&IO, &ServerConnection) { (&self.io, &self.session) } #[inline] - pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) { + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { (&mut self.io, &mut self.session) } #[inline] - pub fn into_inner(self) -> (IO, ServerSession) { + pub fn into_inner(self) -> (IO, ServerConnection) { (self.io, self.session) } } impl IoSession for TlsStream { type Io = IO; - type Session = ServerSession; + type Session = ServerConnection; #[inline] fn skip_handshake(&self) -> bool { @@ -67,7 +66,7 @@ where match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(())) => { - if prev == buf.remaining() { + if prev == buf.remaining() || stream.eof { this.state.shutdown_read(); } diff --git a/tokio-rustls/tests/badssl.rs b/tokio-rustls/tests/badssl.rs index 54abdef..0ab9f43 100644 --- a/tokio-rustls/tests/badssl.rs +++ b/tokio-rustls/tests/badssl.rs @@ -1,10 +1,14 @@ -use rustls::ClientConfig; +use std::convert::TryFrom; use std::io; use std::net::ToSocketAddrs; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_rustls::{client::TlsStream, TlsConnector}; +use tokio_rustls::{ + client::TlsStream, + rustls::{self, ClientConfig, OwnedTrustAnchor}, + TlsConnector, +}; async fn get( config: Arc, @@ -15,7 +19,7 @@ async fn get( let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let addr = (domain, port).to_socket_addrs()?.next().unwrap(); - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + let domain = rustls::ServerName::try_from(domain).unwrap(); let mut buf = Vec::new(); let stream = TcpStream::connect(&addr).await?; @@ -29,16 +33,31 @@ async fn get( #[tokio::test] async fn test_tls12() -> io::Result<()> { - let mut config = ClientConfig::new(); - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - config.versions = vec![rustls::ProtocolVersion::TLSv1_2]; + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + let config = rustls::ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS12]) + .unwrap() + .with_root_certificates(root_store) + .with_no_client_auth(); + let config = Arc::new(config); let domain = "tls-v1-2.badssl.com"; let (_, output) = get(config.clone(), domain, 1012).await?; - assert!(output.contains("tls-v1-2.badssl.com")); + assert!( + output.contains("tls-v1-2.badssl.com"), + "failed badssl test, output: {}", + output + ); Ok(()) } @@ -52,15 +71,27 @@ fn test_tls13() { #[tokio::test] async fn test_modern() -> io::Result<()> { - let mut config = ClientConfig::new(); - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; let (_, output) = get(config.clone(), domain, 443).await?; - assert!(output.contains("mozilla-modern.badssl.com")); + assert!( + output.contains("mozilla-modern.badssl.com"), + "failed badssl test, output: {}", + output + ); Ok(()) } diff --git a/tokio-rustls/tests/early-data.rs b/tokio-rustls/tests/early-data.rs index 4a718f3..80d6d15 100644 --- a/tokio-rustls/tests/early-data.rs +++ b/tokio-rustls/tests/early-data.rs @@ -1,7 +1,8 @@ #![cfg(feature = "early-data")] use futures_util::{future, future::Future, ready}; -use rustls::ClientConfig; +use rustls::RootCertStore; +use std::convert::TryFrom; use std::io::{self, BufRead, BufReader, Cursor}; use std::net::SocketAddr; use std::pin::Pin; @@ -12,7 +13,11 @@ use std::time::Duration; use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; use tokio::time::sleep; -use tokio_rustls::{client::TlsStream, TlsConnector}; +use tokio_rustls::{ + client::TlsStream, + rustls::{self, ClientConfig, OwnedTrustAnchor}, + TlsConnector, +}; struct Read1(T); @@ -34,7 +39,7 @@ async fn send( ) -> io::Result> { let connector = TlsConnector::from(config).early_data(true); let stream = TcpStream::connect(&addr).await?; - let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); + let domain = rustls::ServerName::try_from("testserver.com").unwrap(); let mut stream = connector.connect(domain, stream).await?; stream.write_all(data).await?; @@ -81,10 +86,28 @@ async fn test_0rtt() -> io::Result<()> { // wait openssl server sleep(Duration::from_secs(1)).await; - let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); - config.root_store.add_pem_file(&mut chain).unwrap(); - config.versions = vec![rustls::ProtocolVersion::TLSv1_3]; + let certs = rustls_pemfile::certs(&mut chain).unwrap(); + let trust_anchors = certs + .iter() + .map(|cert| { + let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + .collect::>(); + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors(trust_anchors.into_iter()); + let mut config = rustls::ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_root_certificates(root_store) + .with_no_client_auth(); config.enable_early_data = true; let config = Arc::new(config); let addr = SocketAddr::from(([127, 0, 0, 1], 12354)); diff --git a/tokio-rustls/tests/test.rs b/tokio-rustls/tests/test.rs index 28de255..f0f5f56 100644 --- a/tokio-rustls/tests/test.rs +++ b/tokio-rustls/tests/test.rs @@ -1,7 +1,8 @@ use futures_util::future::TryFutureExt; use lazy_static::lazy_static; -use rustls::internal::pemfile::{certs, rsa_private_keys}; -use rustls::{ClientConfig, ServerConfig}; +use rustls::{ClientConfig, OwnedTrustAnchor}; +use rustls_pemfile::{certs, rsa_private_keys}; +use std::convert::TryFrom; use std::io::{BufReader, Cursor}; use std::net::SocketAddr; use std::sync::mpsc::channel; @@ -13,18 +14,24 @@ use tokio::runtime; use tokio_rustls::{TlsAcceptor, TlsConnector}; const CERT: &str = include_str!("end.cert"); -const CHAIN: &str = include_str!("end.chain"); +const CHAIN: &[u8] = include_bytes!("end.chain"); const RSA: &str = include_str!("end.rsa"); lazy_static! { - static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + static ref TEST_SERVER: (SocketAddr, &'static str, &'static [u8]) = { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))) + .unwrap() + .drain(..) + .map(rustls::Certificate) + .collect(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let mut keys = keys.drain(..).map(rustls::PrivateKey); - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config - .set_single_cert(cert, keys.pop().unwrap()) - .expect("invalid key or certificate"); + let config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert, keys.next().unwrap()) + .unwrap(); let acceptor = TlsAcceptor::from(Arc::new(config)); let (send, recv) = channel(); @@ -70,14 +77,14 @@ lazy_static! { }; } -fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { +fn start_server() -> &'static (SocketAddr, &'static str, &'static [u8]) { &*TEST_SERVER } async fn start_client(addr: SocketAddr, domain: &str, config: Arc) -> io::Result<()> { const FILE: &[u8] = include_bytes!("../README.md"); - let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); + let domain = rustls::ServerName::try_from(domain).unwrap(); let config = TlsConnector::from(config); let mut buf = vec![0; FILE.len()]; @@ -102,12 +109,27 @@ async fn pass() -> io::Result<()> { use std::time::*; tokio::time::sleep(Duration::from_secs(1)).await; - let mut config = ClientConfig::new(); - let mut chain = BufReader::new(Cursor::new(chain)); - config.root_store.add_pem_file(&mut chain).unwrap(); + let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap(); + let trust_anchors = chain + .iter() + .map(|cert| { + let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + .collect::>(); + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_server_trust_anchors(trust_anchors.into_iter()); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); let config = Arc::new(config); - start_client(*addr, domain, config.clone()).await?; + start_client(*addr, domain, config).await?; Ok(()) } @@ -116,9 +138,24 @@ async fn pass() -> io::Result<()> { async fn fail() -> io::Result<()> { let (addr, domain, chain) = start_server(); - let mut config = ClientConfig::new(); - let mut chain = BufReader::new(Cursor::new(chain)); - config.root_store.add_pem_file(&mut chain).unwrap(); + let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap(); + let trust_anchors = chain + .iter() + .map(|cert| { + let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + .collect::>(); + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_server_trust_anchors(trust_anchors.into_iter()); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); let config = Arc::new(config); assert_ne!(domain, &"google.com");