tokio-rustls/tokio-native-tls/tests/smoke.rs
Lucio Franco 7e41beaff4
Rename more tests (#1)
* Rename more tests

* Clean up smoke test

* fmt

* Clean up ci and remove all-features test
2020-02-27 18:32:52 -05:00

146 lines
4.4 KiB
Rust

use futures::join;
use native_tls::{Certificate, Identity};
use std::io::Error;
use tokio::{
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_native_tls::{TlsAcceptor, TlsConnector};
#[tokio::test]
async fn client_to_server() {
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = srv.local_addr().unwrap();
let (server_tls, client_tls) = context();
// Create a future to accept one socket, connect the ssl stream, and then
// read all the data from it.
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let mut socket = server_tls.accept(socket).await.unwrap();
let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap();
data
};
// Create a future to connect to our server, connect the ssl stream, and
// then write a bunch of data to it.
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let socket = client_tls.connect("foobar.com", socket).await.unwrap();
copy_data(socket).await
};
// Finally, run everything!
let (data, _) = join!(server, client);
// assert_eq!(amt, AMT);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn server_to_client() {
// Create a server listening on a port, then figure out what that port is
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = srv.local_addr().unwrap();
let (server_tls, client_tls) = context();
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let socket = server_tls.accept(socket).await.unwrap();
copy_data(socket).await
};
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
let mut data = Vec::new();
socket.read_to_end(&mut data).await.unwrap();
data
};
// Finally, run everything!
let (_, data) = join!(server, client);
assert!(data == vec![9; AMT]);
}
#[tokio::test]
async fn one_byte_at_a_time() {
const AMT: usize = 1024;
let mut srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = srv.local_addr().unwrap();
let (server_tls, client_tls) = context();
let server = async move {
let (socket, _) = srv.accept().await.unwrap();
let mut socket = server_tls.accept(socket).await.unwrap();
let mut amt = 0;
for b in std::iter::repeat(9).take(AMT) {
let data = [b as u8];
socket.write_all(&data).await.unwrap();
amt += 1;
}
amt
};
let client = async move {
let socket = TcpStream::connect(&addr).await.unwrap();
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
let mut data = Vec::new();
loop {
let mut buf = [0; 1];
match socket.read_exact(&mut buf).await {
Ok(_) => data.extend_from_slice(&buf),
Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(err) => panic!(err),
}
}
data
};
let (amt, data) = join!(server, client);
assert_eq!(amt, AMT);
assert!(data == vec![9; AMT as usize]);
}
fn context() -> (TlsAcceptor, TlsConnector) {
// Certs borrowed from `rust-native-tls/tests`
let pkcs12 = include_bytes!("identity.p12");
let der = include_bytes!("root-ca.der");
let identity = Identity::from_pkcs12(pkcs12, "mypass").unwrap();
let acceptor = native_tls::TlsAcceptor::builder(identity).build().unwrap();
let cert = Certificate::from_der(der).unwrap();
let connector = native_tls::TlsConnector::builder()
.add_root_certificate(cert)
.build()
.unwrap();
(acceptor.into(), connector.into())
}
const AMT: usize = 128 * 1024;
async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
let mut data = vec![9; AMT as usize];
let mut amt = 0;
while !data.is_empty() {
let written = w.write(&data).await?;
if written <= data.len() {
amt += written;
data.resize(data.len() - written, 0);
} else {
w.write_all(&data).await?;
amt += data.len();
break;
}
println!("remaining: {}", data.len());
}
Ok(amt)
}