From dd6697d2ba984c610fc227595aa5e8cbb4f60f12 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Thu, 14 Dec 2017 09:02:33 +0000
Subject: [PATCH] Add a test for TLS serving

---
 serverutil/tls_test.go | 176 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 176 insertions(+)
 create mode 100644 serverutil/tls_test.go

diff --git a/serverutil/tls_test.go b/serverutil/tls_test.go
new file mode 100644
index 0000000..9cf3ef9
--- /dev/null
+++ b/serverutil/tls_test.go
@@ -0,0 +1,176 @@
+package serverutil
+
+import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/tls"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"encoding/pem"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"math/big"
+	"net/http"
+	"os"
+	"testing"
+	"time"
+
+	"git.autistici.org/ai3/go-common"
+)
+
+func saveCertificate(cert *x509.Certificate, path string) {
+	data := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
+	ioutil.WriteFile(path, data, 0644)
+}
+
+func savePrivateKey(pkey *ecdsa.PrivateKey, path string) {
+	der, _ := x509.MarshalECPrivateKey(pkey)
+	data := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
+	ioutil.WriteFile(path, data, 0600)
+}
+
+func generateTestCA(t testing.TB, dir string) (*x509.Certificate, *ecdsa.PrivateKey) {
+	pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	now := time.Now().UTC()
+	template := x509.Certificate{
+		SerialNumber:          big.NewInt(1),
+		Subject:               pkix.Name{CommonName: "CA"},
+		NotBefore:             now.Add(-5 * time.Minute),
+		NotAfter:              now.AddDate(5, 0, 0), // 5 years.
+		SignatureAlgorithm:    x509.ECDSAWithSHA256,
+		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+		BasicConstraintsValid: true,
+		IsCA:       true,
+		MaxPathLen: 1,
+	}
+	der, err := x509.CreateCertificate(rand.Reader, &template, &template, pkey.Public(), pkey)
+	if err != nil {
+		t.Fatal(err)
+	}
+	cert, _ := x509.ParseCertificate(der)
+	saveCertificate(cert, dir+"/ca.pem")
+	return cert, pkey
+}
+
+func generateTestCert(t testing.TB, cacert *x509.Certificate, cakey *ecdsa.PrivateKey, dir, name string, subj pkix.Name, altNames []string, isClient, isServer bool) {
+	pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var extUsage []x509.ExtKeyUsage
+	if isServer {
+		extUsage = append(extUsage, x509.ExtKeyUsageServerAuth)
+	}
+	if isClient {
+		extUsage = append(extUsage, x509.ExtKeyUsageClientAuth)
+	}
+
+	now := time.Now().UTC()
+	template := x509.Certificate{
+		SerialNumber:          serialNumber,
+		Subject:               subj,
+		DNSNames:              altNames,
+		NotBefore:             now.Add(-5 * time.Minute),
+		NotAfter:              now.Add(24 * time.Hour),
+		SignatureAlgorithm:    x509.ECDSAWithSHA256,
+		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+		ExtKeyUsage:           extUsage,
+		PublicKey:             pkey.PublicKey,
+		BasicConstraintsValid: true,
+	}
+	der, err := x509.CreateCertificate(rand.Reader, &template, cacert, pkey.Public(), cakey)
+	if err != nil {
+		t.Fatal(err)
+	}
+	cert, _ := x509.ParseCertificate(der)
+
+	savePrivateKey(pkey, fmt.Sprintf("%s/%s_key.pem", dir, name))
+	saveCertificate(cert, fmt.Sprintf("%s/%s_cert.pem", dir, name))
+}
+
+func generateTestPKI(t *testing.T) string {
+	dir, err := ioutil.TempDir("", "")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cacert, cakey := generateTestCA(t, dir)
+	generateTestCert(t, cacert, cakey, dir, "server", pkix.Name{CommonName: "server"}, []string{"server", "localhost"}, false, true)
+	generateTestCert(t, cacert, cakey, dir, "client1", pkix.Name{CommonName: "client1"}, nil, true, false)
+	generateTestCert(t, cacert, cakey, dir, "client2", pkix.Name{CommonName: "client2"}, nil, true, false)
+	return dir
+}
+
+func newTestClient(t testing.TB, dir, name string) *http.Client {
+	cert, err := tls.LoadX509KeyPair(
+		fmt.Sprintf("%s/%s_cert.pem", dir, name),
+		fmt.Sprintf("%s/%s_key.pem", dir, name),
+	)
+	if err != nil {
+		t.Fatal(err)
+	}
+	cas, err := common.LoadCA(dir + "/ca.pem")
+	if err != nil {
+		t.Fatal(err)
+	}
+	return &http.Client{
+		Transport: &http.Transport{
+			TLSClientConfig: &tls.Config{
+				Certificates: []tls.Certificate{cert},
+				RootCAs:      cas,
+			},
+		},
+	}
+}
+
+func TestTLS_Serve(t *testing.T) {
+	dir := generateTestPKI(t)
+	defer os.RemoveAll(dir)
+
+	config := &ServerConfig{
+		TLS: &TLSServerConfig{
+			Cert: dir + "/server_cert.pem",
+			Key:  dir + "/server_key.pem",
+			CA:   dir + "/ca.pem",
+		},
+	}
+	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, "OK")
+	})
+
+	go Serve(h, config, ":19898")
+	time.Sleep(100 * time.Millisecond)
+
+	// A client without a certificate should get a transport-level error.
+	c := &http.Client{
+		Transport: &http.Transport{
+			TLSClientConfig: &tls.Config{
+				InsecureSkipVerify: true,
+			},
+		},
+	}
+	_, err := c.Get("https://localhost:19898/")
+	if err == nil {
+		t.Fatal("client without certificate got a successful reply")
+	}
+
+	// A client with a properly signed cert will get a successful reply.
+	c = newTestClient(t, dir, "client1")
+	_, err = c.Get("https://localhost:19898/")
+	if err != nil {
+		t.Fatalf("client with cert got error: %v", err)
+	}
+}
-- 
GitLab