Convert result in with_context function [tokio-native-tls] (#13)

This commit is contained in:
Kirill Fomichev 2020-05-06 23:48:43 +03:00 committed by GitHub
parent 9af6ed39a6
commit bd749ed734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -113,14 +113,17 @@ impl<S> AllowStd<S>
where where
S: Unpin, S: Unpin,
{ {
fn with_context<F, R>(&mut self, f: F) -> R fn with_context<F, R>(&mut self, f: F) -> io::Result<R>
where where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<io::Result<R>>,
{ {
unsafe { unsafe {
assert!(!self.context.is_null()); assert!(!self.context.is_null());
let waker = &mut *(self.context as *mut _); let waker = &mut *(self.context as *mut _);
f(waker, Pin::new(&mut self.inner)) match f(waker, Pin::new(&mut self.inner)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
} }
} }
} }
@ -130,10 +133,7 @@ where
S: AsyncRead + Unpin, S: AsyncRead + Unpin,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { self.with_context(|ctx, stream| stream.poll_read(ctx, buf))
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
} }
} }
@ -142,37 +142,27 @@ where
S: AsyncWrite + Unpin, S: AsyncWrite + Unpin,
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { self.with_context(|ctx, stream| stream.poll_write(ctx, buf))
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { self.with_context(|ctx, stream| stream.poll_flush(ctx))
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
} }
} }
impl<S> TlsStream<S> { impl<S> TlsStream<S> {
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> Poll<io::Result<R>>
where where
F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> R, F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> io::Result<R>,
AllowStd<S>: Read + Write, AllowStd<S>: Read + Write,
{ {
self.0.get_mut().context = ctx as *mut _ as *mut (); self.0.get_mut().context = ctx as *mut _ as *mut ();
let g = Guard(self); let g = Guard(self);
f(&mut (g.0).0) match f(&mut (g.0).0) {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
} }
/// Returns a shared reference to the inner stream. /// Returns a shared reference to the inner stream.
@ -208,7 +198,7 @@ where
ctx: &mut Context<'_>, ctx: &mut Context<'_>,
buf: &mut [u8], buf: &mut [u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.read(buf))) self.with_context(ctx, |s| s.read(buf))
} }
} }
@ -221,19 +211,15 @@ where
ctx: &mut Context<'_>, ctx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.write(buf))) self.with_context(ctx, |s| s.write(buf))
} }
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| cvt(s.flush())) self.with_context(ctx, |s| s.flush())
} }
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.with_context(ctx, |s| s.shutdown()) { self.with_context(ctx, |s| s.shutdown())
Ok(()) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
} }
} }