package serverutil

import (
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestProxyHeaders(t *testing.T) {
	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		host, _, err := net.SplitHostPort(r.RemoteAddr)
		if err != nil {
			host = r.RemoteAddr
		}
		io.WriteString(w, host)
	})

	p1, err := newProxyHeaders(h, []string{"1.2.3.4/16"})
	if err != nil {
		t.Fatal(err)
	}
	srv1 := httptest.NewServer(p1)
	defer srv1.Close()

	p2, err := newProxyHeaders(h, []string{"::1/32", "127.0.0.1/8"})
	if err != nil {
		t.Fatal(err)
	}
	srv2 := httptest.NewServer(p2)
	defer srv2.Close()

	resp := doProxyRequest(t, srv1, map[string]string{
		"X-Real-IP": "1.2.3.4",
	})
	if resp != "127.0.0.1" && resp != "::1" {
		t.Errorf("request1 returned addr=%v", resp)
	}

	resp = doProxyRequest(t, srv2, map[string]string{
		"X-Real-IP": "1.2.3.4",
	})
	if resp != "1.2.3.4" {
		t.Errorf("request2 returned addr=%v", resp)
	}
}

func doProxyRequest(t testing.TB, s *httptest.Server, hdr map[string]string) string {
	req, err := http.NewRequest("GET", s.URL, nil)
	if err != nil {
		t.Fatalf("NewRequest(%s): %v", s.URL, err)
	}
	for k, v := range hdr {
		req.Header.Set(k, v)
	}
	c := &http.Client{}
	resp, err := c.Do(req)
	if err != nil {
		t.Fatalf("GET(%s): %v", s.URL, err)
	}
	defer resp.Body.Close()
	data, _ := ioutil.ReadAll(resp.Body)
	return string(data)
}