From 9afc1865f1ee521a05527e0ba94a75755890b5d7 Mon Sep 17 00:00:00 2001 From: ale <ale@incal.net> Date: Mon, 18 Jun 2018 06:46:23 +0100 Subject: [PATCH] Add tests for certificate generation --- server.go | 17 ++++++- server_test.go | 120 ++++++++++++++++++++++++++++++++++++++++++++++--- storage.go | 6 ++- 3 files changed, 134 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 729bd0d1..5c77d660 100644 --- a/server.go +++ b/server.go @@ -24,6 +24,11 @@ import ( type certInfo struct { domains []string retryDeadline time.Time + + // A write-only attribute (not part of the logic) to indicate + // whether we think we have a valid certificate or not. Used + // for monitoring and debugging. + valid bool } type certStorage interface { @@ -123,7 +128,7 @@ var ( func (m *Manager) updateAllCerts(ctx context.Context) { for _, certInfo := range m.certs { - if certInfo.retryDeadline.Before(time.Now()) { + if certInfo.retryDeadline.After(time.Now()) { continue } uctx, cancel := context.WithTimeout(ctx, renewalTimeout) @@ -131,7 +136,9 @@ func (m *Manager) updateAllCerts(ctx context.Context) { cancel() if err != nil { log.Printf("error updating %s: %v", certInfo.domains[0], err) + // Retry in a little while. certInfo.retryDeadline = time.Now().Add(errorRetryTimeout) + certInfo.valid = false } } } @@ -161,6 +168,7 @@ func (m *Manager) updateCert(ctx context.Context, certInfo *certInfo) error { } certInfo.retryDeadline = renewalDeadline(leaf) + certInfo.valid = true return nil } @@ -182,6 +190,7 @@ func (m *Manager) loadConfig(certDomains [][]string) { if err == nil { log.Printf("cert for %s loaded from storage", cn) certInfo.retryDeadline = renewalDeadline(leaf) + certInfo.valid = true } else { log.Printf("cert for %s loaded from storage but parameters have changed", cn) } @@ -191,12 +200,18 @@ func (m *Manager) loadConfig(certDomains [][]string) { m.certs = certs } +// This channel is used by the testing code to trigger an update, +// without having to wait for the timer to tick. +var testUpdateCh = make(chan bool) + func (m *Manager) loop(ctx context.Context) { tick := time.NewTicker(5 * time.Minute) for { select { case <-tick.C: m.updateAllCerts(ctx) + case <-testUpdateCh: + m.updateAllCerts(ctx) case certDomains := <-m.configCh: m.loadConfig(certDomains) case <-m.stopCh: diff --git a/server_test.go b/server_test.go index d5b48a52..377c1bd3 100644 --- a/server_test.go +++ b/server_test.go @@ -3,9 +3,14 @@ package acmeserver import ( "context" "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" "crypto/x509" - "errors" + "crypto/x509/pkix" "io/ioutil" + "log" + "math/big" "os" "path/filepath" "testing" @@ -15,16 +20,65 @@ import ( type fakeACME struct { } -func (f *fakeACME) GetCertificate(_ context.Context, key crypto.Signer, domains []string) (der [][]byte, leaf *x509.Certificate, err error) { - return nil, nil, errors.New("unimplemented") +func (f *fakeACME) GetCertificate(_ context.Context, priv crypto.Signer, domains []string) ([][]byte, *x509.Certificate, error) { + notBefore := time.Now() + notAfter := notBefore.AddDate(1, 0, 0) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: domains[0], + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + if len(domains) > 1 { + template.DNSNames = domains[1:] + } + + der, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) + if err != nil { + return nil, nil, err + } + + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, nil, err + } + + log.Printf("created certificate for %s", domains[0]) + + return [][]byte{der}, cert, nil } -func TestManager(t *testing.T) { +func publicKey(priv interface{}) interface{} { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} + +func newTestManager(t testing.TB) (func(), *Manager) { dir, err := ioutil.TempDir("", "") if err != nil { t.Fatal(err) } - defer os.RemoveAll(dir) + os.Mkdir(filepath.Join(dir, "config"), 0700) ioutil.WriteFile( filepath.Join(dir, "config", "test.yml"), @@ -39,7 +93,6 @@ func TestManager(t *testing.T) { if err != nil { t.Fatal(err) } - defer m.Stop() if err := m.Start(context.Background()); err != nil { t.Fatal("Start:", err) @@ -48,6 +101,16 @@ func TestManager(t *testing.T) { // Wait just a little bit to give a chance to m.loop() to run. time.Sleep(50 * time.Millisecond) + return func() { + m.Stop() + os.RemoveAll(dir) + }, m +} + +func TestManager_Reload(t *testing.T) { + cleanup, m := newTestManager(t) + defer cleanup() + // Data race: we read data owned by another goroutine! if len(m.certs) < 1 { t.Fatal("configuration not loaded?") @@ -67,3 +130,48 @@ func TestManager(t *testing.T) { t.Fatalf("certs[0].domains[0] is %s, expected example.com", m.certs[0].domains[0]) } } + +func TestManager_NewCert(t *testing.T) { + cleanup, m := newTestManager(t) + defer cleanup() + + now := time.Now() + ci := m.certs[0] + if ci.retryDeadline.After(now) { + t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline) + } + + testUpdateCh <- true + time.Sleep(100 * time.Millisecond) + + // Verify that the retry/renewal timestamp is in the future. + if ci.retryDeadline.Before(now) { + t.Fatalf("retry deadline is in the past after renewal: %v", ci.retryDeadline) + } + + // Do we think we have a valid certificate? + if !ci.valid { + t.Fatal("we don't have a valid certificate") + } + + // Verify that the credentials have successfully been written + // to storage. + p := filepath.Join(m.configDir, "../certs/example.com/cert.pem") + if _, err := os.Stat(p); err != nil { + t.Fatalf("file not created: %v", err) + } + p = filepath.Join(m.configDir, "../certs/example.com/private_key.pem") + if _, err := os.Stat(p); err != nil { + t.Fatalf("file not created: %v", err) + } + + // By triggering a reload now, we should cause the Manager to + // reload the certificate from storage. + m.Reload() + time.Sleep(50 * time.Millisecond) + + ci = m.certs[0] + if !ci.valid { + t.Fatal("certificate is invalid after a reload") + } +} diff --git a/storage.go b/storage.go index 60ccda1e..f0c5f703 100644 --- a/storage.go +++ b/storage.go @@ -10,6 +10,7 @@ import ( "encoding/pem" "errors" "io/ioutil" + "log" "os" "path/filepath" "strings" @@ -46,7 +47,7 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error { dir := filepath.Join(d.root, cn) if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) { - if err = os.Mkdir(dir, 0755); err != nil { + if err = os.MkdirAll(dir, 0755); err != nil { return err } } @@ -56,7 +57,8 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error { if strings.HasSuffix(path, "private_key.pem") { mode = 0400 } - if err := ioutil.WriteFile(path, data, mode); err != nil { + log.Printf("writing %s (%03o)", path, mode) + if err := ioutil.WriteFile(filepath.Join(d.root, path), data, mode); err != nil { return err } } -- GitLab