package serverutil import ( "fmt" "net" "net/http" "github.com/gorilla/handlers" ) type proxyHeaders struct { wrap, phWrap http.Handler forwarders []net.IPNet } func newProxyHeaders(h http.Handler, trustedForwarders []string) (http.Handler, error) { f, err := parseIPNetList(trustedForwarders) if err != nil { return nil, err } return &proxyHeaders{ wrap: h, phWrap: handlers.ProxyHeaders(h), forwarders: f, }, nil } func (p *proxyHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr } ip := net.ParseIP(host) if ip != nil && matchIPNetList(ip, p.forwarders) { p.phWrap.ServeHTTP(w, r) return } p.wrap.ServeHTTP(w, r) } func fullMask(ip net.IP) net.IPMask { if ip.To4() == nil { return net.CIDRMask(128, 128) } return net.CIDRMask(32, 32) } // ParseIPNetList turns a comma-separated list of IP addresses or CIDR // networks into a net.IPNet slice. func parseIPNetList(iplist []string) ([]net.IPNet, error) { var nets []net.IPNet for _, s := range iplist { if s == "" { continue } _, ipnet, err := net.ParseCIDR(s) if err != nil { ip := net.ParseIP(s) if ip == nil { return nil, fmt.Errorf("could not parse '%s'", s) } ipnet = &net.IPNet{IP: ip, Mask: fullMask(ip)} } nets = append(nets, *ipnet) } return nets, nil } // MatchIPNetList returns true if the given IP address matches one of // the specified networks. func matchIPNetList(ip net.IP, nets []net.IPNet) bool { for _, n := range nets { if n.Contains(ip) { return true } } return false }