From be4da336395850dd80cfaf36bfa0e364ec22a30f Mon Sep 17 00:00:00 2001 From: ale <ale@incal.net> Date: Sun, 14 Jan 2018 09:57:18 +0000 Subject: [PATCH] Fix HTTP request when behind a reverse proxy Allows users to specify a set of trusted forwarders in the configuration, and then using github.com/gorilla/handlers.ProxyHeaders to rewrite the http.Request parameters according to X-Forwarding-* and X-Real-IP headers. --- serverutil/proxy_headers.go | 78 ++++++++++++++++++++++++++++++++ serverutil/proxy_headers_test.go | 66 +++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 serverutil/proxy_headers.go create mode 100644 serverutil/proxy_headers_test.go diff --git a/serverutil/proxy_headers.go b/serverutil/proxy_headers.go new file mode 100644 index 0000000..00480b9 --- /dev/null +++ b/serverutil/proxy_headers.go @@ -0,0 +1,78 @@ +package serverutil + +import ( + "fmt" + "net" + "net/http" + + "github.com/gorilla/handlers" +) + +type proxyHeaders struct { + wrap, phWrap http.Handler + forwarders []net.IPNet +} + +func newProxyHeaders(h http.Handler, trustedForwarders []string) (http.Handler, error) { + f, err := parseIPNetList(trustedForwarders) + if err != nil { + return nil, err + } + return &proxyHeaders{ + wrap: h, + phWrap: handlers.ProxyHeaders(h), + forwarders: f, + }, nil +} + +func (p *proxyHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip != nil && matchIPNetList(ip, p.forwarders) { + p.phWrap.ServeHTTP(w, r) + return + } + p.wrap.ServeHTTP(w, r) +} + +func fullMask(ip net.IP) net.IPMask { + if ip.To4() == nil { + return net.CIDRMask(128, 128) + } + return net.CIDRMask(32, 32) +} + +// ParseIPNetList turns a comma-separated list of IP addresses or CIDR +// networks into a net.IPNet slice. +func parseIPNetList(iplist []string) ([]net.IPNet, error) { + var nets []net.IPNet + for _, s := range iplist { + if s == "" { + continue + } + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + ip := net.ParseIP(s) + if ip == nil { + return nil, fmt.Errorf("could not parse '%s'", s) + } + ipnet = &net.IPNet{IP: ip, Mask: fullMask(ip)} + } + nets = append(nets, *ipnet) + } + return nets, nil +} + +// MatchIPNetList returns true if the given IP address matches one of +// the specified networks. +func matchIPNetList(ip net.IP, nets []net.IPNet) bool { + for _, n := range nets { + if n.Contains(ip) { + return true + } + } + return false +} diff --git a/serverutil/proxy_headers_test.go b/serverutil/proxy_headers_test.go new file mode 100644 index 0000000..4ec647d --- /dev/null +++ b/serverutil/proxy_headers_test.go @@ -0,0 +1,66 @@ +package serverutil + +import ( + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestProxyHeaders(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + io.WriteString(w, host) + }) + + p1, err := newProxyHeaders(h, []string{"1.2.3.4/16"}) + if err != nil { + t.Fatal(err) + } + srv1 := httptest.NewServer(p1) + defer srv1.Close() + + p2, err := newProxyHeaders(h, []string{"::1/32", "127.0.0.1/8"}) + if err != nil { + t.Fatal(err) + } + srv2 := httptest.NewServer(p2) + defer srv2.Close() + + resp := doProxyRequest(t, srv1, map[string]string{ + "X-Real-IP": "1.2.3.4", + }) + if resp != "127.0.0.1" && resp != "::1" { + t.Errorf("request1 returned addr=%v", resp) + } + + resp = doProxyRequest(t, srv2, map[string]string{ + "X-Real-IP": "1.2.3.4", + }) + if resp != "1.2.3.4" { + t.Errorf("request2 returned addr=%v", resp) + } +} + +func doProxyRequest(t testing.TB, s *httptest.Server, hdr map[string]string) string { + req, err := http.NewRequest("GET", s.URL, nil) + if err != nil { + t.Fatalf("NewRequest(%s): %v", s.URL, err) + } + for k, v := range hdr { + req.Header.Set(k, v) + } + c := &http.Client{} + resp, err := c.Do(req) + if err != nil { + t.Fatalf("GET(%s): %v", s.URL, err) + } + defer resp.Body.Close() + data, _ := ioutil.ReadAll(resp.Body) + return string(data) +} -- GitLab