diff --git a/serverutil/proxy_headers.go b/serverutil/proxy_headers.go new file mode 100644 index 0000000000000000000000000000000000000000..00480b93ec3fb54b4843c7cc8709cf750591f363 --- /dev/null +++ b/serverutil/proxy_headers.go @@ -0,0 +1,78 @@ +package serverutil + +import ( + "fmt" + "net" + "net/http" + + "github.com/gorilla/handlers" +) + +type proxyHeaders struct { + wrap, phWrap http.Handler + forwarders []net.IPNet +} + +func newProxyHeaders(h http.Handler, trustedForwarders []string) (http.Handler, error) { + f, err := parseIPNetList(trustedForwarders) + if err != nil { + return nil, err + } + return &proxyHeaders{ + wrap: h, + phWrap: handlers.ProxyHeaders(h), + forwarders: f, + }, nil +} + +func (p *proxyHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip != nil && matchIPNetList(ip, p.forwarders) { + p.phWrap.ServeHTTP(w, r) + return + } + p.wrap.ServeHTTP(w, r) +} + +func fullMask(ip net.IP) net.IPMask { + if ip.To4() == nil { + return net.CIDRMask(128, 128) + } + return net.CIDRMask(32, 32) +} + +// ParseIPNetList turns a comma-separated list of IP addresses or CIDR +// networks into a net.IPNet slice. +func parseIPNetList(iplist []string) ([]net.IPNet, error) { + var nets []net.IPNet + for _, s := range iplist { + if s == "" { + continue + } + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + ip := net.ParseIP(s) + if ip == nil { + return nil, fmt.Errorf("could not parse '%s'", s) + } + ipnet = &net.IPNet{IP: ip, Mask: fullMask(ip)} + } + nets = append(nets, *ipnet) + } + return nets, nil +} + +// MatchIPNetList returns true if the given IP address matches one of +// the specified networks. +func matchIPNetList(ip net.IP, nets []net.IPNet) bool { + for _, n := range nets { + if n.Contains(ip) { + return true + } + } + return false +} diff --git a/serverutil/proxy_headers_test.go b/serverutil/proxy_headers_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4ec647d5afe967ebc8cbb96faf91ffa7b2cf6cec --- /dev/null +++ b/serverutil/proxy_headers_test.go @@ -0,0 +1,66 @@ +package serverutil + +import ( + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestProxyHeaders(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + io.WriteString(w, host) + }) + + p1, err := newProxyHeaders(h, []string{"1.2.3.4/16"}) + if err != nil { + t.Fatal(err) + } + srv1 := httptest.NewServer(p1) + defer srv1.Close() + + p2, err := newProxyHeaders(h, []string{"::1/32", "127.0.0.1/8"}) + if err != nil { + t.Fatal(err) + } + srv2 := httptest.NewServer(p2) + defer srv2.Close() + + resp := doProxyRequest(t, srv1, map[string]string{ + "X-Real-IP": "1.2.3.4", + }) + if resp != "127.0.0.1" && resp != "::1" { + t.Errorf("request1 returned addr=%v", resp) + } + + resp = doProxyRequest(t, srv2, map[string]string{ + "X-Real-IP": "1.2.3.4", + }) + if resp != "1.2.3.4" { + t.Errorf("request2 returned addr=%v", resp) + } +} + +func doProxyRequest(t testing.TB, s *httptest.Server, hdr map[string]string) string { + req, err := http.NewRequest("GET", s.URL, nil) + if err != nil { + t.Fatalf("NewRequest(%s): %v", s.URL, err) + } + for k, v := range hdr { + req.Header.Set(k, v) + } + c := &http.Client{} + resp, err := c.Do(req) + if err != nil { + t.Fatalf("GET(%s): %v", s.URL, err) + } + defer resp.Body.Close() + data, _ := ioutil.ReadAll(resp.Body) + return string(data) +}