Skip to content
Snippets Groups Projects
Commit 9afc1865 authored by ale's avatar ale
Browse files

Add tests for certificate generation

parent 21a99ace
Branches
No related tags found
No related merge requests found
......@@ -24,6 +24,11 @@ import (
type certInfo struct {
domains []string
retryDeadline time.Time
// A write-only attribute (not part of the logic) to indicate
// whether we think we have a valid certificate or not. Used
// for monitoring and debugging.
valid bool
}
type certStorage interface {
......@@ -123,7 +128,7 @@ var (
func (m *Manager) updateAllCerts(ctx context.Context) {
for _, certInfo := range m.certs {
if certInfo.retryDeadline.Before(time.Now()) {
if certInfo.retryDeadline.After(time.Now()) {
continue
}
uctx, cancel := context.WithTimeout(ctx, renewalTimeout)
......@@ -131,7 +136,9 @@ func (m *Manager) updateAllCerts(ctx context.Context) {
cancel()
if err != nil {
log.Printf("error updating %s: %v", certInfo.domains[0], err)
// Retry in a little while.
certInfo.retryDeadline = time.Now().Add(errorRetryTimeout)
certInfo.valid = false
}
}
}
......@@ -161,6 +168,7 @@ func (m *Manager) updateCert(ctx context.Context, certInfo *certInfo) error {
}
certInfo.retryDeadline = renewalDeadline(leaf)
certInfo.valid = true
return nil
}
......@@ -182,6 +190,7 @@ func (m *Manager) loadConfig(certDomains [][]string) {
if err == nil {
log.Printf("cert for %s loaded from storage", cn)
certInfo.retryDeadline = renewalDeadline(leaf)
certInfo.valid = true
} else {
log.Printf("cert for %s loaded from storage but parameters have changed", cn)
}
......@@ -191,12 +200,18 @@ func (m *Manager) loadConfig(certDomains [][]string) {
m.certs = certs
}
// This channel is used by the testing code to trigger an update,
// without having to wait for the timer to tick.
var testUpdateCh = make(chan bool)
func (m *Manager) loop(ctx context.Context) {
tick := time.NewTicker(5 * time.Minute)
for {
select {
case <-tick.C:
m.updateAllCerts(ctx)
case <-testUpdateCh:
m.updateAllCerts(ctx)
case certDomains := <-m.configCh:
m.loadConfig(certDomains)
case <-m.stopCh:
......
......@@ -3,9 +3,14 @@ package acmeserver
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"errors"
"crypto/x509/pkix"
"io/ioutil"
"log"
"math/big"
"os"
"path/filepath"
"testing"
......@@ -15,16 +20,65 @@ import (
type fakeACME struct {
}
func (f *fakeACME) GetCertificate(_ context.Context, key crypto.Signer, domains []string) (der [][]byte, leaf *x509.Certificate, err error) {
return nil, nil, errors.New("unimplemented")
func (f *fakeACME) GetCertificate(_ context.Context, priv crypto.Signer, domains []string) ([][]byte, *x509.Certificate, error) {
notBefore := time.Now()
notAfter := notBefore.AddDate(1, 0, 0)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: domains[0],
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
if len(domains) > 1 {
template.DNSNames = domains[1:]
}
der, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, nil, err
}
log.Printf("created certificate for %s", domains[0])
return [][]byte{der}, cert, nil
}
func TestManager(t *testing.T) {
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
func newTestManager(t testing.TB) (func(), *Manager) {
dir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
os.Mkdir(filepath.Join(dir, "config"), 0700)
ioutil.WriteFile(
filepath.Join(dir, "config", "test.yml"),
......@@ -39,7 +93,6 @@ func TestManager(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer m.Stop()
if err := m.Start(context.Background()); err != nil {
t.Fatal("Start:", err)
......@@ -48,6 +101,16 @@ func TestManager(t *testing.T) {
// Wait just a little bit to give a chance to m.loop() to run.
time.Sleep(50 * time.Millisecond)
return func() {
m.Stop()
os.RemoveAll(dir)
}, m
}
func TestManager_Reload(t *testing.T) {
cleanup, m := newTestManager(t)
defer cleanup()
// Data race: we read data owned by another goroutine!
if len(m.certs) < 1 {
t.Fatal("configuration not loaded?")
......@@ -67,3 +130,48 @@ func TestManager(t *testing.T) {
t.Fatalf("certs[0].domains[0] is %s, expected example.com", m.certs[0].domains[0])
}
}
func TestManager_NewCert(t *testing.T) {
cleanup, m := newTestManager(t)
defer cleanup()
now := time.Now()
ci := m.certs[0]
if ci.retryDeadline.After(now) {
t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline)
}
testUpdateCh <- true
time.Sleep(100 * time.Millisecond)
// Verify that the retry/renewal timestamp is in the future.
if ci.retryDeadline.Before(now) {
t.Fatalf("retry deadline is in the past after renewal: %v", ci.retryDeadline)
}
// Do we think we have a valid certificate?
if !ci.valid {
t.Fatal("we don't have a valid certificate")
}
// Verify that the credentials have successfully been written
// to storage.
p := filepath.Join(m.configDir, "../certs/example.com/cert.pem")
if _, err := os.Stat(p); err != nil {
t.Fatalf("file not created: %v", err)
}
p = filepath.Join(m.configDir, "../certs/example.com/private_key.pem")
if _, err := os.Stat(p); err != nil {
t.Fatalf("file not created: %v", err)
}
// By triggering a reload now, we should cause the Manager to
// reload the certificate from storage.
m.Reload()
time.Sleep(50 * time.Millisecond)
ci = m.certs[0]
if !ci.valid {
t.Fatal("certificate is invalid after a reload")
}
}
......@@ -10,6 +10,7 @@ import (
"encoding/pem"
"errors"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
......@@ -46,7 +47,7 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
dir := filepath.Join(d.root, cn)
if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) {
if err = os.Mkdir(dir, 0755); err != nil {
if err = os.MkdirAll(dir, 0755); err != nil {
return err
}
}
......@@ -56,7 +57,8 @@ func (d *dirStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error {
if strings.HasSuffix(path, "private_key.pem") {
mode = 0400
}
if err := ioutil.WriteFile(path, data, mode); err != nil {
log.Printf("writing %s (%03o)", path, mode)
if err := ioutil.WriteFile(filepath.Join(d.root, path), data, mode); err != nil {
return err
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment