Skip to content
Snippets Groups Projects
Commit 5e055959 authored by ale's avatar ale
Browse files

Refactor Manager update mechanic onto a separate utility function

parent 10370f68
Branches
No related tags found
No related merge requests found
......@@ -20,6 +20,11 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
var (
defaultRenewalDays = 21
updateInterval = 1 * time.Minute
)
// certInfo represents what we know about the state of the certificate
// at runtime.
type certInfo struct {
......@@ -71,18 +76,19 @@ type CertGenerator interface {
GetCertificate(context.Context, crypto.Signer, *certConfig) ([][]byte, *x509.Certificate, error)
}
// Manager periodically renews certificates before they expire, and
// responds to http-01 validation requests.
// Manager periodically renews certificates before they expire.
type Manager struct {
configDirs []string
useRSA bool
storage certStorage
certs []*certInfo
certGen CertGenerator
renewalDays int
configCh chan []*certConfig
doneCh chan bool
mx sync.Mutex
certs []*certInfo
}
// NewManager creates a new Manager with the given configuration.
......@@ -94,6 +100,9 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
if config.Output.Path == "" {
return nil, errors.New("'output.path' is unset")
}
if config.RenewalDays <= 0 {
config.RenewalDays = defaultRenewalDays
}
m := &Manager{
useRSA: config.UseRSA,
......@@ -103,9 +112,6 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
certGen: certGen,
renewalDays: config.RenewalDays,
}
if m.renewalDays <= 0 {
m.renewalDays = 15
}
ds := &dirStorage{root: config.Output.Path}
if config.Output.ReplDS == nil {
......@@ -241,71 +247,41 @@ func (m *Manager) loadConfig(certs []*certConfig) []*certInfo {
return out
}
func (m *Manager) getCerts() []*certInfo {
m.mx.Lock()
defer m.mx.Unlock()
return m.certs
}
func (m *Manager) setCerts(certs []*certInfo) {
m.mx.Lock()
m.certs = certs
m.mx.Unlock()
}
// This channel is used by the testing code to trigger an update,
// without having to wait for the timer to tick.
var testUpdateCh = make(chan bool)
func (m *Manager) loop(ctx context.Context) {
// Updates are long-term jobs, so they should be
// interruptible. We run updates in a separate goroutine, and
// cancel them when the configuration is reloaded or on exit.
// A simple channel is used as a semaphore, so that only one
// update goroutine can be running at any given time (without
// other ones piling up).
var upCancel context.CancelFunc
var wg sync.WaitGroup
sem := make(chan struct{}, 1)
startUpdate := func(certs []*certInfo) context.CancelFunc {
// Acquire the semaphore, return if we fail to.
select {
case sem <- struct{}{}:
default:
return nil
}
defer func() {
<-sem
}()
upCtx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
m.updateAllCerts(upCtx, certs)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
upCancel = nil
reloadCh := make(chan interface{}, 1)
go func() {
for config := range m.configCh {
certs := m.loadConfig(config)
m.setCerts(certs)
reloadCh <- certs
}
wg.Wait()
}
defer cancelUpdate()
}()
tick := time.NewTicker(5 * time.Minute)
defer tick.Stop()
for {
var c func()
select {
case <-tick.C:
c = startUpdate(m.certs)
case <-testUpdateCh:
c = startUpdate(m.certs)
case certDomains := <-m.configCh:
cancelUpdate()
m.certs = m.loadConfig(certDomains)
case <-ctx.Done():
return
}
if c != nil {
upCancel = nil
}
}
runWithUpdates(
ctx,
func(ctx context.Context, value interface{}) {
certs := value.([]*certInfo)
m.updateAllCerts(ctx, certs)
},
reloadCh,
updateInterval,
)
}
func concatDER(der [][]byte) []byte {
......@@ -324,7 +300,7 @@ func concatDER(der [][]byte) []byte {
func certRequest(key crypto.Signer, domains []string) ([]byte, error) {
req := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: domains[0]},
Subject: pkix.Name{CommonName: domains[0]},
DNSNames: domains,
}
return x509.CreateCertificateRequest(rand.Reader, req, key)
......
......@@ -77,36 +77,44 @@ func TestManager_Reload(t *testing.T) {
defer cleanup()
// Data race: we read data owned by another goroutine!
if len(m.certs) < 1 {
certs := m.getCerts()
if len(certs) < 1 {
t.Fatal("configuration not loaded?")
}
if m.certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn())
if certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn())
}
// Try a reload, catch obvious errors.
m.Reload()
time.Sleep(50 * time.Millisecond)
certs = m.getCerts()
if len(m.certs) != 1 {
if len(certs) != 1 {
t.Fatalf("certs count is %d, expected 1", len(m.certs))
}
if m.certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn())
if certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn())
}
}
func TestManager_NewCert(t *testing.T) {
var oldUpdateInterval time.Duration
oldUpdateInterval, updateInterval = updateInterval, 50*time.Millisecond
defer func() {
updateInterval = oldUpdateInterval
}()
cleanup, _, m := newTestManager(t, NewSelfSignedCertGenerator())
defer cleanup()
now := time.Now()
ci := m.certs[0]
certs := m.getCerts()
ci := certs[0]
if ci.retryDeadline.After(now) {
t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline)
}
testUpdateCh <- true
time.Sleep(100 * time.Millisecond)
// Verify that the retry/renewal timestamp is in the future.
......@@ -135,7 +143,7 @@ func TestManager_NewCert(t *testing.T) {
m.Reload()
time.Sleep(50 * time.Millisecond)
ci = m.certs[0]
ci = m.getCerts()[0]
if !ci.valid {
t.Fatal("certificate is invalid after a reload")
}
......
util.go 0 → 100644
package acmeserver
import (
"context"
"sync"
"time"
)
// Updates are long-term jobs, so they should be interruptible. We run
// updates in a separate goroutine, and cancel them when the
// configuration is reloaded or on exit. A semaphore ensures that only
// one update goroutine will be running at any given time (without
// other ones piling up).
func runWithUpdates(ctx context.Context, fn func(context.Context, interface{}), reloadCh <-chan interface{}, updateInterval time.Duration) {
// Function to cancel the current update, and the associated
// WaitGroup to wait for its termination.
var upCancel context.CancelFunc
var wg sync.WaitGroup
sem := make(chan struct{}, 1)
startUpdate := func(value interface{}) context.CancelFunc {
// Acquire the semaphore, return if we fail to.
// Equivalent to a 'try-lock' construct.
select {
case sem <- struct{}{}:
default:
return nil
}
defer func() {
<-sem
}()
ctx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
fn(ctx, value)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
upCancel = nil
}
wg.Wait()
}
defer cancelUpdate()
var cur interface{}
tick := time.NewTicker(updateInterval)
defer tick.Stop()
for {
select {
case <-tick.C:
// Do not cancel running update when running the ticker.
if cancel := startUpdate(cur); cancel != nil {
upCancel = cancel
}
case value := <-reloadCh:
// Cancel the running update when configuration is reloaded.
cancelUpdate()
cur = value
case <-ctx.Done():
return
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment