From 32a44a2033334e8198a85f72a65056dd955ff533 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Thu, 9 Aug 2018 13:20:16 -0600 Subject: [PATCH] mhttp: implement AddXForwardedFor --- mhttp/mhttp.go | 18 ++++++++++++++++++ mhttp/mhttp_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 mhttp/mhttp_test.go diff --git a/mhttp/mhttp.go b/mhttp/mhttp.go index f281e50..d518714 100644 --- a/mhttp/mhttp.go +++ b/mhttp/mhttp.go @@ -4,11 +4,14 @@ package mhttp import ( "context" + "net" "net/http" + "strings" "github.com/mediocregopher/mediocre-go-lib/m" "github.com/mediocregopher/mediocre-go-lib/mcfg" "github.com/mediocregopher/mediocre-go-lib/mlog" + "github.com/mediocregopher/mediocre-go-lib/mnet" ) // CfgServer initializes and returns an *http.Server which will initialize on @@ -36,3 +39,18 @@ func CfgServer(cfg *mcfg.Cfg, h http.Handler) *http.Server { // TODO shutdown logic return &srv } + +// AddXForwardedFor populates the X-Forwarded-For header on the Request to +// convey that the request is being proxied for IP. +// +// If the IP is invalid, loopback, or otherwise part of a reserved range, this +// does nothing. +func AddXForwardedFor(r *http.Request, ipStr string) { + const xff = "X-Forwarded-For" + ip := net.ParseIP(ipStr) + if ip == nil || mnet.IsReservedIP(ip) { // IsReservedIP includes loopback + return + } + prev, _ := r.Header[xff] + r.Header.Set(xff, strings.Join(append(prev, ip.String()), ", ")) +} diff --git a/mhttp/mhttp_test.go b/mhttp/mhttp_test.go new file mode 100644 index 0000000..af92d05 --- /dev/null +++ b/mhttp/mhttp_test.go @@ -0,0 +1,41 @@ +package mhttp + +import ( + "net/http/httptest" + . "testing" + + "github.com/mediocregopher/mediocre-go-lib/mtest/massert" +) + +func TestAddXForwardedFor(t *T) { + assertXFF := func(prev []string, ipStr, expected string) massert.Assertion { + r := httptest.NewRequest("GET", "/", nil) + for i := range prev { + r.Header.Add("X-Forwarded-For", prev[i]) + } + AddXForwardedFor(r, ipStr) + var a massert.Assertion + if expected == "" { + a = massert.Len(r.Header["X-Forwarded-For"], 0) + } else { + a = massert.All( + massert.Len(r.Header["X-Forwarded-For"], 1), + massert.Equal(expected, r.Header["X-Forwarded-For"][0]), + ) + } + return massert.Comment(a, "prev:%#v ipStr:%q", prev, ipStr) + } + + massert.Fatal(t, massert.All( + assertXFF(nil, "invalid", ""), + assertXFF(nil, "::1", ""), + assertXFF([]string{"8.0.0.0"}, "invalid", "8.0.0.0"), + assertXFF([]string{"8.0.0.0"}, "::1", "8.0.0.0"), + + assertXFF(nil, "8.0.0.0", "8.0.0.0"), + assertXFF([]string{"8.0.0.0"}, "8.0.0.1", "8.0.0.0, 8.0.0.1"), + assertXFF([]string{"8.0.0.0, 8.0.0.1"}, "8.0.0.2", "8.0.0.0, 8.0.0.1, 8.0.0.2"), + assertXFF([]string{"8.0.0.0, 8.0.0.1", "8.0.0.2"}, "8.0.0.3", + "8.0.0.0, 8.0.0.1, 8.0.0.2, 8.0.0.3"), + )) +}