Skip to content
Snippets Groups Projects
Commit dd6697d2 authored by ale's avatar ale
Browse files

Add a test for TLS serving

parent 0cc06229
No related branches found
No related tags found
No related merge requests found
package serverutil
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"io/ioutil"
"math/big"
"net/http"
"os"
"testing"
"time"
"git.autistici.org/ai3/go-common"
)
func saveCertificate(cert *x509.Certificate, path string) {
data := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
ioutil.WriteFile(path, data, 0644)
}
func savePrivateKey(pkey *ecdsa.PrivateKey, path string) {
der, _ := x509.MarshalECPrivateKey(pkey)
data := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
ioutil.WriteFile(path, data, 0600)
}
func generateTestCA(t testing.TB, dir string) (*x509.Certificate, *ecdsa.PrivateKey) {
pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
now := time.Now().UTC()
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "CA"},
NotBefore: now.Add(-5 * time.Minute),
NotAfter: now.AddDate(5, 0, 0), // 5 years.
SignatureAlgorithm: x509.ECDSAWithSHA256,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 1,
}
der, err := x509.CreateCertificate(rand.Reader, &template, &template, pkey.Public(), pkey)
if err != nil {
t.Fatal(err)
}
cert, _ := x509.ParseCertificate(der)
saveCertificate(cert, dir+"/ca.pem")
return cert, pkey
}
func generateTestCert(t testing.TB, cacert *x509.Certificate, cakey *ecdsa.PrivateKey, dir, name string, subj pkix.Name, altNames []string, isClient, isServer bool) {
pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
t.Fatal(err)
}
var extUsage []x509.ExtKeyUsage
if isServer {
extUsage = append(extUsage, x509.ExtKeyUsageServerAuth)
}
if isClient {
extUsage = append(extUsage, x509.ExtKeyUsageClientAuth)
}
now := time.Now().UTC()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: subj,
DNSNames: altNames,
NotBefore: now.Add(-5 * time.Minute),
NotAfter: now.Add(24 * time.Hour),
SignatureAlgorithm: x509.ECDSAWithSHA256,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: extUsage,
PublicKey: pkey.PublicKey,
BasicConstraintsValid: true,
}
der, err := x509.CreateCertificate(rand.Reader, &template, cacert, pkey.Public(), cakey)
if err != nil {
t.Fatal(err)
}
cert, _ := x509.ParseCertificate(der)
savePrivateKey(pkey, fmt.Sprintf("%s/%s_key.pem", dir, name))
saveCertificate(cert, fmt.Sprintf("%s/%s_cert.pem", dir, name))
}
func generateTestPKI(t *testing.T) string {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cacert, cakey := generateTestCA(t, dir)
generateTestCert(t, cacert, cakey, dir, "server", pkix.Name{CommonName: "server"}, []string{"server", "localhost"}, false, true)
generateTestCert(t, cacert, cakey, dir, "client1", pkix.Name{CommonName: "client1"}, nil, true, false)
generateTestCert(t, cacert, cakey, dir, "client2", pkix.Name{CommonName: "client2"}, nil, true, false)
return dir
}
func newTestClient(t testing.TB, dir, name string) *http.Client {
cert, err := tls.LoadX509KeyPair(
fmt.Sprintf("%s/%s_cert.pem", dir, name),
fmt.Sprintf("%s/%s_key.pem", dir, name),
)
if err != nil {
t.Fatal(err)
}
cas, err := common.LoadCA(dir + "/ca.pem")
if err != nil {
t.Fatal(err)
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: cas,
},
},
}
}
func TestTLS_Serve(t *testing.T) {
dir := generateTestPKI(t)
defer os.RemoveAll(dir)
config := &ServerConfig{
TLS: &TLSServerConfig{
Cert: dir + "/server_cert.pem",
Key: dir + "/server_key.pem",
CA: dir + "/ca.pem",
},
}
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "OK")
})
go Serve(h, config, ":19898")
time.Sleep(100 * time.Millisecond)
// A client without a certificate should get a transport-level error.
c := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
_, err := c.Get("https://localhost:19898/")
if err == nil {
t.Fatal("client without certificate got a successful reply")
}
// A client with a properly signed cert will get a successful reply.
c = newTestClient(t, dir, "client1")
_, err = c.Get("https://localhost:19898/")
if err != nil {
t.Fatalf("client with cert got error: %v", err)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment