package httpsso

import (
	"crypto/tls"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/cookiejar"
	"net/http/httptest"
	"net/url"
	"testing"
	"time"

	"github.com/gorilla/mux"
	"github.com/gorilla/securecookie"
	"golang.org/x/crypto/ed25519"

	sso "git.autistici.org/id/go-sso"
)

func newTestHTTPClient() *http.Client {
	jar, _ := cookiejar.New(nil)
	return &http.Client{
		Jar: jar,
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		},
		// This client will not follow redirects.
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			return http.ErrUseLastResponse
		},
	}
}

const (
	testHost        = "service.example.com"
	testService     = "service.example.com/"
	testDomain      = "example.com"
	testLoginServer = "login.example.com"
)

func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte {
	resp, err := c.Get(base + path)
	if err != nil {
		t.Fatalf("Get(%s): %v", path, err)
	}
	if resp.StatusCode != http.StatusFound {
		t.Fatalf("Get(%s) expected 302, got %d", path, resp.StatusCode)
	}
	loc, err := url.Parse(resp.Header.Get("Location"))
	if err != nil {
		t.Fatalf("Get(%s) redirects to unparsable URL %s: %v", path, resp.Header.Get("Location"), err)
	}
	if loc.Host != testLoginServer {
		t.Fatalf("Get(%s) got bad redirect: %s", path, loc)
	}
	resp.Body.Close()

	// Sign a ticket, pretending we are the SSO server, then make
	// a new request to the sso_login endpoint.
	signer, err := sso.NewSigner(priv)
	if err != nil {
		t.Fatal(err)
	}
	nonce := loc.Query().Get("n")
	tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second)
	signed, err := signer.Sign(tkt)
	if err != nil {
		t.Fatal("Sign():", err)
	}
	u := make(url.Values)
	destURL := "https://" + testHost + "/test"
	u.Set("d", destURL)
	u.Set("t", signed)
	resp, err = c.Get(base + "/sso_login?" + u.Encode())
	if err != nil {
		t.Fatal("Get(/sso_login):", err)
	}
	if resp.StatusCode != http.StatusFound {
		t.Fatalf("Get(/sso_login) expected 302, got %d", resp.StatusCode)
	}
	resp.Body.Close()
	if s := resp.Header.Get("Location"); s != destURL {
		t.Fatalf("Get(/sso_login) redirects to unexpected location %s", s)
	}

	// Finally, requesting the original URL should work now.
	resp, err = c.Get(base + path)
	if err != nil {
		t.Fatalf("Get(%s, post-auth): %v", path, err)
	}
	if resp.StatusCode != 200 {
		t.Fatalf("Get(%s, post-auth) expected 200, got %d", path, resp.StatusCode)
	}
	data, _ := ioutil.ReadAll(resp.Body)
	resp.Body.Close()
	return data
}

func TestSSOWrapper(t *testing.T) {
	pub, priv, err := ed25519.GenerateKey(nil)
	if err != nil {
		t.Fatal(err)
	}

	// Build a test app - note that we want to use a gorilla Mux
	// here, otherwise cookie-based sessions won't work.
	m := mux.NewRouter()
	m.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) {
		io.WriteString(w, "OK")
	})

	w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32))
	if err != nil {
		t.Fatal("NewSSOWrapper():", err)
	}

	// Start a local test https server.
	srv := httptest.NewTLSServer(w.Wrap(m, testService, nil))
	defer srv.Close()

	// Request a sample URL.
	c := newTestHTTPClient()
	data := string(makeAuthRequest(t, c, srv.URL, "/test", testService, testDomain, priv))
	if data != "OK" {
		t.Fatalf("Get() returned bad data: %s", data)
	}
}