Convert result in with_context function [tokio-native-tls] (#13)

This commit is contained in:
Kirill Fomichev 2020-05-06 23:48:43 +03:00 committed by GitHub
parent 9af6ed39a6
commit bd749ed734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -113,14 +113,17 @@ impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> R
fn with_context<F, R>(&mut self, f: F) -> io::Result<R>
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<io::Result<R>>,
{
unsafe {
assert!(!self.context.is_null());
let waker = &mut *(self.context as *mut _);
f(waker, Pin::new(&mut self.inner))
match f(waker, Pin::new(&mut self.inner)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
}
@ -130,10 +133,7 @@ where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
self.with_context(|ctx, stream| stream.poll_read(ctx, buf))
}
}
@ -142,37 +142,27 @@ where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
self.with_context(|ctx, stream| stream.poll_write(ctx, buf))
}
fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
self.with_context(|ctx, stream| stream.poll_flush(ctx))
}
}
impl<S> TlsStream<S> {
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> Poll<io::Result<R>>
where
F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> R,
F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> io::Result<R>,
AllowStd<S>: Read + Write,
{
self.0.get_mut().context = ctx as *mut _ as *mut ();
let g = Guard(self);
f(&mut (g.0).0)
match f(&mut (g.0).0) {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
/// Returns a shared reference to the inner stream.
@ -208,7 +198,7 @@ where
ctx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.read(buf)))
self.with_context(ctx, |s| s.read(buf))
}
}
@ -221,19 +211,15 @@ where
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.write(buf)))
self.with_context(ctx, |s| s.write(buf))
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| cvt(s.flush()))
self.with_context(ctx, |s| s.flush())
}
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.with_context(ctx, |s| s.shutdown()) {
Ok(()) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
self.with_context(ctx, |s| s.shutdown())
}
}