Add tests for forwarded for header

This commit is contained in:
Brendan Zabarauskas 2017-07-15 16:54:46 +10:00
parent 4bbb0340ff
commit e0eb5dfe18
2 changed files with 88 additions and 20 deletions

View File

@ -9,9 +9,8 @@ extern crate unicase;
extern crate void; extern crate void;
use futures::future::Future; use futures::future::Future;
use hyper::{Body, Request, Response, StatusCode}; use hyper::{Body, Headers, Request, Response, StatusCode};
use hyper::server::Service; use hyper::server::Service;
use hyper::header::Headers;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::net::IpAddr; use std::net::IpAddr;
use void::Void; use void::Void;
@ -61,6 +60,24 @@ header! {
/// * `203.0.113.195` /// * `203.0.113.195`
/// * `203.0.113.195, 70.41.3.18, 150.172.238.178` /// * `203.0.113.195, 70.41.3.18, 150.172.238.178`
/// ///
/// # Examples
///
/// ```
/// # extern crate hyper;
/// # extern crate hyper_reverse_proxy;
/// use hyper::Headers;
/// use hyper_reverse_proxy::XForwardedFor;
/// use std::net::{Ipv4Addr, Ipv6Addr};
///
/// # fn main() {
/// let mut headers = Headers::new();
/// headers.set(XForwardedFor(vec![
/// Ipv4Addr::new(127, 0, 0, 1).into(),
/// Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into(),
/// ]));
/// # }
/// ```
///
/// # References /// # References
/// ///
/// - [MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) /// - [MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For)
@ -93,16 +110,16 @@ fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
/// [`httputil.ReverseProxy`]: https://golang.org/pkg/net/http/httputil/#ReverseProxy /// [`httputil.ReverseProxy`]: https://golang.org/pkg/net/http/httputil/#ReverseProxy
pub struct ReverseProxy<C: Service, B = Body> { pub struct ReverseProxy<C: Service, B = Body> {
client: C, client: C,
remote_ip_addr: Option<IpAddr>, remote_ip: Option<IpAddr>,
_pantom_data: PhantomData<B>, _pantom_data: PhantomData<B>,
} }
impl<C: Service, B> ReverseProxy<C, B> { impl<C: Service, B> ReverseProxy<C, B> {
/// Construct a reverse proxy that dispatches to the given client. /// Construct a reverse proxy that dispatches to the given client.
pub fn new(client: C, remote_ip_addr: Option<IpAddr>) -> ReverseProxy<C, B> { pub fn new(client: C, remote_ip: Option<IpAddr>) -> ReverseProxy<C, B> {
ReverseProxy { ReverseProxy {
client, client,
remote_ip_addr, remote_ip,
_pantom_data: PhantomData, _pantom_data: PhantomData,
} }
} }
@ -111,15 +128,15 @@ impl<C: Service, B> ReverseProxy<C, B> {
*request.headers_mut() = remove_hop_headers(request.headers()); *request.headers_mut() = remove_hop_headers(request.headers());
// Add forwarding information in the headers // Add forwarding information in the headers
if let Some(ip_addr) = self.remote_ip_addr { if let Some(ip) = self.remote_ip {
// This is kind of ugly because of borrowing. Maybe hyper's `Headers` object // This is kind of ugly because of borrowing. Maybe hyper's `Headers` object
// could use an entry API like `std::collections::HashMap`? // could use an entry API like `std::collections::HashMap`?
if request.headers().has::<XForwardedFor>() { if request.headers().has::<XForwardedFor>() {
if let Some(prior) = request.headers_mut().get_mut::<XForwardedFor>() { if let Some(prior) = request.headers_mut().get_mut::<XForwardedFor>() {
prior.0.push(ip_addr); prior.0.push(ip);
} }
} else { } else {
let header = XForwardedFor(vec![ip_addr]); let header = XForwardedFor(vec![ip]);
request.headers_mut().set(header); request.headers_mut().set(header);
} }
} }

View File

@ -22,10 +22,57 @@ impl<F: Fn(Request) -> Response> Service for MockService<F> {
} }
#[test] #[test]
#[ignore] fn begins_forwarded_for_header() {
fn adds_forwarded_for_header() { use hyper_reverse_proxy::XForwardedFor;
// TODO: https://github.com/hyperium/hyper/issues/1258 use std::net::Ipv6Addr;
unimplemented!()
let mut request = Request::new(Get, "/".parse().unwrap());
request.set_body("request");
let remote_ip = Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8);
let client = MockService(|request| {
assert_eq!(
request.headers().get::<XForwardedFor>(),
Some(&XForwardedFor(
vec![Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8).into()],
))
);
Response::new()
});
let service = ReverseProxy::new(client, Some(remote_ip.into()));
service.call(request).wait().unwrap();
}
#[test]
fn continues_forwarded_for_header() {
use hyper_reverse_proxy::XForwardedFor;
use std::net::{Ipv4Addr, Ipv6Addr};
let mut request = Request::new(Get, "/".parse().unwrap());
request.set_body("request");
request.headers_mut().set(XForwardedFor(vec![
Ipv4Addr::new(127, 0, 0, 1).into(),
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into(),
]));
let remote_ip = Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8);
let client = MockService(|request| {
assert_eq!(
request.headers().get::<XForwardedFor>(),
Some(&XForwardedFor(vec![
Ipv4Addr::new(127, 0, 0, 1).into(),
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into(),
Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8).into(),
]))
);
Response::new()
});
let service = ReverseProxy::new(client, Some(remote_ip.into()));
service.call(request).wait().unwrap();
} }
#[test] #[test]
@ -35,12 +82,13 @@ fn forwards_the_bodies() {
let mut request = Request::new(Get, "/".parse().unwrap()); let mut request = Request::new(Get, "/".parse().unwrap());
request.set_body("request"); request.set_body("request");
let service = ReverseProxy::new(MockService(|request| { let client = MockService(|request| {
let body = request.body().concat2().wait().unwrap(); let body = request.body().concat2().wait().unwrap();
assert_eq!(body.as_ref(), b"request"); assert_eq!(body.as_ref(), b"request");
Response::new().with_body("response") Response::new().with_body("response")
})); });
let service = ReverseProxy::new(client, None);
let response = service.call(request).wait().unwrap(); let response = service.call(request).wait().unwrap();
let body = response.body().concat2().wait().unwrap(); let body = response.body().concat2().wait().unwrap();
@ -56,13 +104,14 @@ fn clones_headers() {
request.headers_mut().set(XTestHeader1("Test1".to_owned())); request.headers_mut().set(XTestHeader1("Test1".to_owned()));
request.headers_mut().set(XTestHeader2("Test2".to_owned())); request.headers_mut().set(XTestHeader2("Test2".to_owned()));
let service = ReverseProxy::new(MockService(|request| { let client = MockService(|request| {
let header1 = request.headers().get::<XTestHeader1>().unwrap(); let header1 = request.headers().get::<XTestHeader1>().unwrap();
let header2 = request.headers().get::<XTestHeader2>().unwrap(); let header2 = request.headers().get::<XTestHeader2>().unwrap();
assert_eq!(header1, &XTestHeader1("Test1".to_owned())); assert_eq!(header1, &XTestHeader1("Test1".to_owned()));
assert_eq!(header2, &XTestHeader2("Test2".to_owned())); assert_eq!(header2, &XTestHeader2("Test2".to_owned()));
Response::new() Response::new()
})); });
let service = ReverseProxy::new(client, None);
service.call(request).wait().unwrap(); service.call(request).wait().unwrap();
} }
@ -81,7 +130,7 @@ fn removes_request_hop_headers() {
request.headers_mut().set(TransferEncoding(vec![])); request.headers_mut().set(TransferEncoding(vec![]));
request.headers_mut().set(Upgrade(vec![])); request.headers_mut().set(Upgrade(vec![]));
let service = ReverseProxy::new(MockService(|request| { let client = MockService(|request| {
assert_eq!(request.headers().get::<Connection>(), None); assert_eq!(request.headers().get::<Connection>(), None);
assert_eq!(request.headers().get_raw("Keep-Alive"), None); assert_eq!(request.headers().get_raw("Keep-Alive"), None);
assert_eq!(request.headers().get_raw("Proxy-Authenticate"), None); assert_eq!(request.headers().get_raw("Proxy-Authenticate"), None);
@ -91,7 +140,8 @@ fn removes_request_hop_headers() {
assert_eq!(request.headers().get::<TransferEncoding>(), None); assert_eq!(request.headers().get::<TransferEncoding>(), None);
assert_eq!(request.headers().get::<Upgrade>(), None); assert_eq!(request.headers().get::<Upgrade>(), None);
Response::new() Response::new()
})); });
let service = ReverseProxy::new(client, None);
service.call(request).wait().unwrap(); service.call(request).wait().unwrap();
} }
@ -102,7 +152,7 @@ fn removes_response_hop_headers() {
let request = Request::new(Get, "/".parse().unwrap()); let request = Request::new(Get, "/".parse().unwrap());
let service = ReverseProxy::new(MockService(|_| { let client = MockService(|_| {
let mut response = Response::new(); let mut response = Response::new();
response.headers_mut().set(Connection(vec![])); response.headers_mut().set(Connection(vec![]));
response.headers_mut().set_raw("Keep-Alive", ""); response.headers_mut().set_raw("Keep-Alive", "");
@ -113,7 +163,8 @@ fn removes_response_hop_headers() {
response.headers_mut().set(TransferEncoding(vec![])); response.headers_mut().set(TransferEncoding(vec![]));
response.headers_mut().set(Upgrade(vec![])); response.headers_mut().set(Upgrade(vec![]));
response response
})); });
let service = ReverseProxy::new(client, None);
let response = service.call(request).wait().unwrap(); let response = service.call(request).wait().unwrap();
assert_eq!(response.headers().get::<Connection>(), None); assert_eq!(response.headers().get::<Connection>(), None);