Merge branch 'master' into new-api
This commit is contained in:
commit
f633a72f02
20
.travis.yml
20
.travis.yml
@ -1,13 +1,21 @@
|
|||||||
language: rust
|
language: rust
|
||||||
rust:
|
|
||||||
- stable
|
|
||||||
cache: cargo
|
cache: cargo
|
||||||
os:
|
|
||||||
- linux
|
matrix:
|
||||||
- osx
|
include:
|
||||||
|
- rust: stable
|
||||||
|
os: linux
|
||||||
|
- rust: nightly
|
||||||
|
env: FEATURE=nightly
|
||||||
|
os: linux
|
||||||
|
- rust: stable
|
||||||
|
os: osx
|
||||||
|
- rust: nightly
|
||||||
|
env: FEATURE=nightly
|
||||||
|
os: osx
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- cargo test --all-features
|
- cargo test --features "$FEATURE"
|
||||||
- cd examples/server
|
- cd examples/server
|
||||||
- cargo check
|
- cargo check
|
||||||
- cd ../../examples/client
|
- cd ../../examples/client
|
||||||
|
17
Cargo.toml
17
Cargo.toml
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tokio-rustls"
|
name = "tokio-rustls"
|
||||||
version = "0.7.1"
|
version = "0.7.2"
|
||||||
authors = ["quininer kel <quininer@live.com>"]
|
authors = ["quininer kel <quininer@live.com>"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
repository = "https://github.com/quininer/tokio-rustls"
|
repository = "https://github.com/quininer/tokio-rustls"
|
||||||
@ -15,20 +15,17 @@ travis-ci = { repository = "quininer/tokio-rustls" }
|
|||||||
appveyor = { repository = "quininer/tokio-rustls" }
|
appveyor = { repository = "quininer/tokio-rustls" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures-core = { version = "0.2.0", optional = true }
|
|
||||||
futures-io = { version = "0.2.0", optional = true }
|
|
||||||
tokio = { version = "0.1.6", optional = true }
|
tokio = { version = "0.1.6", optional = true }
|
||||||
|
bytes = { version = "0.4", optional = true }
|
||||||
|
iovec = { version = "0.1", optional = true }
|
||||||
rustls = "0.13"
|
rustls = "0.13"
|
||||||
webpki = "0.18.1"
|
webpki = "0.18.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
# futures = "0.2.0"
|
|
||||||
tokio = "0.1.6"
|
tokio = "0.1.6"
|
||||||
|
lazy_static = "1"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "tokio" ]
|
default = ["tokio-support"]
|
||||||
# unstable-futures = [
|
nightly = ["bytes", "iovec"]
|
||||||
# "futures-core",
|
tokio-support = ["tokio"]
|
||||||
# "futures-io",
|
|
||||||
# "tokio/unstable-futures"
|
|
||||||
# ]
|
|
||||||
|
@ -13,7 +13,7 @@ install:
|
|||||||
build: false
|
build: false
|
||||||
|
|
||||||
test_script:
|
test_script:
|
||||||
- 'cargo test --all-features'
|
- 'cargo test'
|
||||||
- 'cd examples/server'
|
- 'cd examples/server'
|
||||||
- 'cargo check'
|
- 'cargo check'
|
||||||
- 'cd ../../examples/client'
|
- 'cd ../../examples/client'
|
||||||
|
@ -5,15 +5,8 @@ authors = ["quininer <quininer@live.com>"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
webpki = "0.18.1"
|
webpki = "0.18.1"
|
||||||
tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] }
|
tokio-rustls = { path = "../.." }
|
||||||
|
|
||||||
tokio = "0.1"
|
tokio = "0.1"
|
||||||
|
clap = "2"
|
||||||
clap = "2.26"
|
|
||||||
webpki-roots = "0.15"
|
webpki-roots = "0.15"
|
||||||
|
tokio-stdin-stdout = "0.1"
|
||||||
[target.'cfg(unix)'.dependencies]
|
|
||||||
tokio-file-unix = "0.5"
|
|
||||||
|
|
||||||
[target.'cfg(not(unix))'.dependencies]
|
|
||||||
tokio-fs = "0.1"
|
|
||||||
|
@ -4,8 +4,7 @@ extern crate webpki;
|
|||||||
extern crate webpki_roots;
|
extern crate webpki_roots;
|
||||||
extern crate tokio_rustls;
|
extern crate tokio_rustls;
|
||||||
|
|
||||||
#[cfg(unix)] extern crate tokio_file_unix;
|
extern crate tokio_stdin_stdout;
|
||||||
#[cfg(not(unix))] extern crate tokio_fs;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
@ -16,6 +15,7 @@ use tokio::net::TcpStream;
|
|||||||
use tokio::prelude::*;
|
use tokio::prelude::*;
|
||||||
use clap::{ App, Arg };
|
use clap::{ App, Arg };
|
||||||
use tokio_rustls::{ TlsConnector, rustls::ClientConfig };
|
use tokio_rustls::{ TlsConnector, rustls::ClientConfig };
|
||||||
|
use tokio_stdin_stdout::{ stdin as tokio_stdin, stdout as tokio_stdout };
|
||||||
|
|
||||||
fn app() -> App<'static, 'static> {
|
fn app() -> App<'static, 'static> {
|
||||||
App::new("client")
|
App::new("client")
|
||||||
@ -52,59 +52,23 @@ fn main() {
|
|||||||
let config = TlsConnector::from(Arc::new(config));
|
let config = TlsConnector::from(Arc::new(config));
|
||||||
|
|
||||||
let socket = TcpStream::connect(&addr);
|
let socket = TcpStream::connect(&addr);
|
||||||
|
let (stdin, stdout) = (tokio_stdin(0), tokio_stdout(0));
|
||||||
|
|
||||||
#[cfg(unix)]
|
let done = socket
|
||||||
let resp = {
|
.and_then(move |stream| {
|
||||||
use tokio::reactor::Handle;
|
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
||||||
use tokio_file_unix::{ raw_stdin, raw_stdout, File };
|
config.connect(domain, stream)
|
||||||
|
})
|
||||||
|
.and_then(move |stream| io::write_all(stream, text))
|
||||||
|
.and_then(move |(stream, _)| {
|
||||||
|
let (r, w) = stream.split();
|
||||||
|
io::copy(r, stdout)
|
||||||
|
.map(drop)
|
||||||
|
.select2(io::copy(stdin, w).map(drop))
|
||||||
|
.map_err(|res| res.split().0)
|
||||||
|
})
|
||||||
|
.map(drop)
|
||||||
|
.map_err(|err| eprintln!("{:?}", err));
|
||||||
|
|
||||||
let stdin = raw_stdin()
|
tokio::run(done);
|
||||||
.and_then(File::new_nb)
|
|
||||||
.and_then(|fd| fd.into_reader(&Handle::current()))
|
|
||||||
.unwrap();
|
|
||||||
let stdout = raw_stdout()
|
|
||||||
.and_then(File::new_nb)
|
|
||||||
.and_then(|fd| fd.into_io(&Handle::current()))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
socket
|
|
||||||
.and_then(move |stream| {
|
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
|
||||||
config.connect(domain, stream)
|
|
||||||
})
|
|
||||||
.and_then(move |stream| io::write_all(stream, text))
|
|
||||||
.and_then(move |(stream, _)| {
|
|
||||||
let (r, w) = stream.split();
|
|
||||||
io::copy(r, stdout)
|
|
||||||
.map(drop)
|
|
||||||
.select2(io::copy(stdin, w).map(drop))
|
|
||||||
.map_err(|res| res.split().0)
|
|
||||||
})
|
|
||||||
.map(drop)
|
|
||||||
.map_err(|err| eprintln!("{:?}", err))
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
|
||||||
let resp = {
|
|
||||||
use tokio_fs::{ stdin as tokio_stdin, stdout as tokio_stdout };
|
|
||||||
|
|
||||||
let (stdin, stdout) = (tokio_stdin(), tokio_stdout());
|
|
||||||
|
|
||||||
socket
|
|
||||||
.and_then(move |stream| {
|
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
|
||||||
config.connect(domain, stream)
|
|
||||||
})
|
|
||||||
.and_then(move |stream| io::write_all(stream, text))
|
|
||||||
.and_then(move |(stream, _)| {
|
|
||||||
let (r, w) = stream.split();
|
|
||||||
io::copy(r, stdout)
|
|
||||||
.map(drop)
|
|
||||||
.join(io::copy(stdin, w).map(drop))
|
|
||||||
})
|
|
||||||
.map(drop)
|
|
||||||
.map_err(|err| eprintln!("{:?}", err))
|
|
||||||
};
|
|
||||||
|
|
||||||
tokio::run(resp);
|
|
||||||
}
|
}
|
||||||
|
@ -4,9 +4,6 @@ version = "0.1.0"
|
|||||||
authors = ["quininer <quininer@live.com>"]
|
authors = ["quininer <quininer@live.com>"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio-rustls = { path = "../..", default-features = false, features = [ "tokio" ] }
|
tokio-rustls = { path = "../.." }
|
||||||
|
|
||||||
tokio = { version = "0.1.6" }
|
tokio = { version = "0.1.6" }
|
||||||
# futures = "0.2.0-beta"
|
clap = "2"
|
||||||
|
|
||||||
clap = "2.26"
|
|
||||||
|
153
src/common/mod.rs
Normal file
153
src/common/mod.rs
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
#[cfg(feature = "tokio-support")]
|
||||||
|
mod vecbuf;
|
||||||
|
|
||||||
|
use std::io::{ self, Read, Write };
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
use std::io::Initializer;
|
||||||
|
use rustls::Session;
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
use rustls::WriteV;
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
#[cfg(feature = "tokio-support")]
|
||||||
|
use tokio::io::AsyncWrite;
|
||||||
|
|
||||||
|
pub struct Stream<'a, S: 'a, IO: 'a> {
|
||||||
|
session: &'a mut S,
|
||||||
|
io: &'a mut IO
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write {
|
||||||
|
fn write_tls(&mut self) -> io::Result<usize>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, S: Session, IO: Read + Write> Stream<'a, S, IO> {
|
||||||
|
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
|
||||||
|
Stream { session, io }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
|
||||||
|
// fork from https://github.com/ctz/rustls/blob/master/src/session.rs#L161
|
||||||
|
|
||||||
|
let until_handshaked = self.session.is_handshaking();
|
||||||
|
let mut eof = false;
|
||||||
|
let mut wrlen = 0;
|
||||||
|
let mut rdlen = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
while self.session.wants_write() {
|
||||||
|
wrlen += self.write_tls()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !until_handshaked && wrlen > 0 {
|
||||||
|
return Ok((rdlen, wrlen));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !eof && self.session.wants_read() {
|
||||||
|
match self.session.read_tls(self.io)? {
|
||||||
|
0 => eof = true,
|
||||||
|
n => rdlen += n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.session.process_new_packets() {
|
||||||
|
Ok(_) => {},
|
||||||
|
Err(e) => {
|
||||||
|
// In case we have an alert to send describing this error,
|
||||||
|
// try a last-gasp write -- but don't predate the primary
|
||||||
|
// error.
|
||||||
|
let _ignored = self.write_tls();
|
||||||
|
|
||||||
|
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
match (eof, until_handshaked, self.session.is_handshaking()) {
|
||||||
|
(_, true, false) => return Ok((rdlen, wrlen)),
|
||||||
|
(_, false, _) => return Ok((rdlen, wrlen)),
|
||||||
|
(true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
|
||||||
|
(..) => ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "nightly"))]
|
||||||
|
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||||
|
fn write_tls(&mut self) -> io::Result<usize> {
|
||||||
|
self.session.write_tls(self.io)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
impl<'a, S: Session, IO: Read + Write> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||||
|
default fn write_tls(&mut self) -> io::Result<usize> {
|
||||||
|
self.session.write_tls(self.io)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
#[cfg(feature = "tokio-support")]
|
||||||
|
impl<'a, S: Session, IO: Read + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
|
||||||
|
fn write_tls(&mut self) -> io::Result<usize> {
|
||||||
|
use tokio::prelude::Async;
|
||||||
|
use self::vecbuf::VecBuf;
|
||||||
|
|
||||||
|
struct V<'a, IO: 'a>(&'a mut IO);
|
||||||
|
|
||||||
|
impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> {
|
||||||
|
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize> {
|
||||||
|
let mut vbytes = VecBuf::new(vbytes);
|
||||||
|
match self.0.write_buf(&mut vbytes) {
|
||||||
|
Ok(Async::Ready(n)) => Ok(n),
|
||||||
|
Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()),
|
||||||
|
Err(err) => Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut vecbuf = V(self.io);
|
||||||
|
self.session.writev_tls(&mut vecbuf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, S: Session, IO: Read + Write> Read for Stream<'a, S, IO> {
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
unsafe fn initializer(&self) -> Initializer {
|
||||||
|
Initializer::nop()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
|
while self.session.wants_read() {
|
||||||
|
if let (0, 0) = self.complete_io()? {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.session.read(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, S: Session, IO: Read + Write> io::Write for Stream<'a, S, IO> {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
|
let len = self.session.write(buf)?;
|
||||||
|
while self.session.wants_write() {
|
||||||
|
match self.complete_io() {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break,
|
||||||
|
Err(err) => return Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(len)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
self.session.flush()?;
|
||||||
|
if self.session.wants_write() {
|
||||||
|
self.complete_io()?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test_stream;
|
161
src/common/test_stream.rs
Normal file
161
src/common/test_stream.rs
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use std::io::{ self, Read, Write, BufReader, Cursor };
|
||||||
|
use webpki::DNSNameRef;
|
||||||
|
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
||||||
|
use rustls::{
|
||||||
|
ServerConfig, ClientConfig,
|
||||||
|
ServerSession, ClientSession,
|
||||||
|
Session, NoClientAuth
|
||||||
|
};
|
||||||
|
use super::Stream;
|
||||||
|
|
||||||
|
|
||||||
|
struct Good<'a>(&'a mut Session);
|
||||||
|
|
||||||
|
impl<'a> Read for Good<'a> {
|
||||||
|
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
|
||||||
|
self.0.write_tls(buf.by_ref())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Write for Good<'a> {
|
||||||
|
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
|
||||||
|
let len = self.0.read_tls(buf.by_ref())?;
|
||||||
|
self.0.process_new_packets()
|
||||||
|
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
||||||
|
Ok(len)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Bad(bool);
|
||||||
|
|
||||||
|
impl Read for Bad {
|
||||||
|
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Write for Bad {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
|
if self.0 {
|
||||||
|
Err(io::ErrorKind::WouldBlock.into())
|
||||||
|
} else {
|
||||||
|
Ok(buf.len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_good() -> io::Result<()> {
|
||||||
|
const FILE: &'static [u8] = include_bytes!("../../README.md");
|
||||||
|
|
||||||
|
let (mut server, mut client) = make_pair();
|
||||||
|
do_handshake(&mut client, &mut server);
|
||||||
|
io::copy(&mut Cursor::new(FILE), &mut server)?;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut good = Good(&mut server);
|
||||||
|
let mut stream = Stream::new(&mut client, &mut good);
|
||||||
|
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
stream.read_to_end(&mut buf)?;
|
||||||
|
assert_eq!(buf, FILE);
|
||||||
|
stream.write_all(b"Hello World!")?
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut buf = String::new();
|
||||||
|
server.read_to_string(&mut buf)?;
|
||||||
|
assert_eq!(buf, "Hello World!");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_bad() -> io::Result<()> {
|
||||||
|
let (mut server, mut client) = make_pair();
|
||||||
|
do_handshake(&mut client, &mut server);
|
||||||
|
client.set_buffer_limit(1024);
|
||||||
|
|
||||||
|
let mut bad = Bad(true);
|
||||||
|
let mut stream = Stream::new(&mut client, &mut bad);
|
||||||
|
assert_eq!(stream.write(&[0x42; 8])?, 8);
|
||||||
|
assert_eq!(stream.write(&[0x42; 8])?, 8);
|
||||||
|
let r = stream.write(&[0x00; 1024])?; // fill buffer
|
||||||
|
assert!(r < 1024);
|
||||||
|
assert_eq!(
|
||||||
|
stream.write(&[0x01]).unwrap_err().kind(),
|
||||||
|
io::ErrorKind::WouldBlock
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_handshake() -> io::Result<()> {
|
||||||
|
let (mut server, mut client) = make_pair();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut good = Good(&mut server);
|
||||||
|
let mut stream = Stream::new(&mut client, &mut good);
|
||||||
|
let (r, w) = stream.complete_io()?;
|
||||||
|
|
||||||
|
assert!(r > 0);
|
||||||
|
assert!(w > 0);
|
||||||
|
|
||||||
|
stream.complete_io()?; // finish server handshake
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(!server.is_handshaking());
|
||||||
|
assert!(!client.is_handshaking());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stream_handshake_eof() -> io::Result<()> {
|
||||||
|
let (_, mut client) = make_pair();
|
||||||
|
|
||||||
|
let mut bad = Bad(false);
|
||||||
|
let mut stream = Stream::new(&mut client, &mut bad);
|
||||||
|
let r = stream.complete_io();
|
||||||
|
|
||||||
|
assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_pair() -> (ServerSession, ClientSession) {
|
||||||
|
const CERT: &str = include_str!("../../tests/end.cert");
|
||||||
|
const CHAIN: &str = include_str!("../../tests/end.chain");
|
||||||
|
const RSA: &str = include_str!("../../tests/end.rsa");
|
||||||
|
|
||||||
|
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||||
|
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||||
|
let mut sconfig = ServerConfig::new(NoClientAuth::new());
|
||||||
|
sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap();
|
||||||
|
let server = ServerSession::new(&Arc::new(sconfig));
|
||||||
|
|
||||||
|
let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap();
|
||||||
|
let mut cconfig = ClientConfig::new();
|
||||||
|
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
||||||
|
cconfig.root_store.add_pem_file(&mut chain).unwrap();
|
||||||
|
let client = ClientSession::new(&Arc::new(cconfig), domain);
|
||||||
|
|
||||||
|
(server, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) {
|
||||||
|
let mut good = Good(server);
|
||||||
|
let mut stream = Stream::new(client, &mut good);
|
||||||
|
stream.complete_io().unwrap();
|
||||||
|
stream.complete_io().unwrap();
|
||||||
|
}
|
122
src/common/vecbuf.rs
Normal file
122
src/common/vecbuf.rs
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
use std::cmp::{ self, Ordering };
|
||||||
|
use bytes::Buf;
|
||||||
|
use iovec::IoVec;
|
||||||
|
|
||||||
|
pub struct VecBuf<'a, 'b: 'a> {
|
||||||
|
pos: usize,
|
||||||
|
cur: usize,
|
||||||
|
inner: &'a [&'b [u8]]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> VecBuf<'a, 'b> {
|
||||||
|
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
|
||||||
|
VecBuf { pos: 0, cur: 0, inner: vbytes }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
||||||
|
fn remaining(&self) -> usize {
|
||||||
|
let sum = self.inner
|
||||||
|
.iter()
|
||||||
|
.skip(self.pos)
|
||||||
|
.map(|bytes| bytes.len())
|
||||||
|
.sum::<usize>();
|
||||||
|
sum - self.cur
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bytes(&self) -> &[u8] {
|
||||||
|
&self.inner[self.pos][self.cur..]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn advance(&mut self, cnt: usize) {
|
||||||
|
let current = self.inner[self.pos].len();
|
||||||
|
match (self.cur + cnt).cmp(¤t) {
|
||||||
|
Ordering::Equal => if self.pos + 1 < self.inner.len() {
|
||||||
|
self.pos += 1;
|
||||||
|
self.cur = 0;
|
||||||
|
} else {
|
||||||
|
self.cur += cnt;
|
||||||
|
},
|
||||||
|
Ordering::Greater => {
|
||||||
|
if self.pos + 1 < self.inner.len() {
|
||||||
|
self.pos += 1;
|
||||||
|
}
|
||||||
|
let remaining = self.cur + cnt - current;
|
||||||
|
self.advance(remaining);
|
||||||
|
},
|
||||||
|
Ordering::Less => self.cur += cnt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "cargo-clippy", allow(needless_range_loop))]
|
||||||
|
fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize {
|
||||||
|
let len = cmp::min(self.inner.len() - self.pos, dst.len());
|
||||||
|
|
||||||
|
if len > 0 {
|
||||||
|
dst[0] = self.bytes().into();
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 1..len {
|
||||||
|
dst[i] = self.inner[self.pos + i].into();
|
||||||
|
}
|
||||||
|
|
||||||
|
len
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test_vecbuf {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_fresh_cursor_vec() {
|
||||||
|
let mut buf = VecBuf::new(&[b"he", b"llo"]);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 5);
|
||||||
|
assert_eq!(buf.bytes(), b"he");
|
||||||
|
|
||||||
|
buf.advance(2);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 3);
|
||||||
|
assert_eq!(buf.bytes(), b"llo");
|
||||||
|
|
||||||
|
buf.advance(3);
|
||||||
|
|
||||||
|
assert_eq!(buf.remaining(), 0);
|
||||||
|
assert_eq!(buf.bytes(), b"");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_u8() {
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
|
||||||
|
assert_eq!(0x21, buf.get_u8());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_u16() {
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||||
|
assert_eq!(0x2154, buf.get_u16_be());
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||||
|
assert_eq!(0x5421, buf.get_u16_le());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_get_u16_buffer_underflow() {
|
||||||
|
let mut buf = VecBuf::new(&[b"\x21"]);
|
||||||
|
buf.get_u16_be();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bufs_vec() {
|
||||||
|
let buf = VecBuf::new(&[b"he", b"llo"]);
|
||||||
|
|
||||||
|
let b1: &[u8] = &mut [0];
|
||||||
|
let b2: &[u8] = &mut [0];
|
||||||
|
|
||||||
|
let mut dst: [&IoVec; 2] =
|
||||||
|
[b1.into(), b2.into()];
|
||||||
|
|
||||||
|
assert_eq!(2, buf.bytes_vec(&mut dst[..]));
|
||||||
|
}
|
||||||
|
}
|
@ -1,170 +0,0 @@
|
|||||||
extern crate futures_core;
|
|
||||||
extern crate futures_io;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
use self::futures_core::{ Future, Poll, Async };
|
|
||||||
use self::futures_core::task::Context;
|
|
||||||
use self::futures_io::{ Error, AsyncRead, AsyncWrite };
|
|
||||||
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite> Future for ConnectAsync<S> {
|
|
||||||
type Item = TlsStream<S, ClientSession>;
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
|
|
||||||
self.0.poll(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite> Future for AcceptAsync<S> {
|
|
||||||
type Item = TlsStream<S, ServerSession>;
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
|
|
||||||
self.0.poll(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! async {
|
|
||||||
( to $r:expr ) => {
|
|
||||||
match $r {
|
|
||||||
Ok(Async::Ready(n)) => Ok(n),
|
|
||||||
Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()),
|
|
||||||
Err(e) => Err(e)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
( from $r:expr ) => {
|
|
||||||
match $r {
|
|
||||||
Ok(n) => Ok(Async::Ready(n)),
|
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
|
|
||||||
Err(e) => Err(e)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TaskStream<'a, 'b: 'a, S: 'a> {
|
|
||||||
io: &'a mut S,
|
|
||||||
task: &'a mut Context<'b>
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S>
|
|
||||||
where S: AsyncRead + AsyncWrite
|
|
||||||
{
|
|
||||||
#[inline]
|
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
|
||||||
async!(to self.io.poll_read(self.task, buf))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S>
|
|
||||||
where S: AsyncRead + AsyncWrite
|
|
||||||
{
|
|
||||||
#[inline]
|
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
||||||
async!(to self.io.poll_write(self.task, buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
async!(to self.io.poll_flush(self.task))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, C> Future for MidHandshake<S, C>
|
|
||||||
where S: AsyncRead + AsyncWrite, C: Session
|
|
||||||
{
|
|
||||||
type Item = TlsStream<S, C>;
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll(&mut self, ctx: &mut Context) -> Poll<Self::Item, Self::Error> {
|
|
||||||
loop {
|
|
||||||
let stream = self.inner.as_mut().unwrap();
|
|
||||||
if !stream.session.is_handshaking() { break };
|
|
||||||
|
|
||||||
let (io, session) = stream.get_mut();
|
|
||||||
let mut taskio = TaskStream { io, task: ctx };
|
|
||||||
|
|
||||||
match session.complete_io(&mut taskio) {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending),
|
|
||||||
Err(e) => return Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Async::Ready(self.inner.take().unwrap()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, C> AsyncRead for TlsStream<S, C>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
C: Session
|
|
||||||
{
|
|
||||||
fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll<usize, Error> {
|
|
||||||
if self.eof {
|
|
||||||
return Ok(Async::Ready(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO nll
|
|
||||||
let result = {
|
|
||||||
let (io, session) = self.get_mut();
|
|
||||||
let mut taskio = TaskStream { io, task: ctx };
|
|
||||||
let mut stream = Stream::new(session, &mut taskio);
|
|
||||||
io::Read::read(&mut stream, buf)
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(0) => { self.eof = true; Ok(Async::Ready(0)) },
|
|
||||||
Ok(n) => Ok(Async::Ready(n)),
|
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
|
|
||||||
self.eof = true;
|
|
||||||
self.is_shutdown = true;
|
|
||||||
self.session.send_close_notify();
|
|
||||||
Ok(Async::Ready(0))
|
|
||||||
},
|
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending),
|
|
||||||
Err(e) => Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, C> AsyncWrite for TlsStream<S, C>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
C: Session
|
|
||||||
{
|
|
||||||
fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll<usize, Error> {
|
|
||||||
let (io, session) = self.get_mut();
|
|
||||||
let mut taskio = TaskStream { io, task: ctx };
|
|
||||||
let mut stream = Stream::new(session, &mut taskio);
|
|
||||||
|
|
||||||
async!(from io::Write::write(&mut stream, buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> {
|
|
||||||
let (io, session) = self.get_mut();
|
|
||||||
let mut taskio = TaskStream { io, task: ctx };
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut stream = Stream::new(session, &mut taskio);
|
|
||||||
async!(from io::Write::flush(&mut stream))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
async!(from io::Write::flush(&mut taskio))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> {
|
|
||||||
if !self.is_shutdown {
|
|
||||||
self.session.send_close_notify();
|
|
||||||
self.is_shutdown = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let (io, session) = self.get_mut();
|
|
||||||
let mut taskio = TaskStream { io, task: ctx };
|
|
||||||
async!(from session.complete_io(&mut taskio))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.io.poll_close(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
18
src/lib.rs
18
src/lib.rs
@ -1,10 +1,22 @@
|
|||||||
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
|
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
|
||||||
|
|
||||||
|
#![cfg_attr(feature = "nightly", feature(specialization, read_initializer))]
|
||||||
|
|
||||||
pub extern crate rustls;
|
pub extern crate rustls;
|
||||||
pub extern crate webpki;
|
pub extern crate webpki;
|
||||||
|
|
||||||
#[cfg(feature = "tokio")] mod tokio_impl;
|
#[cfg(feature = "tokio-support")]
|
||||||
#[cfg(feature = "unstable-futures")] mod futures_impl;
|
extern crate tokio;
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
#[cfg(feature = "tokio-support")]
|
||||||
|
extern crate bytes;
|
||||||
|
#[cfg(feature = "nightly")]
|
||||||
|
#[cfg(feature = "tokio-support")]
|
||||||
|
extern crate iovec;
|
||||||
|
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
#[cfg(feature = "tokio-support")] mod tokio_impl;
|
||||||
|
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -12,8 +24,8 @@ use webpki::DNSNameRef;
|
|||||||
use rustls::{
|
use rustls::{
|
||||||
Session, ClientSession, ServerSession,
|
Session, ClientSession, ServerSession,
|
||||||
ClientConfig, ServerConfig,
|
ClientConfig, ServerConfig,
|
||||||
Stream
|
|
||||||
};
|
};
|
||||||
|
use common::Stream;
|
||||||
|
|
||||||
|
|
||||||
pub struct TlsConnector {
|
pub struct TlsConnector {
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
extern crate tokio;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use self::tokio::prelude::*;
|
use tokio::prelude::*;
|
||||||
use self::tokio::io::{ AsyncRead, AsyncWrite };
|
use tokio::io::{ AsyncRead, AsyncWrite };
|
||||||
use self::tokio::prelude::Poll;
|
use tokio::prelude::Poll;
|
||||||
|
use common::Stream;
|
||||||
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite> Future for Connect<S> {
|
impl<S: AsyncRead + AsyncWrite> Future for Connect<S> {
|
||||||
@ -31,16 +30,17 @@ impl<S, C> Future for MidHandshake<S, C>
|
|||||||
type Error = io::Error;
|
type Error = io::Error;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||||
loop {
|
{
|
||||||
let stream = self.inner.as_mut().unwrap();
|
let stream = self.inner.as_mut().unwrap();
|
||||||
if !stream.session.is_handshaking() { break };
|
if stream.session.is_handshaking() {
|
||||||
|
let (io, session) = stream.get_mut();
|
||||||
|
let mut stream = Stream::new(session, io);
|
||||||
|
|
||||||
let (io, session) = stream.get_mut();
|
match stream.complete_io() {
|
||||||
|
Ok(_) => (),
|
||||||
match session.complete_io(io) {
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
|
||||||
Ok(_) => (),
|
Err(e) => return Err(e)
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady),
|
}
|
||||||
Err(e) => return Err(e)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,11 @@ impl<S, C> AsyncRead for TlsStream<S, C>
|
|||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite,
|
S: AsyncRead + AsyncWrite,
|
||||||
C: Session
|
C: Session
|
||||||
{}
|
{
|
||||||
|
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<S, C> AsyncWrite for TlsStream<S, C>
|
impl<S, C> AsyncWrite for TlsStream<S, C>
|
||||||
where
|
where
|
||||||
|
143
tests/test.rs
143
tests/test.rs
@ -1,81 +1,89 @@
|
|||||||
|
#[macro_use] extern crate lazy_static;
|
||||||
extern crate rustls;
|
extern crate rustls;
|
||||||
extern crate tokio;
|
extern crate tokio;
|
||||||
extern crate tokio_rustls;
|
extern crate tokio_rustls;
|
||||||
extern crate webpki;
|
extern crate webpki;
|
||||||
|
|
||||||
#[cfg(feature = "unstable-futures")] extern crate futures;
|
|
||||||
|
|
||||||
use std::{ io, thread };
|
use std::{ io, thread };
|
||||||
use std::io::{ BufReader, Cursor };
|
use std::io::{ BufReader, Cursor };
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::mpsc::channel;
|
use std::sync::mpsc::channel;
|
||||||
use std::net::{ SocketAddr, IpAddr, Ipv4Addr };
|
use std::net::SocketAddr;
|
||||||
use tokio::net::{ TcpListener, TcpStream };
|
use tokio::net::{ TcpListener, TcpStream };
|
||||||
use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig };
|
use rustls::{ ServerConfig, ClientConfig };
|
||||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
||||||
use tokio_rustls::{ TlsConnector, TlsAcceptor };
|
use tokio_rustls::{ TlsConnector, TlsAcceptor };
|
||||||
|
|
||||||
const CERT: &str = include_str!("end.cert");
|
const CERT: &str = include_str!("end.cert");
|
||||||
const CHAIN: &str = include_str!("end.chain");
|
const CHAIN: &str = include_str!("end.chain");
|
||||||
const RSA: &str = include_str!("end.rsa");
|
const RSA: &str = include_str!("end.rsa");
|
||||||
const HELLO_WORLD: &[u8] = b"Hello world!";
|
|
||||||
|
|
||||||
|
lazy_static!{
|
||||||
|
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
|
||||||
|
use tokio::prelude::*;
|
||||||
|
use tokio::io as aio;
|
||||||
|
|
||||||
fn start_server(cert: Vec<Certificate>, rsa: PrivateKey) -> SocketAddr {
|
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||||
use tokio::prelude::*;
|
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||||
use tokio::io as aio;
|
|
||||||
|
|
||||||
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
||||||
config.set_single_cert(cert, rsa)
|
config.set_single_cert(cert, keys.pop().unwrap())
|
||||||
.expect("invalid key or certificate");
|
.expect("invalid key or certificate");
|
||||||
let config = TlsAcceptor::from(Arc::new(config));
|
let config = TlsAcceptor::from(Arc::new(config));
|
||||||
|
|
||||||
let (send, recv) = channel();
|
let (send, recv) = channel();
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
|
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
|
||||||
let listener = TcpListener::bind(&addr).unwrap();
|
let listener = TcpListener::bind(&addr).unwrap();
|
||||||
|
|
||||||
send.send(listener.local_addr().unwrap()).unwrap();
|
send.send(listener.local_addr().unwrap()).unwrap();
|
||||||
|
|
||||||
let done = listener.incoming()
|
let done = listener.incoming()
|
||||||
.for_each(move |stream| {
|
.for_each(move |stream| {
|
||||||
let done = config.accept(stream)
|
let done = config.accept(stream)
|
||||||
.and_then(|stream| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
|
.and_then(|stream| {
|
||||||
.and_then(|(stream, buf)| {
|
let (reader, writer) = stream.split();
|
||||||
assert_eq!(buf, HELLO_WORLD);
|
aio::copy(reader, writer)
|
||||||
aio::write_all(stream, HELLO_WORLD)
|
})
|
||||||
})
|
.then(|_| Ok(()));
|
||||||
.then(|_| Ok(()));
|
|
||||||
|
|
||||||
tokio::spawn(done);
|
tokio::spawn(done);
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
.map_err(|err| panic!("{:?}", err));
|
.map_err(|err| panic!("{:?}", err));
|
||||||
|
|
||||||
tokio::run(done);
|
tokio::run(done);
|
||||||
});
|
});
|
||||||
|
|
||||||
recv.recv().unwrap()
|
let addr = recv.recv().unwrap();
|
||||||
|
(addr, "localhost", CHAIN)
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<&str>>>) -> io::Result<()> {
|
|
||||||
|
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
|
||||||
|
&*TEST_SERVER
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> {
|
||||||
use tokio::prelude::*;
|
use tokio::prelude::*;
|
||||||
use tokio::io as aio;
|
use tokio::io as aio;
|
||||||
|
|
||||||
|
const FILE: &'static [u8] = include_bytes!("../README.md");
|
||||||
|
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
||||||
let mut config = ClientConfig::new();
|
let mut config = ClientConfig::new();
|
||||||
if let Some(mut chain) = chain {
|
let mut chain = BufReader::new(Cursor::new(chain));
|
||||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||||
}
|
|
||||||
let config = TlsConnector::from(Arc::new(config));
|
let config = TlsConnector::from(Arc::new(config));
|
||||||
|
|
||||||
let done = TcpStream::connect(addr)
|
let done = TcpStream::connect(addr)
|
||||||
.and_then(|stream| config.connect(domain, stream))
|
.and_then(|stream| config.connect(domain, stream))
|
||||||
.and_then(|stream| aio::write_all(stream, HELLO_WORLD))
|
.and_then(|stream| aio::write_all(stream, FILE))
|
||||||
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; HELLO_WORLD.len()]))
|
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()]))
|
||||||
.and_then(|(stream, buf)| {
|
.and_then(|(stream, buf)| {
|
||||||
assert_eq!(buf, HELLO_WORLD);
|
assert_eq!(buf, FILE);
|
||||||
aio::shutdown(stream)
|
aio::shutdown(stream)
|
||||||
})
|
})
|
||||||
.map(drop);
|
.map(drop);
|
||||||
@ -83,62 +91,17 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<
|
|||||||
done.wait()
|
done.wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "unstable-futures")]
|
|
||||||
fn start_client2(addr: &SocketAddr, domain: &str, chain: Option<BufReader<Cursor<&str>>>) -> io::Result<()> {
|
|
||||||
use futures::FutureExt;
|
|
||||||
use futures::io::{ AsyncReadExt, AsyncWriteExt };
|
|
||||||
use futures::executor::block_on;
|
|
||||||
|
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
|
||||||
let mut config = ClientConfig::new();
|
|
||||||
if let Some(mut chain) = chain {
|
|
||||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
|
||||||
}
|
|
||||||
let config = TlsConnector::from(Arc::new(config));
|
|
||||||
|
|
||||||
let done = TcpStream::connect(addr)
|
|
||||||
.and_then(|stream| config.connect(domain, stream))
|
|
||||||
.and_then(|stream| stream.write_all(HELLO_WORLD))
|
|
||||||
.and_then(|(stream, _)| stream.read_exact(vec![0; HELLO_WORLD.len()]))
|
|
||||||
.and_then(|(stream, buf)| {
|
|
||||||
assert_eq!(buf, HELLO_WORLD);
|
|
||||||
stream.close()
|
|
||||||
})
|
|
||||||
.map(drop);
|
|
||||||
|
|
||||||
block_on(done)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pass() {
|
fn pass() {
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
let (addr, domain, chain) = start_server();
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
|
||||||
let chain = BufReader::new(Cursor::new(CHAIN));
|
|
||||||
|
|
||||||
let addr = start_server(cert, keys.pop().unwrap());
|
start_client(addr, domain, chain).unwrap();
|
||||||
start_client(&addr, "localhost", Some(chain)).unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "unstable-futures")]
|
|
||||||
#[test]
|
|
||||||
fn pass2() {
|
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
|
||||||
let chain = BufReader::new(Cursor::new(CHAIN));
|
|
||||||
|
|
||||||
let addr = start_server(cert, keys.pop().unwrap());
|
|
||||||
start_client2(&addr, "localhost", Some(chain)).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[should_panic]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fail() {
|
fn fail() {
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
let (addr, domain, chain) = start_server();
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
|
||||||
let chain = BufReader::new(Cursor::new(CHAIN));
|
|
||||||
|
|
||||||
let addr = start_server(cert, keys.pop().unwrap());
|
assert_ne!(domain, &"google.com");
|
||||||
|
assert!(start_client(addr, "google.com", chain).is_err());
|
||||||
start_client(&addr, "google.com", Some(chain)).unwrap();
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user