Skip to content
Snippets Groups Projects
Commit 9afc1865 authored by ale's avatar ale
Browse files

Add tests for certificate generation

parent 21a99ace
No related branches found
No related tags found
No related merge requests found
...@@ -24,6 +24,11 @@ import ( ...@@ -24,6 +24,11 @@ import (
type certInfo struct { type certInfo struct {
domains []string domains []string
retryDeadline time.Time retryDeadline time.Time
// A write-only attribute (not part of the logic) to indicate
// whether we think we have a valid certificate or not. Used
// for monitoring and debugging.
valid bool
} }
type certStorage interface { type certStorage interface {
...@@ -123,7 +128,7 @@ var ( ...@@ -123,7 +128,7 @@ var (
func (m *Manager) updateAllCerts(ctx context.Context) { func (m *Manager) updateAllCerts(ctx context.Context) {
for _, certInfo := range m.certs { for _, certInfo := range m.certs {
if certInfo.retryDeadline.Before(time.Now()) { if certInfo.retryDeadline.After(time.Now()) {
continue continue
} }
uctx, cancel := context.WithTimeout(ctx, renewalTimeout) uctx, cancel := context.WithTimeout(ctx, renewalTimeout)
...@@ -131,7 +136,9 @@ func (m *Manager) updateAllCerts(ctx context.Context) { ...@@ -131,7 +136,9 @@ func (m *Manager) updateAllCerts(ctx context.Context) {
cancel() cancel()
if err != nil { if err != nil {
log.Printf("error updating %s: %v", certInfo.domains[0], err) log.Printf("error updating %s: %v", certInfo.domains[0], err)
// Retry in a little while.
certInfo.retryDeadline = time.Now().Add(errorRetryTimeout) certInfo.retryDeadline = time.Now().Add(errorRetryTimeout)
certInfo.valid = false
} }
} }
} }
...@@ -161,6 +168,7 @@ func (m *Manager) updateCert(ctx context.Context, certInfo *certInfo) error { ...@@ -161,6 +168,7 @@ func (m *Manager) updateCert(ctx context.Context, certInfo *certInfo) error {
} }
certInfo.retryDeadline = renewalDeadline(leaf) certInfo.retryDeadline = renewalDeadline(leaf)
certInfo.valid = true
return nil return nil
} }
...@@ -182,6 +190,7 @@ func (m *Manager) loadConfig(certDomains [][]string) { ...@@ -182,6 +190,7 @@ func (m *Manager) loadConfig(certDomains [][]string) {
if err == nil { if err == nil {
log.Printf("cert for %s loaded from storage", cn) log.Printf("cert for %s loaded from storage", cn)
certInfo.retryDeadline = renewalDeadline(leaf) certInfo.retryDeadline = renewalDeadline(leaf)
certInfo.valid = true
} else { } else {
log.Printf("cert for %s loaded from storage but parameters have changed", cn) log.Printf("cert for %s loaded from storage but parameters have changed", cn)
} }
...@@ -191,12 +200,18 @@ func (m *Manager) loadConfig(certDomains [][]string) { ...@@ -191,12 +200,18 @@ func (m *Manager) loadConfig(certDomains [][]string) {
m.certs = certs m.certs = certs
} }
// This channel is used by the testing code to trigger an update,
// without having to wait for the timer to tick.
var testUpdateCh = make(chan bool)
func (m *Manager) loop(ctx context.Context) { func (m *Manager) loop(ctx context.Context) {
tick := time.NewTicker(5 * time.Minute) tick := time.NewTicker(5 * time.Minute)
for { for {
select { select {
case <-tick.C: case <-tick.C:
m.updateAllCerts(ctx) m.updateAllCerts(ctx)
case <-testUpdateCh:
m.updateAllCerts(ctx)
case certDomains := <-m.configCh: case certDomains := <-m.configCh:
m.loadConfig(certDomains) m.loadConfig(certDomains)
case <-m.stopCh: case <-m.stopCh:
......
...@@ -3,9 +3,14 @@ package acmeserver ...@@ -3,9 +3,14 @@ package acmeserver
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/x509" "crypto/x509"
"errors" "crypto/x509/pkix"
"io/ioutil" "io/ioutil"
"log"
"math/big"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
...@@ -15,16 +20,65 @@ import ( ...@@ -15,16 +20,65 @@ import (
type fakeACME struct { type fakeACME struct {
} }
func (f *fakeACME) GetCertificate(_ context.Context, key crypto.Signer, domains []string) (der [][]byte, leaf *x509.Certificate, err error) { func (f *fakeACME) GetCertificate(_ context.Context, priv crypto.Signer, domains []string) ([][]byte, *x509.Certificate, error) {
return nil, nil, errors.New("unimplemented") notBefore := time.Now()
notAfter := notBefore.AddDate(1, 0, 0)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: domains[0],
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
if len(domains) > 1 {
template.DNSNames = domains[1:]
}
der, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
return nil, nil, err
} }
func TestManager(t *testing.T) { cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, nil, err
}
log.Printf("created certificate for %s", domains[0])
return [][]byte{der}, cert, nil
}
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
func newTestManager(t testing.TB) (func(), *Manager) {
dir, err := ioutil.TempDir("", "") dir, err := ioutil.TempDir("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.RemoveAll(dir)
os.Mkdir(filepath.Join(dir, "config"), 0700) os.Mkdir(filepath.Join(dir, "config"), 0700)
ioutil.WriteFile( ioutil.WriteFile(
filepath.Join(dir, "config", "test.yml"), filepath.Join(dir, "config", "test.yml"),
...@@ -39,7 +93,6 @@ func TestManager(t *testing.T) { ...@@ -39,7 +93,6 @@ func TestManager(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer m.Stop()
if err := m.Start(context.Background()); err != nil { if err := m.Start(context.Background()); err != nil {
t.Fatal("Start:", err) t.Fatal("Start:", err)
...@@ -48,6 +101,16 @@ func TestManager(t *testing.T) { ...@@ -48,6 +101,16 @@ func TestManager(t *testing.T) {
// Wait just a little bit to give a chance to m.loop() to run. // Wait just a little bit to give a chance to m.loop() to run.
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
return func() {
m.Stop()
os.RemoveAll(dir)
}, m
}
func TestManager_Reload(t *testing.T) {
cleanup, m := newTestManager(t)
defer cleanup()
// Data race: we read data owned by another goroutine! // Data race: we read data owned by another goroutine!
if len(m.certs) < 1 { if len(m.certs) < 1 {
t.Fatal("configuration not loaded?") t.Fatal("configuration not loaded?")
...@@ -67,3 +130,48 @@ func TestManager(t *testing.T) { ...@@ -67,3 +130,48 @@ func TestManager(t *testing.T) {
t.Fatalf("certs[0].domains[0] is %s, expected example.com", m.certs[0].domains[0]) t.Fatalf("certs[0].domains[0] is %s, expected example.com", m.certs[0].domains[0])
} }
} }
func TestManager_NewCert(t *testing.T) {
cleanup, m := newTestManager(t)
defer cleanup()
now := time.Now()
ci := m.certs[0]
if ci.retryDeadline.After(now) {
t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline)
}
testUpdateCh <- true
time.Sleep(100 * time.Millisecond)
// Verify that the retry/renewal timestamp is in the future.
if ci.retryDeadline.Before(now) {
t.Fatalf("retry deadline is in the past after renewal: %v", ci.retryDeadline)
}
// Do we think we have a valid certificate?
if !ci.valid {
t.Fatal("we don't have a valid certificate")
}
// Verify that the credentials have successfully been written
// to storage.
p := filepath.Join(m.configDir, "../certs/example.com/cert.pem")
if _, err := os.Stat(p); err != nil {
t.Fatalf("file not created: %v", err)
}
p = filepath.Join(m.configDir, "../certs/example.com/private_key.pem")
if _, err := os.Stat(p); err != nil {
t.Fatalf("file not created: %v", err)
}
// By triggering a reload now, we should cause the Manager to
// reload the certificate from storage.
m.Reload()
time.Sleep(50 * time.Millisecond)
ci = m.certs[0]
if !ci.valid {
t.Fatal("certificate is invalid after a reload")
}
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"io/ioutil" "io/ioutil"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
...@@ -46,7 +47,7 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error { ...@@ -46,7 +47,7 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
dir := filepath.Join(d.root, cn) dir := filepath.Join(d.root, cn)
if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) { if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) {
if err = os.Mkdir(dir, 0755); err != nil { if err = os.MkdirAll(dir, 0755); err != nil {
return err return err
} }
} }
...@@ -56,7 +57,8 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error { ...@@ -56,7 +57,8 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
if strings.HasSuffix(path, "private_key.pem") { if strings.HasSuffix(path, "private_key.pem") {
mode = 0400 mode = 0400
} }
if err := ioutil.WriteFile(path, data, mode); err != nil { log.Printf("writing %s (%03o)", path, mode)
if err := ioutil.WriteFile(filepath.Join(d.root, path), data, mode); err != nil {
return err return err
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment