package serverutil

import (
	"fmt"
	"net"
	"net/http"

	"github.com/gorilla/handlers"
)

type proxyHeaders struct {
	wrap, phWrap http.Handler
	forwarders   []net.IPNet
}

func newProxyHeaders(h http.Handler, trustedForwarders []string) (http.Handler, error) {
	f, err := parseIPNetList(trustedForwarders)
	if err != nil {
		return nil, err
	}
	return &proxyHeaders{
		wrap:       h,
		phWrap:     handlers.ProxyHeaders(h),
		forwarders: f,
	}, nil
}

func (p *proxyHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	host, _, err := net.SplitHostPort(r.RemoteAddr)
	if err != nil {
		host = r.RemoteAddr
	}
	ip := net.ParseIP(host)
	if ip != nil && matchIPNetList(ip, p.forwarders) {
		p.phWrap.ServeHTTP(w, r)
		return
	}
	p.wrap.ServeHTTP(w, r)
}

func fullMask(ip net.IP) net.IPMask {
	if ip.To4() == nil {
		return net.CIDRMask(128, 128)
	}
	return net.CIDRMask(32, 32)
}

// ParseIPNetList turns a comma-separated list of IP addresses or CIDR
// networks into a net.IPNet slice.
func parseIPNetList(iplist []string) ([]net.IPNet, error) {
	var nets []net.IPNet
	for _, s := range iplist {
		if s == "" {
			continue
		}
		_, ipnet, err := net.ParseCIDR(s)
		if err != nil {
			ip := net.ParseIP(s)
			if ip == nil {
				return nil, fmt.Errorf("could not parse '%s'", s)
			}
			ipnet = &net.IPNet{IP: ip, Mask: fullMask(ip)}
		}
		nets = append(nets, *ipnet)
	}
	return nets, nil
}

// MatchIPNetList returns true if the given IP address matches one of
// the specified networks.
func matchIPNetList(ip net.IP, nets []net.IPNet) bool {
	for _, n := range nets {
		if n.Contains(ip) {
			return true
		}
	}
	return false
}