mhttp: implement AddXForwardedFor

This commit is contained in:
Brian Picciano 2018-08-09 13:20:16 -06:00
parent dcf9f73bcb
commit 32a44a2033
2 changed files with 59 additions and 0 deletions

View File

@ -4,11 +4,14 @@ package mhttp
import ( import (
"context" "context"
"net"
"net/http" "net/http"
"strings"
"github.com/mediocregopher/mediocre-go-lib/m" "github.com/mediocregopher/mediocre-go-lib/m"
"github.com/mediocregopher/mediocre-go-lib/mcfg" "github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mlog" "github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mnet"
) )
// CfgServer initializes and returns an *http.Server which will initialize on // CfgServer initializes and returns an *http.Server which will initialize on
@ -36,3 +39,18 @@ func CfgServer(cfg *mcfg.Cfg, h http.Handler) *http.Server {
// TODO shutdown logic // TODO shutdown logic
return &srv return &srv
} }
// AddXForwardedFor populates the X-Forwarded-For header on the Request to
// convey that the request is being proxied for IP.
//
// If the IP is invalid, loopback, or otherwise part of a reserved range, this
// does nothing.
func AddXForwardedFor(r *http.Request, ipStr string) {
const xff = "X-Forwarded-For"
ip := net.ParseIP(ipStr)
if ip == nil || mnet.IsReservedIP(ip) { // IsReservedIP includes loopback
return
}
prev, _ := r.Header[xff]
r.Header.Set(xff, strings.Join(append(prev, ip.String()), ", "))
}

41
mhttp/mhttp_test.go Normal file
View File

@ -0,0 +1,41 @@
package mhttp
import (
"net/http/httptest"
. "testing"
"github.com/mediocregopher/mediocre-go-lib/mtest/massert"
)
func TestAddXForwardedFor(t *T) {
assertXFF := func(prev []string, ipStr, expected string) massert.Assertion {
r := httptest.NewRequest("GET", "/", nil)
for i := range prev {
r.Header.Add("X-Forwarded-For", prev[i])
}
AddXForwardedFor(r, ipStr)
var a massert.Assertion
if expected == "" {
a = massert.Len(r.Header["X-Forwarded-For"], 0)
} else {
a = massert.All(
massert.Len(r.Header["X-Forwarded-For"], 1),
massert.Equal(expected, r.Header["X-Forwarded-For"][0]),
)
}
return massert.Comment(a, "prev:%#v ipStr:%q", prev, ipStr)
}
massert.Fatal(t, massert.All(
assertXFF(nil, "invalid", ""),
assertXFF(nil, "::1", ""),
assertXFF([]string{"8.0.0.0"}, "invalid", "8.0.0.0"),
assertXFF([]string{"8.0.0.0"}, "::1", "8.0.0.0"),
assertXFF(nil, "8.0.0.0", "8.0.0.0"),
assertXFF([]string{"8.0.0.0"}, "8.0.0.1", "8.0.0.0, 8.0.0.1"),
assertXFF([]string{"8.0.0.0, 8.0.0.1"}, "8.0.0.2", "8.0.0.0, 8.0.0.1, 8.0.0.2"),
assertXFF([]string{"8.0.0.0, 8.0.0.1", "8.0.0.2"}, "8.0.0.3",
"8.0.0.0, 8.0.0.1, 8.0.0.2, 8.0.0.3"),
))
}