package proxy

import (
	"crypto/rand"
	"crypto/tls"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"net/http/cookiejar"
	"net/http/httptest"
	"net/url"
	"os"
	"testing"
	"time"

	sso "git.autistici.org/id/go-sso"
	"github.com/gorilla/securecookie"
	"golang.org/x/crypto/ed25519"
)

func createTestServer(t testing.TB) *httptest.Server {
	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		io.WriteString(w, "OK")
	}))
}

func TestProxy(t *testing.T) {
	tmpdir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(tmpdir)

	pub, priv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
		t.Fatal(err)
	}
	ioutil.WriteFile(tmpdir+"/public.key", pub, 0644)

	targetSrv := createTestServer(t)
	targetURL, _ := url.Parse(targetSrv.URL)
	defer targetSrv.Close()

	config := &Config{
		SessionAuthKey:    string(securecookie.GenerateRandomKey(64)),
		SessionEncKey:     string(securecookie.GenerateRandomKey(32)),
		SSOLoginServerURL: "https://login.example.com/",
		SSOPublicKeyFile:  tmpdir + "/public.key",
		SSODomain:         "example.com",
		Backends: []*Backend{
			&Backend{
				Host:     "test.example.com",
				Upstream: []string{targetURL.Host},
			},
		},
	}

	proxy, err := NewProxy(config)
	if err != nil {
		t.Fatal(err)
	}
	proxySrv := httptest.NewTLSServer(proxy)
	defer proxySrv.Close()

	c := newTestHTTPClient(proxySrv.URL)

	data := string(makeAuthRequest(t, c, "https://test.example.com", "/", "test.example.com/", "example.com", priv))
	if data != "OK" {
		t.Fatalf("bad response: %s", data)
	}
}

// Create a http.Client locked to a specific address - no matter the
// URL, the underlying transport will make a connection to the server
// specified in uri.
func newTestHTTPClient(uri string) *http.Client {
	u, _ := url.Parse(uri)
	addr := u.Host
	jar, _ := cookiejar.New(nil)
	return &http.Client{
		Jar: jar,
		Transport: &http.Transport{
			Dial: func(n, _ string) (net.Conn, error) {
				return net.Dial(n, addr)
			},
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		},
		// This client will not follow redirects.
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			return http.ErrUseLastResponse
		},
	}
}

func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte {
	resp, err := c.Get(base + path)
	if err != nil {
		t.Fatalf("Get(%s): %v", path, err)
	}
	if resp.StatusCode != http.StatusFound {
		t.Fatalf("Get(%s) expected 302, got %d", path, resp.StatusCode)
	}
	loc, err := url.Parse(resp.Header.Get("Location"))
	if err != nil {
		t.Fatalf("Get(%s) redirects to unparsable URL %s: %v", path, resp.Header.Get("Location"), err)
	}
	// if loc.Host != testLoginServer {
	// 	t.Fatalf("Get(%s) got bad redirect: %s", path, loc)
	// }
	resp.Body.Close()

	// Sign a ticket, pretending we are the SSO server, then make
	// a new request to the sso_login endpoint.
	signer, err := sso.NewSigner(priv)
	if err != nil {
		t.Fatal(err)
	}
	nonce := loc.Query().Get("n")
	tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second)
	signed, err := signer.Sign(tkt)
	if err != nil {
		t.Fatal("Sign():", err)
	}
	u := make(url.Values)
	destURL := base + path
	u.Set("d", destURL)
	u.Set("t", signed)
	resp, err = c.Get(base + "/sso_login?" + u.Encode())
	if err != nil {
		t.Fatal("Get(/sso_login):", err)
	}
	if resp.StatusCode != http.StatusFound {
		t.Fatalf("Get(/sso_login) expected 302, got %d", resp.StatusCode)
	}
	resp.Body.Close()
	if s := resp.Header.Get("Location"); s != destURL {
		t.Fatalf("Get(/sso_login) redirects to unexpected location %s", s)
	}

	// Finally, requesting the original URL should work now.
	resp, err = c.Get(base + path)
	if err != nil {
		t.Fatalf("Get(%s, post-auth): %v", path, err)
	}
	if resp.StatusCode != 200 {
		t.Fatalf("Get(%s, post-auth) expected 200, got %d", path, resp.StatusCode)
	}
	data, _ := ioutil.ReadAll(resp.Body)
	resp.Body.Close()
	return data
}