diff --git a/serverutil/tls_test.go b/serverutil/tls_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9cf3ef9aeb3d19ef62107f8d5121004b2f63c121 --- /dev/null +++ b/serverutil/tls_test.go @@ -0,0 +1,176 @@ +package serverutil + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io" + "io/ioutil" + "math/big" + "net/http" + "os" + "testing" + "time" + + "git.autistici.org/ai3/go-common" +) + +func saveCertificate(cert *x509.Certificate, path string) { + data := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + ioutil.WriteFile(path, data, 0644) +} + +func savePrivateKey(pkey *ecdsa.PrivateKey, path string) { + der, _ := x509.MarshalECPrivateKey(pkey) + data := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}) + ioutil.WriteFile(path, data, 0600) +} + +func generateTestCA(t testing.TB, dir string) (*x509.Certificate, *ecdsa.PrivateKey) { + pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + now := time.Now().UTC() + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "CA"}, + NotBefore: now.Add(-5 * time.Minute), + NotAfter: now.AddDate(5, 0, 0), // 5 years. + SignatureAlgorithm: x509.ECDSAWithSHA256, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + } + der, err := x509.CreateCertificate(rand.Reader, &template, &template, pkey.Public(), pkey) + if err != nil { + t.Fatal(err) + } + cert, _ := x509.ParseCertificate(der) + saveCertificate(cert, dir+"/ca.pem") + return cert, pkey +} + +func generateTestCert(t testing.TB, cacert *x509.Certificate, cakey *ecdsa.PrivateKey, dir, name string, subj pkix.Name, altNames []string, isClient, isServer bool) { + pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatal(err) + } + + var extUsage []x509.ExtKeyUsage + if isServer { + extUsage = append(extUsage, x509.ExtKeyUsageServerAuth) + } + if isClient { + extUsage = append(extUsage, x509.ExtKeyUsageClientAuth) + } + + now := time.Now().UTC() + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subj, + DNSNames: altNames, + NotBefore: now.Add(-5 * time.Minute), + NotAfter: now.Add(24 * time.Hour), + SignatureAlgorithm: x509.ECDSAWithSHA256, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: extUsage, + PublicKey: pkey.PublicKey, + BasicConstraintsValid: true, + } + der, err := x509.CreateCertificate(rand.Reader, &template, cacert, pkey.Public(), cakey) + if err != nil { + t.Fatal(err) + } + cert, _ := x509.ParseCertificate(der) + + savePrivateKey(pkey, fmt.Sprintf("%s/%s_key.pem", dir, name)) + saveCertificate(cert, fmt.Sprintf("%s/%s_cert.pem", dir, name)) +} + +func generateTestPKI(t *testing.T) string { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + + cacert, cakey := generateTestCA(t, dir) + generateTestCert(t, cacert, cakey, dir, "server", pkix.Name{CommonName: "server"}, []string{"server", "localhost"}, false, true) + generateTestCert(t, cacert, cakey, dir, "client1", pkix.Name{CommonName: "client1"}, nil, true, false) + generateTestCert(t, cacert, cakey, dir, "client2", pkix.Name{CommonName: "client2"}, nil, true, false) + return dir +} + +func newTestClient(t testing.TB, dir, name string) *http.Client { + cert, err := tls.LoadX509KeyPair( + fmt.Sprintf("%s/%s_cert.pem", dir, name), + fmt.Sprintf("%s/%s_key.pem", dir, name), + ) + if err != nil { + t.Fatal(err) + } + cas, err := common.LoadCA(dir + "/ca.pem") + if err != nil { + t.Fatal(err) + } + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: cas, + }, + }, + } +} + +func TestTLS_Serve(t *testing.T) { + dir := generateTestPKI(t) + defer os.RemoveAll(dir) + + config := &ServerConfig{ + TLS: &TLSServerConfig{ + Cert: dir + "/server_cert.pem", + Key: dir + "/server_key.pem", + CA: dir + "/ca.pem", + }, + } + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK") + }) + + go Serve(h, config, ":19898") + time.Sleep(100 * time.Millisecond) + + // A client without a certificate should get a transport-level error. + c := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + _, err := c.Get("https://localhost:19898/") + if err == nil { + t.Fatal("client without certificate got a successful reply") + } + + // A client with a properly signed cert will get a successful reply. + c = newTestClient(t, dir, "client1") + _, err = c.Get("https://localhost:19898/") + if err != nil { + t.Fatalf("client with cert got error: %v", err) + } +}