package serverutil

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"io"
	"io/ioutil"
	"math/big"
	"net/http"
	"os"
	"testing"
	"time"

	"git.autistici.org/ai3/go-common"
)

func saveCertificate(cert *x509.Certificate, path string) {
	data := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
	ioutil.WriteFile(path, data, 0644) // nolint: errcheck
}

func savePrivateKey(pkey *ecdsa.PrivateKey, path string) {
	der, _ := x509.MarshalECPrivateKey(pkey)
	data := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
	ioutil.WriteFile(path, data, 0600) // nolint: errcheck
}

func generateTestCA(t testing.TB, dir string) (*x509.Certificate, *ecdsa.PrivateKey) {
	pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		t.Fatal(err)
	}

	now := time.Now().UTC()
	template := x509.Certificate{
		SerialNumber:          big.NewInt(1),
		Subject:               pkix.Name{CommonName: "CA"},
		NotBefore:             now.Add(-5 * time.Minute),
		NotAfter:              now.AddDate(5, 0, 0), // 5 years.
		SignatureAlgorithm:    x509.ECDSAWithSHA256,
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		BasicConstraintsValid: true,
		IsCA:                  true,
		MaxPathLen:            1,
	}
	der, err := x509.CreateCertificate(rand.Reader, &template, &template, pkey.Public(), pkey)
	if err != nil {
		t.Fatal(err)
	}
	cert, _ := x509.ParseCertificate(der)
	saveCertificate(cert, dir+"/ca.pem")
	return cert, pkey
}

func generateTestCert(t testing.TB, cacert *x509.Certificate, cakey *ecdsa.PrivateKey, dir, name string, subj pkix.Name, altNames []string, isClient, isServer bool) {
	pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		t.Fatal(err)
	}

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		t.Fatal(err)
	}

	var extUsage []x509.ExtKeyUsage
	if isServer {
		extUsage = append(extUsage, x509.ExtKeyUsageServerAuth)
	}
	if isClient {
		extUsage = append(extUsage, x509.ExtKeyUsageClientAuth)
	}

	now := time.Now().UTC()
	template := x509.Certificate{
		SerialNumber:          serialNumber,
		Subject:               subj,
		DNSNames:              altNames,
		NotBefore:             now.Add(-5 * time.Minute),
		NotAfter:              now.Add(24 * time.Hour),
		SignatureAlgorithm:    x509.ECDSAWithSHA256,
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
		ExtKeyUsage:           extUsage,
		PublicKey:             pkey.PublicKey,
		BasicConstraintsValid: true,
	}
	der, err := x509.CreateCertificate(rand.Reader, &template, cacert, pkey.Public(), cakey)
	if err != nil {
		t.Fatal(err)
	}
	cert, _ := x509.ParseCertificate(der)

	savePrivateKey(pkey, fmt.Sprintf("%s/%s_key.pem", dir, name))
	saveCertificate(cert, fmt.Sprintf("%s/%s_cert.pem", dir, name))
}

func generateTestPKI(t *testing.T) string {
	dir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}

	cacert, cakey := generateTestCA(t, dir)
	generateTestCert(t, cacert, cakey, dir, "server", pkix.Name{CommonName: "server"}, []string{"server", "localhost"}, false, true)
	generateTestCert(t, cacert, cakey, dir, "client1", pkix.Name{CommonName: "client1"}, nil, true, false)
	generateTestCert(t, cacert, cakey, dir, "client2", pkix.Name{CommonName: "client2"}, nil, true, false)
	return dir
}

func newTestClient(t testing.TB, dir, name string) *http.Client {
	cert, err := tls.LoadX509KeyPair(
		fmt.Sprintf("%s/%s_cert.pem", dir, name),
		fmt.Sprintf("%s/%s_key.pem", dir, name),
	)
	if err != nil {
		t.Fatal(err)
	}
	cas, err := common.LoadCA(dir + "/ca.pem")
	if err != nil {
		t.Fatal(err)
	}
	return &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				Certificates: []tls.Certificate{cert},
				RootCAs:      cas,
			},
		},
	}
}

func TestTLS_Serve(t *testing.T) {
	dir := generateTestPKI(t)
	defer os.RemoveAll(dir)

	config := &ServerConfig{
		TLS: &TLSServerConfig{
			Cert: dir + "/server_cert.pem",
			Key:  dir + "/server_key.pem",
			CA:   dir + "/ca.pem",
			Auth: &TLSAuthConfig{
				Allow: []*TLSAuthACL{
					&TLSAuthACL{
						Path:       "/testpath",
						CommonName: "client1.*",
					},
					&TLSAuthACL{
						Path:       ".*",
						CommonName: ".*",
					},
				},
			},
		},
	}
	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		io.WriteString(w, "OK") // nolint: errcheck
	})

	go Serve(h, config, ":19898") // nolint: errcheck
	time.Sleep(100 * time.Millisecond)

	// A client without a certificate should get a transport-level error.
	c := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		},
	}

	// A client with a properly signed cert will get a successful reply.
	c1 := newTestClient(t, dir, "client1")
	c2 := newTestClient(t, dir, "client2")
	testdata := []struct {
		tag        string
		client     *http.Client
		uri        string
		expectedOk bool
	}{
		{"no-cert", c, "/", false},
		{"client1", c1, "/", true},
		{"client2", c2, "/", true},
		{"client1", c1, "/testpath", true},
		{"client2", c2, "/testpath", false},
	}

	for _, td := range testdata {
		resp, err := td.client.Get("https://localhost:19898" + td.uri)
		ok := false
		if err == nil {
			if resp.StatusCode == 200 {
				ok = true
			} else {
				err = fmt.Errorf("HTTP status %s", resp.Status)
			}
		}
		if ok != td.expectedOk {
			t.Errorf("client %s requesting %s got ok=%v, expected=%v (err=%v)", td.tag, td.uri, td.expectedOk, ok, err)
		}
	}
}