diff --git a/manager.go b/manager.go index eef527c79c4358182074ffeef426bfaf3695ec71..67a767f43f9748358cd8488f5b6be810e119cc83 100644 --- a/manager.go +++ b/manager.go @@ -20,6 +20,11 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +var ( + defaultRenewalDays = 21 + updateInterval = 1 * time.Minute +) + // certInfo represents what we know about the state of the certificate // at runtime. type certInfo struct { @@ -71,18 +76,19 @@ type CertGenerator interface { GetCertificate(context.Context, crypto.Signer, *certConfig) ([][]byte, *x509.Certificate, error) } -// Manager periodically renews certificates before they expire, and -// responds to http-01 validation requests. +// Manager periodically renews certificates before they expire. type Manager struct { configDirs []string useRSA bool storage certStorage - certs []*certInfo certGen CertGenerator renewalDays int configCh chan []*certConfig doneCh chan bool + + mx sync.Mutex + certs []*certInfo } // NewManager creates a new Manager with the given configuration. @@ -94,6 +100,9 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) { if config.Output.Path == "" { return nil, errors.New("'output.path' is unset") } + if config.RenewalDays <= 0 { + config.RenewalDays = defaultRenewalDays + } m := &Manager{ useRSA: config.UseRSA, @@ -103,9 +112,6 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) { certGen: certGen, renewalDays: config.RenewalDays, } - if m.renewalDays <= 0 { - m.renewalDays = 15 - } ds := &dirStorage{root: config.Output.Path} if config.Output.ReplDS == nil { @@ -241,71 +247,41 @@ func (m *Manager) loadConfig(certs []*certConfig) []*certInfo { return out } +func (m *Manager) getCerts() []*certInfo { + m.mx.Lock() + defer m.mx.Unlock() + return m.certs +} + +func (m *Manager) setCerts(certs []*certInfo) { + m.mx.Lock() + m.certs = certs + m.mx.Unlock() +} + // This channel is used by the testing code to trigger an update, // without having to wait for the timer to tick. var testUpdateCh = make(chan bool) func (m *Manager) loop(ctx context.Context) { - // Updates are long-term jobs, so they should be - // interruptible. We run updates in a separate goroutine, and - // cancel them when the configuration is reloaded or on exit. - // A simple channel is used as a semaphore, so that only one - // update goroutine can be running at any given time (without - // other ones piling up). - var upCancel context.CancelFunc - var wg sync.WaitGroup - sem := make(chan struct{}, 1) - - startUpdate := func(certs []*certInfo) context.CancelFunc { - // Acquire the semaphore, return if we fail to. - select { - case sem <- struct{}{}: - default: - return nil - } - defer func() { - <-sem - }() - - upCtx, cancel := context.WithCancel(ctx) - wg.Add(1) - go func() { - m.updateAllCerts(upCtx, certs) - wg.Done() - }() - return cancel - } - - // Cancel the running update, if any. Called on config - // updates, when exiting. - cancelUpdate := func() { - if upCancel != nil { - upCancel() - upCancel = nil + reloadCh := make(chan interface{}, 1) + go func() { + for config := range m.configCh { + certs := m.loadConfig(config) + m.setCerts(certs) + reloadCh <- certs } - wg.Wait() - } - defer cancelUpdate() + }() - tick := time.NewTicker(5 * time.Minute) - defer tick.Stop() - for { - var c func() - select { - case <-tick.C: - c = startUpdate(m.certs) - case <-testUpdateCh: - c = startUpdate(m.certs) - case certDomains := <-m.configCh: - cancelUpdate() - m.certs = m.loadConfig(certDomains) - case <-ctx.Done(): - return - } - if c != nil { - upCancel = nil - } - } + runWithUpdates( + ctx, + func(ctx context.Context, value interface{}) { + certs := value.([]*certInfo) + m.updateAllCerts(ctx, certs) + }, + reloadCh, + updateInterval, + ) } func concatDER(der [][]byte) []byte { @@ -324,7 +300,7 @@ func concatDER(der [][]byte) []byte { func certRequest(key crypto.Signer, domains []string) ([]byte, error) { req := &x509.CertificateRequest{ - Subject: pkix.Name{CommonName: domains[0]}, + Subject: pkix.Name{CommonName: domains[0]}, DNSNames: domains, } return x509.CreateCertificateRequest(rand.Reader, req, key) diff --git a/manager_test.go b/manager_test.go index 4ee0ce2a36dd4a9104e2a6cdf43310a0ded5da0d..1c735361ce0878e3761f2807e9c79a5adaab745b 100644 --- a/manager_test.go +++ b/manager_test.go @@ -77,36 +77,44 @@ func TestManager_Reload(t *testing.T) { defer cleanup() // Data race: we read data owned by another goroutine! - if len(m.certs) < 1 { + certs := m.getCerts() + if len(certs) < 1 { t.Fatal("configuration not loaded?") } - if m.certs[0].cn() != "example.com" { - t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn()) + if certs[0].cn() != "example.com" { + t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn()) } // Try a reload, catch obvious errors. m.Reload() time.Sleep(50 * time.Millisecond) + certs = m.getCerts() - if len(m.certs) != 1 { + if len(certs) != 1 { t.Fatalf("certs count is %d, expected 1", len(m.certs)) } - if m.certs[0].cn() != "example.com" { - t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn()) + if certs[0].cn() != "example.com" { + t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn()) } } func TestManager_NewCert(t *testing.T) { + var oldUpdateInterval time.Duration + oldUpdateInterval, updateInterval = updateInterval, 50*time.Millisecond + defer func() { + updateInterval = oldUpdateInterval + }() + cleanup, _, m := newTestManager(t, NewSelfSignedCertGenerator()) defer cleanup() now := time.Now() - ci := m.certs[0] + certs := m.getCerts() + ci := certs[0] if ci.retryDeadline.After(now) { t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline) } - testUpdateCh <- true time.Sleep(100 * time.Millisecond) // Verify that the retry/renewal timestamp is in the future. @@ -135,7 +143,7 @@ func TestManager_NewCert(t *testing.T) { m.Reload() time.Sleep(50 * time.Millisecond) - ci = m.certs[0] + ci = m.getCerts()[0] if !ci.valid { t.Fatal("certificate is invalid after a reload") } diff --git a/util.go b/util.go new file mode 100644 index 0000000000000000000000000000000000000000..9801990310b630f08d9b47ec1ca334743883972f --- /dev/null +++ b/util.go @@ -0,0 +1,71 @@ +package acmeserver + +import ( + "context" + "sync" + "time" +) + +// Updates are long-term jobs, so they should be interruptible. We run +// updates in a separate goroutine, and cancel them when the +// configuration is reloaded or on exit. A semaphore ensures that only +// one update goroutine will be running at any given time (without +// other ones piling up). +func runWithUpdates(ctx context.Context, fn func(context.Context, interface{}), reloadCh <-chan interface{}, updateInterval time.Duration) { + // Function to cancel the current update, and the associated + // WaitGroup to wait for its termination. + var upCancel context.CancelFunc + var wg sync.WaitGroup + sem := make(chan struct{}, 1) + + startUpdate := func(value interface{}) context.CancelFunc { + // Acquire the semaphore, return if we fail to. + // Equivalent to a 'try-lock' construct. + select { + case sem <- struct{}{}: + default: + return nil + } + defer func() { + <-sem + }() + + ctx, cancel := context.WithCancel(ctx) + wg.Add(1) + go func() { + fn(ctx, value) + wg.Done() + }() + return cancel + } + + // Cancel the running update, if any. Called on config + // updates, when exiting. + cancelUpdate := func() { + if upCancel != nil { + upCancel() + upCancel = nil + } + wg.Wait() + } + defer cancelUpdate() + + var cur interface{} + tick := time.NewTicker(updateInterval) + defer tick.Stop() + for { + select { + case <-tick.C: + // Do not cancel running update when running the ticker. + if cancel := startUpdate(cur); cancel != nil { + upCancel = cancel + } + case value := <-reloadCh: + // Cancel the running update when configuration is reloaded. + cancelUpdate() + cur = value + case <-ctx.Done(): + return + } + } +}