Skip to content
Snippets Groups Projects
tls_test.go 4.79 KiB
package serverutil

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"io"
	"io/ioutil"
	"math/big"
	"net/http"
	"os"
	"testing"
	"time"

	"git.autistici.org/ai3/go-common"
)

func saveCertificate(cert *x509.Certificate, path string) {
	data := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
	ioutil.WriteFile(path, data, 0644)
}

func savePrivateKey(pkey *ecdsa.PrivateKey, path string) {
	der, _ := x509.MarshalECPrivateKey(pkey)
	data := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
	ioutil.WriteFile(path, data, 0600)
}

func generateTestCA(t testing.TB, dir string) (*x509.Certificate, *ecdsa.PrivateKey) {
	pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		t.Fatal(err)
	}

	now := time.Now().UTC()
	template := x509.Certificate{
		SerialNumber:          big.NewInt(1),
		Subject:               pkix.Name{CommonName: "CA"},
		NotBefore:             now.Add(-5 * time.Minute),
		NotAfter:              now.AddDate(5, 0, 0), // 5 years.
		SignatureAlgorithm:    x509.ECDSAWithSHA256,
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		BasicConstraintsValid: true,
		IsCA:       true,
		MaxPathLen: 1,
	}
	der, err := x509.CreateCertificate(rand.Reader, &template, &template, pkey.Public(), pkey)
	if err != nil {
		t.Fatal(err)
	}
	cert, _ := x509.ParseCertificate(der)
	saveCertificate(cert, dir+"/ca.pem")
	return cert, pkey
}

func generateTestCert(t testing.TB, cacert *x509.Certificate, cakey *ecdsa.PrivateKey, dir, name string, subj pkix.Name, altNames []string, isClient, isServer bool) {
	pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		t.Fatal(err)
	}

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		t.Fatal(err)
	}

	var extUsage []x509.ExtKeyUsage
	if isServer {
		extUsage = append(extUsage, x509.ExtKeyUsageServerAuth)
	}
	if isClient {
		extUsage = append(extUsage, x509.ExtKeyUsageClientAuth)
	}

	now := time.Now().UTC()
	template := x509.Certificate{
		SerialNumber:          serialNumber,
		Subject:               subj,
		DNSNames:              altNames,
		NotBefore:             now.Add(-5 * time.Minute),
		NotAfter:              now.Add(24 * time.Hour),
		SignatureAlgorithm:    x509.ECDSAWithSHA256,
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
		ExtKeyUsage:           extUsage,
		PublicKey:             pkey.PublicKey,
		BasicConstraintsValid: true,
	}
	der, err := x509.CreateCertificate(rand.Reader, &template, cacert, pkey.Public(), cakey)
	if err != nil {
		t.Fatal(err)
	}
	cert, _ := x509.ParseCertificate(der)

	savePrivateKey(pkey, fmt.Sprintf("%s/%s_key.pem", dir, name))
	saveCertificate(cert, fmt.Sprintf("%s/%s_cert.pem", dir, name))
}

func generateTestPKI(t *testing.T) string {
	dir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}

	cacert, cakey := generateTestCA(t, dir)
	generateTestCert(t, cacert, cakey, dir, "server", pkix.Name{CommonName: "server"}, []string{"server", "localhost"}, false, true)
	generateTestCert(t, cacert, cakey, dir, "client1", pkix.Name{CommonName: "client1"}, nil, true, false)
	generateTestCert(t, cacert, cakey, dir, "client2", pkix.Name{CommonName: "client2"}, nil, true, false)
	return dir
}

func newTestClient(t testing.TB, dir, name string) *http.Client {
	cert, err := tls.LoadX509KeyPair(
		fmt.Sprintf("%s/%s_cert.pem", dir, name),
		fmt.Sprintf("%s/%s_key.pem", dir, name),
	)
	if err != nil {
		t.Fatal(err)
	}
	cas, err := common.LoadCA(dir + "/ca.pem")
	if err != nil {
		t.Fatal(err)
	}
	return &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				Certificates: []tls.Certificate{cert},
				RootCAs:      cas,
			},
		},
	}
}

func TestTLS_Serve(t *testing.T) {
	dir := generateTestPKI(t)
	defer os.RemoveAll(dir)

	config := &ServerConfig{
		TLS: &TLSServerConfig{
			Cert: dir + "/server_cert.pem",
			Key:  dir + "/server_key.pem",
			CA:   dir + "/ca.pem",
		},
	}
	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		io.WriteString(w, "OK")
	})

	go Serve(h, config, ":19898")
	time.Sleep(100 * time.Millisecond)

	// A client without a certificate should get a transport-level error.
	c := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		},
	}
	_, err := c.Get("https://localhost:19898/")
	if err == nil {
		t.Fatal("client without certificate got a successful reply")
	}

	// A client with a properly signed cert will get a successful reply.
	c = newTestClient(t, dir, "client1")
	_, err = c.Get("https://localhost:19898/")
	if err != nil {
		t.Fatalf("client with cert got error: %v", err)
	}
}