tokio-rustls/src/client.rs

203 lines
6.5 KiB
Rust
Raw Normal View History

2019-02-18 12:41:52 +00:00
use super::*;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientSession,
pub(crate) state: TlsState,
2019-02-25 15:48:06 +00:00
#[cfg(feature = "early-data")]
pub(crate) early_data: (usize, Vec<u8>),
2019-02-18 12:41:52 +00:00
}
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
#[cfg(feature = "early-data")]
EarlyData(TlsStream<IO>),
End,
2019-02-18 12:41:52 +00:00
}
impl<IO> TlsStream<IO> {
#[inline]
pub fn get_ref(&self) -> (&IO, &ClientSession) {
(&self.io, &self.session)
}
#[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
(&mut self.io, &mut self.session)
}
#[inline]
pub fn into_inner(self) -> (IO, ClientSession) {
(self.io, self.session)
}
}
impl<IO> Future for MidHandshake<IO>
where
2019-05-18 08:05:10 +00:00
IO: AsyncRead + AsyncWrite + Unpin,
2019-02-18 12:41:52 +00:00
{
2019-05-18 08:05:10 +00:00
type Output = io::Result<TlsStream<IO>>;
2019-02-18 12:41:52 +00:00
#[inline]
2019-05-19 16:28:27 +00:00
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
2019-05-18 10:18:26 +00:00
let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this {
let eof = !stream.state.readable();
2019-02-22 17:48:09 +00:00
let (io, session) = stream.get_mut();
2019-05-18 10:18:26 +00:00
let mut stream = Stream::new(io, session).set_eof(eof);
2019-02-18 12:41:52 +00:00
2019-02-22 17:48:09 +00:00
if stream.session.is_handshaking() {
2019-05-18 08:05:10 +00:00
try_ready!(stream.complete_io(cx));
2019-02-22 17:48:09 +00:00
}
2019-02-18 12:41:52 +00:00
2019-02-22 17:48:09 +00:00
if stream.session.wants_write() {
2019-05-18 08:05:10 +00:00
try_ready!(stream.complete_io(cx));
2019-02-22 17:48:09 +00:00
}
2019-02-18 12:41:52 +00:00
}
2019-05-18 10:18:26 +00:00
match mem::replace(this, MidHandshake::End) {
2019-05-18 08:05:10 +00:00
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
2019-02-25 15:48:06 +00:00
#[cfg(feature = "early-data")]
2019-05-18 08:05:10 +00:00
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(),
2019-02-18 12:41:52 +00:00
}
}
}
2019-05-18 08:05:10 +00:00
impl<IO> AsyncRead for TlsStream<IO>
where
2019-05-18 08:05:10 +00:00
IO: AsyncRead + AsyncWrite + Unpin,
2019-02-18 12:41:52 +00:00
{
2019-05-18 08:05:10 +00:00
unsafe fn initializer(&self) -> Initializer {
// TODO
Initializer::nop()
}
2019-05-19 16:28:27 +00:00
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
2019-02-18 12:41:52 +00:00
match self.state {
2019-02-25 15:48:06 +00:00
#[cfg(feature = "early-data")]
2019-02-18 12:41:52 +00:00
TlsState::EarlyData => {
2019-05-18 08:05:10 +00:00
let this = self.get_mut();
2019-02-18 12:41:52 +00:00
2019-05-18 10:18:26 +00:00
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
2019-05-18 08:05:10 +00:00
let (pos, data) = &mut this.early_data;
2019-02-18 12:41:52 +00:00
2019-05-18 08:05:10 +00:00
// complete handshake
if stream.session.is_handshaking() {
try_ready!(stream.complete_io(cx));
}
2019-03-26 02:44:38 +00:00
2019-05-18 08:05:10 +00:00
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
2019-05-20 17:47:50 +00:00
let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..]));
2019-05-18 08:05:10 +00:00
*pos += len;
}
2019-02-18 12:41:52 +00:00
}
2019-05-18 08:05:10 +00:00
// end
this.state = TlsState::Stream;
data.clear();
Pin::new(this).poll_read(cx, buf)
}
TlsState::Stream | TlsState::WriteShutdown => {
2019-05-18 08:05:10 +00:00
let this = self.get_mut();
2019-05-18 10:18:26 +00:00
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
2019-03-26 02:44:38 +00:00
2019-05-20 17:47:50 +00:00
match stream.pin().poll_read(cx, buf) {
2019-05-18 08:05:10 +00:00
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
}
2019-05-18 08:05:10 +00:00
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
2019-05-18 08:05:10 +00:00
this.state.shutdown_write();
}
2019-05-18 08:05:10 +00:00
Poll::Ready(Ok(0))
}
2019-05-18 08:05:10 +00:00
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending
2019-03-26 02:44:38 +00:00
}
}
2019-05-18 08:05:10 +00:00
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
2019-02-18 12:41:52 +00:00
}
}
}
2019-05-18 08:05:10 +00:00
impl<IO> AsyncWrite for TlsStream<IO>
where
2019-05-18 08:05:10 +00:00
IO: AsyncRead + AsyncWrite + Unpin,
2019-02-18 12:41:52 +00:00
{
2019-05-19 16:28:27 +00:00
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
2019-05-18 08:05:10 +00:00
let this = self.get_mut();
2019-05-18 10:18:26 +00:00
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
2019-02-18 12:41:52 +00:00
2019-05-18 08:05:10 +00:00
match this.state {
2019-02-25 15:48:06 +00:00
#[cfg(feature = "early-data")]
2019-02-18 12:41:52 +00:00
TlsState::EarlyData => {
2019-05-18 16:48:56 +00:00
use std::io::Write;
2019-05-18 08:05:10 +00:00
let (pos, data) = &mut this.early_data;
2019-02-18 12:41:52 +00:00
// write early data
if let Some(mut early_data) = stream.session.early_data() {
2019-05-18 08:05:10 +00:00
let len = early_data.write(buf)?; // TODO check pending
2019-02-18 12:41:52 +00:00
data.extend_from_slice(&buf[..len]);
2019-05-18 08:05:10 +00:00
return Poll::Ready(Ok(len));
2019-02-18 12:41:52 +00:00
}
// complete handshake
if stream.session.is_handshaking() {
2019-05-18 08:05:10 +00:00
try_ready!(stream.complete_io(cx));
2019-02-18 12:41:52 +00:00
}
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
2019-05-20 17:47:50 +00:00
let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..]));
2019-02-18 12:41:52 +00:00
*pos += len;
}
}
// end
2019-05-18 08:05:10 +00:00
this.state = TlsState::Stream;
2019-02-18 12:41:52 +00:00
data.clear();
2019-05-20 17:47:50 +00:00
stream.pin().poll_write(cx, buf)
}
2019-05-20 17:47:50 +00:00
_ => stream.pin().poll_write(cx, buf),
2019-02-18 12:41:52 +00:00
}
}
2019-05-19 16:28:27 +00:00
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
2019-05-18 08:05:10 +00:00
let this = self.get_mut();
2019-05-20 17:47:50 +00:00
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
stream.pin().poll_flush(cx)
2019-02-18 12:41:52 +00:00
}
2019-05-19 16:28:27 +00:00
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
2019-04-22 04:08:13 +00:00
if self.state.writeable() {
self.session.send_close_notify();
self.state.shutdown_write();
2019-02-18 12:41:52 +00:00
}
2019-05-18 08:05:10 +00:00
let this = self.get_mut();
2019-05-18 10:18:26 +00:00
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
2019-05-20 17:47:50 +00:00
stream.pin().poll_close(cx)
2019-02-18 12:41:52 +00:00
}
}