package serverutil import ( "fmt" "net" "net/http" "strings" ) type proxyHeaders struct { wrap http.Handler forwarders []net.IPNet } func newProxyHeaders(h http.Handler, trustedForwarders []string) (http.Handler, error) { f, err := parseIPNetList(trustedForwarders) if err != nil { return nil, err } return &proxyHeaders{ wrap: h, forwarders: f, }, nil } func (p *proxyHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr } ip := net.ParseIP(host) if ip != nil && matchIPNetList(ip, p.forwarders) { if fwd := getForwardedIP(r); fwd != "" { r.RemoteAddr = fwd } } p.wrap.ServeHTTP(w, r) } // Parse the X-Real-IP or X-Forwarded-For headers, if present, to get // the original client IP. func getForwardedIP(r *http.Request) string { if s := r.Header.Get("X-Real-IP"); s != "" { return s } if s := r.Header.Get("X-Forwarded-For"); s != "" { if n := strings.IndexByte(s, ','); n > 0 { s = s[:n] } return s } return "" } func fullMask(ip net.IP) net.IPMask { if ip.To4() == nil { return net.CIDRMask(128, 128) } return net.CIDRMask(32, 32) } // ParseIPNetList turns a comma-separated list of IP addresses or CIDR // networks into a net.IPNet slice. func parseIPNetList(iplist []string) ([]net.IPNet, error) { var nets []net.IPNet for _, s := range iplist { if s == "" { continue } _, ipnet, err := net.ParseCIDR(s) if err != nil { ip := net.ParseIP(s) if ip == nil { return nil, fmt.Errorf("could not parse '%s'", s) } ipnet = &net.IPNet{IP: ip, Mask: fullMask(ip)} } nets = append(nets, *ipnet) } return nets, nil } // MatchIPNetList returns true if the given IP address matches one of // the specified networks. func matchIPNetList(ip net.IP, nets []net.IPNet) bool { for _, n := range nets { if n.Contains(ip) { return true } } return false }