package serverutil import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "io" "io/ioutil" "math/big" "net/http" "os" "testing" "time" "git.autistici.org/ai3/go-common" ) func saveCertificate(cert *x509.Certificate, path string) { data := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) ioutil.WriteFile(path, data, 0644) } func savePrivateKey(pkey *ecdsa.PrivateKey, path string) { der, _ := x509.MarshalECPrivateKey(pkey) data := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}) ioutil.WriteFile(path, data, 0600) } func generateTestCA(t testing.TB, dir string) (*x509.Certificate, *ecdsa.PrivateKey) { pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } now := time.Now().UTC() template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "CA"}, NotBefore: now.Add(-5 * time.Minute), NotAfter: now.AddDate(5, 0, 0), // 5 years. SignatureAlgorithm: x509.ECDSAWithSHA256, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, } der, err := x509.CreateCertificate(rand.Reader, &template, &template, pkey.Public(), pkey) if err != nil { t.Fatal(err) } cert, _ := x509.ParseCertificate(der) saveCertificate(cert, dir+"/ca.pem") return cert, pkey } func generateTestCert(t testing.TB, cacert *x509.Certificate, cakey *ecdsa.PrivateKey, dir, name string, subj pkix.Name, altNames []string, isClient, isServer bool) { pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { t.Fatal(err) } var extUsage []x509.ExtKeyUsage if isServer { extUsage = append(extUsage, x509.ExtKeyUsageServerAuth) } if isClient { extUsage = append(extUsage, x509.ExtKeyUsageClientAuth) } now := time.Now().UTC() template := x509.Certificate{ SerialNumber: serialNumber, Subject: subj, DNSNames: altNames, NotBefore: now.Add(-5 * time.Minute), NotAfter: now.Add(24 * time.Hour), SignatureAlgorithm: x509.ECDSAWithSHA256, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: extUsage, PublicKey: pkey.PublicKey, BasicConstraintsValid: true, } der, err := x509.CreateCertificate(rand.Reader, &template, cacert, pkey.Public(), cakey) if err != nil { t.Fatal(err) } cert, _ := x509.ParseCertificate(der) savePrivateKey(pkey, fmt.Sprintf("%s/%s_key.pem", dir, name)) saveCertificate(cert, fmt.Sprintf("%s/%s_cert.pem", dir, name)) } func generateTestPKI(t *testing.T) string { dir, err := ioutil.TempDir("", "") if err != nil { t.Fatal(err) } cacert, cakey := generateTestCA(t, dir) generateTestCert(t, cacert, cakey, dir, "server", pkix.Name{CommonName: "server"}, []string{"server", "localhost"}, false, true) generateTestCert(t, cacert, cakey, dir, "client1", pkix.Name{CommonName: "client1"}, nil, true, false) generateTestCert(t, cacert, cakey, dir, "client2", pkix.Name{CommonName: "client2"}, nil, true, false) return dir } func newTestClient(t testing.TB, dir, name string) *http.Client { cert, err := tls.LoadX509KeyPair( fmt.Sprintf("%s/%s_cert.pem", dir, name), fmt.Sprintf("%s/%s_key.pem", dir, name), ) if err != nil { t.Fatal(err) } cas, err := common.LoadCA(dir + "/ca.pem") if err != nil { t.Fatal(err) } return &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: cas, }, }, } } func TestTLS_Serve(t *testing.T) { dir := generateTestPKI(t) defer os.RemoveAll(dir) config := &ServerConfig{ TLS: &TLSServerConfig{ Cert: dir + "/server_cert.pem", Key: dir + "/server_key.pem", CA: dir + "/ca.pem", }, } h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "OK") }) go Serve(h, config, ":19898") time.Sleep(100 * time.Millisecond) // A client without a certificate should get a transport-level error. c := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } _, err := c.Get("https://localhost:19898/") if err == nil { t.Fatal("client without certificate got a successful reply") } // A client with a properly signed cert will get a successful reply. c = newTestClient(t, dir, "client1") _, err = c.Get("https://localhost:19898/") if err != nil { t.Fatalf("client with cert got error: %v", err) } }