add 0-RTT test
This commit is contained in:
parent
821d1c129f
commit
369c13d6a5
@ -113,6 +113,8 @@ impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
/// Note: that it does not guarantee the final data to be sent.
|
||||
/// To be cautious, you must manually call `flush`.
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||
|
@ -195,7 +195,3 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
|
||||
Pin::new(&mut self.0).poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "early-data")]
|
||||
#[cfg(test)]
|
||||
mod test_0rtt;
|
||||
|
@ -105,6 +105,8 @@ impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
/// Note: that it does not guarantee the final data to be sent.
|
||||
/// To be cautious, you must manually call `flush`.
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||
|
118
tests/early-data.rs
Normal file
118
tests/early-data.rs
Normal file
@ -0,0 +1,118 @@
|
||||
#![cfg(feature = "early-data")]
|
||||
|
||||
use std::io::{ self, BufReader, BufRead, Cursor };
|
||||
use std::process::{ Command, Child, Stdio };
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::marker::Unpin;
|
||||
use std::pin::{ Pin };
|
||||
use std::task::{ Context, Poll };
|
||||
use std::time::Duration;
|
||||
use tokio::prelude::*;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::split;
|
||||
use tokio::timer::delay_for;
|
||||
use futures_util::{ future, ready };
|
||||
use rustls::ClientConfig;
|
||||
use tokio_rustls::{ TlsConnector, client::TlsStream };
|
||||
|
||||
|
||||
struct Read1<T>(T);
|
||||
|
||||
impl<T: AsyncRead + Unpin> Future for Read1<T> {
|
||||
type Output = io::Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut buf = [0];
|
||||
ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?;
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(config: Arc<ClientConfig>, addr: SocketAddr, data: &[u8])
|
||||
-> io::Result<TlsStream<TcpStream>>
|
||||
{
|
||||
let connector = TlsConnector::from(config)
|
||||
.early_data(true);
|
||||
let stream = TcpStream::connect(&addr).await?;
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();
|
||||
|
||||
let mut stream = connector.connect(domain, stream).await?;
|
||||
stream.write_all(data).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let (r, mut w) = split(stream);
|
||||
let fut = Read1(r);
|
||||
let fut2 = async move {
|
||||
// sleep 3s
|
||||
//
|
||||
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
|
||||
delay_for(Duration::from_secs(3)).await;
|
||||
w.shutdown().await?;
|
||||
Ok(w) as io::Result<_>
|
||||
};
|
||||
|
||||
let stream = match future::select(fut, fut2.boxed()).await {
|
||||
future::Either::Left(_) => unreachable!(),
|
||||
future::Either::Right((Ok(w), Read1(r))) => r.unsplit(w),
|
||||
future::Either::Right((Err(err), _)) => return Err(err)
|
||||
};
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
struct DropKill(Child);
|
||||
|
||||
impl Drop for DropKill {
|
||||
fn drop(&mut self) {
|
||||
self.0.kill().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_0rtt() -> io::Result<()> {
|
||||
let mut handle = Command::new("openssl")
|
||||
.arg("s_server")
|
||||
.arg("-early_data")
|
||||
.arg("-tls1_3")
|
||||
.args(&["-cert", "./tests/end.cert"])
|
||||
.args(&["-key", "./tests/end.rsa"])
|
||||
.args(&["-port", "12354"])
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()
|
||||
.map(DropKill)?;
|
||||
|
||||
// wait openssl server
|
||||
delay_for(Duration::from_secs(3)).await;
|
||||
|
||||
let mut config = ClientConfig::new();
|
||||
let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
config.versions = vec![rustls::ProtocolVersion::TLSv1_3];
|
||||
config.enable_early_data = true;
|
||||
let config = Arc::new(config);
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 12354));
|
||||
|
||||
let io = send(config.clone(), addr, b"hello").await?;
|
||||
assert!(!io.get_ref().1.is_early_data_accepted());
|
||||
|
||||
let io = send(config, addr, b"world!").await?;
|
||||
assert!(io.get_ref().1.is_early_data_accepted());
|
||||
|
||||
let stdout = handle.0.stdout.as_mut().unwrap();
|
||||
let mut lines = BufReader::new(stdout).lines();
|
||||
|
||||
for line in lines.by_ref() {
|
||||
if line?.contains("hello") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for line in lines.by_ref() {
|
||||
if line?.contains("world!") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -7,6 +7,7 @@ use lazy_static::lazy_static;
|
||||
use tokio::prelude::*;
|
||||
use tokio::runtime::current_thread;
|
||||
use tokio::net::{ TcpListener, TcpStream };
|
||||
use tokio::io::split;
|
||||
use futures_util::try_future::TryFutureExt;
|
||||
use rustls::{ ServerConfig, ClientConfig };
|
||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
||||
@ -42,17 +43,10 @@ lazy_static!{
|
||||
while let Some(stream) = incoming.next().await {
|
||||
let acceptor = acceptor.clone();
|
||||
let fut = async move {
|
||||
let mut stream = acceptor.accept(stream?).await?;
|
||||
let stream = acceptor.accept(stream?).await?;
|
||||
|
||||
// TODO split
|
||||
//
|
||||
// let (mut reader, mut write) = stream.split();
|
||||
// reader.copy(&mut write).await?;
|
||||
|
||||
let mut buf = vec![0; 8192];
|
||||
let n = stream.read(&mut buf).await?;
|
||||
stream.write(&buf[..n]).await?;
|
||||
stream.flush().await?;
|
||||
let (mut reader, mut writer) = split(stream);
|
||||
reader.copy(&mut writer).await?;
|
||||
|
||||
Ok(()) as io::Result<()>
|
||||
}.unwrap_or_else(|err| eprintln!("server: {:?}", err));
|
||||
@ -67,7 +61,7 @@ lazy_static!{
|
||||
});
|
||||
|
||||
let addr = recv.recv().unwrap();
|
||||
(addr, "localhost", CHAIN)
|
||||
(addr, "testserver.com", CHAIN)
|
||||
};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user