Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ai3/tools/acmeserver
  • godog/acmeserver
  • svp-bot/acmeserver
3 results
Show changes
Commits on Source (90)
Showing
with 1744 additions and 146 deletions
include: "https://git.autistici.org/ai3/build-deb/raw/master/ci-common.yml" include:
- "https://git.autistici.org/pipelines/debian/raw/master/common.yml"
- "https://git.autistici.org/pipelines/images/test/golang/raw/master/ci.yml"
...@@ -133,10 +133,16 @@ func (a *ACME) acmeClient(ctx context.Context) (*acme.Client, error) { ...@@ -133,10 +133,16 @@ func (a *ACME) acmeClient(ctx context.Context) (*acme.Client, error) {
// account is already registered we get a StatusConflict, // account is already registered we get a StatusConflict,
// which we can ignore. // which we can ignore.
_, err = client.Register(ctx, ac, func(_ string) bool { return true }) _, err = client.Register(ctx, ac, func(_ string) bool { return true })
if ae, ok := err.(*acme.Error); err == nil || ok && ae.StatusCode == http.StatusConflict { if ae, ok := err.(*acme.Error); err == nil || err == acme.ErrAccountAlreadyExists || (ok && ae.StatusCode == http.StatusConflict) {
a.client = client a.client = client
err = nil err = nil
} }
// Fetch account info and display it.
if acct, err := client.GetReg(ctx, ""); err == nil {
log.Printf("ACME account %s", acct.URI)
}
return a.client, err return a.client, err
} }
...@@ -149,7 +155,8 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon ...@@ -149,7 +155,8 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon
return nil, nil, err return nil, nil, err
} }
if err = a.verifyAll(ctx, client, c); err != nil { o, err := a.verifyAll(ctx, client, c)
if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -157,7 +164,7 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon ...@@ -157,7 +164,7 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
der, _, err = client.CreateCert(ctx, csr, 0, true) der, _, err = client.CreateOrderCert(ctx, o.FinalizeURL, csr, true)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
...@@ -168,56 +175,67 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon ...@@ -168,56 +175,67 @@ func (a *ACME) GetCertificate(ctx context.Context, key crypto.Signer, c *certCon
return der, leaf, nil return der, leaf, nil
} }
func (a *ACME) verifyAll(ctx context.Context, client *acme.Client, c *certConfig) error { func (a *ACME) verifyAll(ctx context.Context, client *acme.Client, c *certConfig) (*acme.Order, error) {
for _, domain := range c.Names {
if err := a.verify(ctx, client, c, domain); err != nil {
return err
}
}
return nil
}
func (a *ACME) verify(ctx context.Context, client *acme.Client, c *certConfig, domain string) error {
// Make an authorization request to the ACME server, and // Make an authorization request to the ACME server, and
// verify that it returns a valid response with challenges. // verify that it returns a valid response with challenges.
authz, err := client.Authorize(ctx, domain) o, err := client.AuthorizeOrder(ctx, acme.DomainIDs(c.Names...))
if err != nil { if err != nil {
return err return nil, fmt.Errorf("AuthorizeOrder failed: %v", err)
}
switch authz.Status {
case acme.StatusValid:
return nil // already authorized
case acme.StatusInvalid:
return fmt.Errorf("invalid authorization %q", authz.URI)
} }
// Pick a challenge that matches our preferences and the switch o.Status {
// available validators. The validator fulfills the challenge, case acme.StatusReady:
// and returns a cleanup function that we're going to call return o, nil // already authorized
// before we return. All steps are sequential and idempotent. case acme.StatusPending:
chal := a.pickChallenge(authz.Challenges, c) default:
if chal == nil { return nil, fmt.Errorf("invalid new order status %q", o.Status)
return fmt.Errorf("unable to authorize %q", domain)
} }
v, ok := a.validators[chal.Type]
if !ok { for _, zurl := range o.AuthzURLs {
return fmt.Errorf("challenge type '%s' is not available", chal.Type) z, err := client.GetAuthorization(ctx, zurl)
} if err != nil {
cleanup, err := v.Fulfill(ctx, client, domain, chal) return nil, fmt.Errorf("GetAuthorization(%s) failed: %v", zurl, err)
if err != nil { }
return err if z.Status != acme.StatusPending {
continue
}
// Pick a challenge that matches our preferences and the
// available validators. The validator fulfills the challenge,
// and returns a cleanup function that we're going to call
// before we return. All steps are sequential and idempotent.
chal := a.pickChallenge(z.Challenges, c)
if chal == nil {
return nil, fmt.Errorf("unable to authorize %q", c.Names)
}
v, ok := a.validators[chal.Type]
if !ok {
return nil, fmt.Errorf("challenge type '%s' is not available", chal.Type)
}
log.Printf("attempting fulfillment for %q (identifier: %+v)", c.Names, z.Identifier)
for _, domain := range c.Names {
cleanup, err := v.Fulfill(ctx, client, domain, chal)
if err != nil {
return nil, fmt.Errorf("fulfillment failed: %v", err)
}
defer cleanup()
}
if _, err := client.Accept(ctx, chal); err != nil {
return nil, fmt.Errorf("challenge accept failed: %v", err)
}
if _, err := client.WaitAuthorization(ctx, z.URI); err != nil {
return nil, fmt.Errorf("WaitAuthorization(%s) failed: %v", z.URI, err)
}
} }
defer cleanup()
// Tell the ACME server that we've accepted the challenge, and // Authorizations are satisfied, wait for the CA
// then wait, possibly for some time, until there is an // to update the order status.
// authorization response (either successful or not) from the if _, err = client.WaitOrder(ctx, o.URI); err != nil {
// server. return nil, err
if _, err = client.Accept(ctx, chal); err != nil {
return err
} }
_, err = client.WaitAuthorization(ctx, authz.URI) return o, nil
return err
} }
// Pick a challenge with the right type from the Challenge response // Pick a challenge with the right type from the Challenge response
......
...@@ -3,16 +3,15 @@ package main ...@@ -3,16 +3,15 @@ package main
import ( import (
"context" "context"
"flag" "flag"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"git.autistici.org/ai3/tools/acmeserver"
"git.autistici.org/ai3/go-common/serverutil" "git.autistici.org/ai3/go-common/serverutil"
"gopkg.in/yaml.v2" "git.autistici.org/ai3/tools/acmeserver"
"gopkg.in/yaml.v3"
) )
var ( var (
...@@ -29,12 +28,13 @@ type Config struct { ...@@ -29,12 +28,13 @@ type Config struct {
func loadConfig(path string) (*Config, error) { func loadConfig(path string) (*Config, error) {
// Read YAML config. // Read YAML config.
data, err := ioutil.ReadFile(path) // nolint: gosec f, err := os.Open(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close()
var config Config var config Config
if err := yaml.Unmarshal(data, &config); err != nil { if err := yaml.NewDecoder(f).Decode(&config); err != nil {
return nil, err return nil, err
} }
return &config, nil return &config, nil
......
...@@ -3,11 +3,11 @@ package acmeserver ...@@ -3,11 +3,11 @@ package acmeserver
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"os"
"path/filepath" "path/filepath"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
"git.autistici.org/ai3/go-common/clientutil" "git.autistici.org/ai3/go-common/clientutil"
) )
...@@ -77,12 +77,13 @@ func (c *certConfig) check() error { ...@@ -77,12 +77,13 @@ func (c *certConfig) check() error {
} }
func readCertConfigs(path string) ([]*certConfig, error) { func readCertConfigs(path string) ([]*certConfig, error) {
data, err := ioutil.ReadFile(path) // nolint: gosec f, err := os.Open(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close()
var cc []*certConfig var cc []*certConfig
if err := yaml.Unmarshal(data, &cc); err != nil { if err := yaml.NewDecoder(f).Decode(&cc); err != nil {
return nil, err return nil, err
} }
return cc, nil return cc, nil
......
...@@ -2,4 +2,4 @@ acmeserver (2.0) unstable; urgency=medium ...@@ -2,4 +2,4 @@ acmeserver (2.0) unstable; urgency=medium
* Initial Release. * Initial Release.
-- Autistici/Inventati <debian@autistici.org> Sat, 15 Jun 2018 09:23:40 +0000 -- Autistici/Inventati <debian@autistici.org> Fri, 15 Jun 2018 09:23:40 +0000
10
...@@ -2,12 +2,12 @@ Source: acmeserver ...@@ -2,12 +2,12 @@ Source: acmeserver
Section: admin Section: admin
Priority: optional Priority: optional
Maintainer: Autistici/Inventati <debian@autistici.org> Maintainer: Autistici/Inventati <debian@autistici.org>
Build-Depends: debhelper (>=9), golang-any (>=1.11), dh-systemd, dh-golang Build-Depends: debhelper-compat (= 13), golang-any (>= 1.11), dh-golang
Standards-Version: 3.9.6 Standards-Version: 3.9.6
Package: acmeserver Package: acmeserver
Architecture: any Architecture: any
Depends: ${shlibs:Depends}, ${misc:Depends} Depends: ${shlibs:Depends}, ${misc:Depends}
Built-Using: ${misc:Built-Using}
Description: ACME server Description: ACME server
Automatically manages and renews public SSL certificates. Automatically manages and renews public SSL certificates.
...@@ -2,18 +2,17 @@ ...@@ -2,18 +2,17 @@
export DH_GOPKG = git.autistici.org/ai3/tools/acmeserver export DH_GOPKG = git.autistici.org/ai3/tools/acmeserver
export DH_GOLANG_EXCLUDES = vendor export DH_GOLANG_EXCLUDES = vendor
export DH_GOLANG_INSTALL_ALL := 1
%: %:
dh $@ --with systemd --with golang --buildsystem golang dh $@ --with golang --buildsystem golang
override_dh_auto_install: override_dh_auto_install:
dh_auto_install dh_auto_install -- --no-source
$(RM) -rv debian/acmeserver/usr/share/gocode
override_dh_systemd_enable: override_dh_installsystemd:
dh_systemd_enable --no-enable dh_installsystemd --no-enable
override_dh_systemd_start:
dh_systemd_start --no-start
override_dh_installsystemd:
dh_installsystemd --no-start
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/net/publicsuffix"
) )
const ( const (
...@@ -87,8 +88,16 @@ func (d *dnsValidator) client() *dns.Client { ...@@ -87,8 +88,16 @@ func (d *dnsValidator) client() *dns.Client {
} }
func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain string, chal *acme.Challenge) (func(), error) { func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain string, chal *acme.Challenge) (func(), error) {
zone := domain[strings.Index(domain, ".")+1:] domain = strings.TrimPrefix(domain, "*.")
fqdn := dns.Fqdn(domain)
zone, err := publicsuffix.EffectiveTLDPlusOne(domain)
if err != nil {
return nil, fmt.Errorf("could not determine effective tld: %w", err)
}
zone = dns.Fqdn(zone)
fqdn := dns.Fqdn("_acme-challenge." + domain)
value, err := client.DNS01ChallengeRecord(chal.Token) value, err := client.DNS01ChallengeRecord(chal.Token)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -109,6 +118,7 @@ func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain ...@@ -109,6 +118,7 @@ func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain
} }
func (d *dnsValidator) updateNS(ctx context.Context, ns, zone, fqdn, value string, remove bool) error { func (d *dnsValidator) updateNS(ctx context.Context, ns, zone, fqdn, value string, remove bool) error {
log.Printf("updateNS(%s, %s, %s, %s, %v)", ns, zone, fqdn, value, remove)
rrs := d.makeRR(fqdn, value, rfc2136Timeout) rrs := d.makeRR(fqdn, value, rfc2136Timeout)
m := d.makeMsg(zone, rrs, remove) m := d.makeMsg(zone, rrs, remove)
c := d.client() c := d.client()
......
module git.autistici.org/ai3/tools/acmeserver
go 1.19
require (
git.autistici.org/ai3/go-common v0.0.0-20230816213645-b3aa3fb514d6
git.autistici.org/ai3/tools/replds v0.0.0-20230923170339-b6e6e3cc032b
github.com/miekg/dns v1.1.50
github.com/prometheus/client_golang v1.12.2
golang.org/x/crypto v0.24.0
golang.org/x/net v0.26.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/NYTimes/gziphandler v1.1.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/openzipkin/zipkin-go v0.4.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.34.0 // indirect
go.opentelemetry.io/contrib/propagators/b3 v1.9.0 // indirect
go.opentelemetry.io/otel v1.10.0 // indirect
go.opentelemetry.io/otel/exporters/zipkin v1.9.0 // indirect
go.opentelemetry.io/otel/metric v0.31.0 // indirect
go.opentelemetry.io/otel/sdk v1.10.0 // indirect
go.opentelemetry.io/otel/trace v1.10.0 // indirect
golang.org/x/mod v0.4.2 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.27.1 // indirect
)
This diff is collapsed.
...@@ -16,10 +16,15 @@ import ( ...@@ -16,10 +16,15 @@ import (
"sync" "sync"
"time" "time"
"git.autistici.org/ai3/replds" "git.autistici.org/ai3/tools/replds"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
var (
defaultRenewalDays = 21
updateInterval = 1 * time.Minute
)
// certInfo represents what we know about the state of the certificate // certInfo represents what we know about the state of the certificate
// at runtime. // at runtime.
type certInfo struct { type certInfo struct {
...@@ -71,18 +76,19 @@ type CertGenerator interface { ...@@ -71,18 +76,19 @@ type CertGenerator interface {
GetCertificate(context.Context, crypto.Signer, *certConfig) ([][]byte, *x509.Certificate, error) GetCertificate(context.Context, crypto.Signer, *certConfig) ([][]byte, *x509.Certificate, error)
} }
// Manager periodically renews certificates before they expire, and // Manager periodically renews certificates before they expire.
// responds to http-01 validation requests.
type Manager struct { type Manager struct {
configDirs []string configDirs []string
useRSA bool useRSA bool
storage certStorage storage certStorage
certs []*certInfo
certGen CertGenerator certGen CertGenerator
renewalDays int renewalDays int
configCh chan []*certConfig configCh chan []*certConfig
doneCh chan bool doneCh chan bool
mx sync.Mutex
certs []*certInfo
} }
// NewManager creates a new Manager with the given configuration. // NewManager creates a new Manager with the given configuration.
...@@ -94,6 +100,9 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) { ...@@ -94,6 +100,9 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
if config.Output.Path == "" { if config.Output.Path == "" {
return nil, errors.New("'output.path' is unset") return nil, errors.New("'output.path' is unset")
} }
if config.RenewalDays <= 0 {
config.RenewalDays = defaultRenewalDays
}
m := &Manager{ m := &Manager{
useRSA: config.UseRSA, useRSA: config.UseRSA,
...@@ -103,9 +112,6 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) { ...@@ -103,9 +112,6 @@ func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
certGen: certGen, certGen: certGen,
renewalDays: config.RenewalDays, renewalDays: config.RenewalDays,
} }
if m.renewalDays <= 0 {
m.renewalDays = 15
}
ds := &dirStorage{root: config.Output.Path} ds := &dirStorage{root: config.Output.Path}
if config.Output.ReplDS == nil { if config.Output.ReplDS == nil {
...@@ -158,7 +164,7 @@ func (m *Manager) Reload() { ...@@ -158,7 +164,7 @@ func (m *Manager) Reload() {
var ( var (
renewalTimeout = 10 * time.Minute renewalTimeout = 10 * time.Minute
errorRetryTimeout = 10 * time.Minute errorRetryTimeout = 6 * time.Hour
) )
func (m *Manager) updateAllCerts(ctx context.Context, certs []*certInfo) { func (m *Manager) updateAllCerts(ctx context.Context, certs []*certInfo) {
...@@ -241,55 +247,41 @@ func (m *Manager) loadConfig(certs []*certConfig) []*certInfo { ...@@ -241,55 +247,41 @@ func (m *Manager) loadConfig(certs []*certConfig) []*certInfo {
return out return out
} }
func (m *Manager) getCerts() []*certInfo {
m.mx.Lock()
defer m.mx.Unlock()
return m.certs
}
func (m *Manager) setCerts(certs []*certInfo) {
m.mx.Lock()
m.certs = certs
m.mx.Unlock()
}
// This channel is used by the testing code to trigger an update, // This channel is used by the testing code to trigger an update,
// without having to wait for the timer to tick. // without having to wait for the timer to tick.
var testUpdateCh = make(chan bool) var testUpdateCh = make(chan bool)
func (m *Manager) loop(ctx context.Context) { func (m *Manager) loop(ctx context.Context) {
// Updates are long-term jobs, so they should be reloadCh := make(chan interface{}, 1)
// interruptible. We run updates in a separate goroutine, and go func() {
// cancel them when the configuration is reloaded or on exit. for config := range m.configCh {
var upCancel context.CancelFunc certs := m.loadConfig(config)
var wg sync.WaitGroup m.setCerts(certs)
reloadCh <- certs
startUpdate := func(certs []*certInfo) context.CancelFunc {
// Ensure the previous update has finished.
wg.Wait()
upCtx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
m.updateAllCerts(upCtx, certs)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
} }
wg.Wait() }()
}
defer cancelUpdate()
tick := time.NewTicker(5 * time.Minute) runWithUpdates(
defer tick.Stop() ctx,
for { func(ctx context.Context, value interface{}) {
select { certs := value.([]*certInfo)
case <-tick.C: m.updateAllCerts(ctx, certs)
upCancel = startUpdate(m.certs) },
case <-testUpdateCh: reloadCh,
upCancel = startUpdate(m.certs) updateInterval,
case certDomains := <-m.configCh: )
cancelUpdate()
m.certs = m.loadConfig(certDomains)
case <-ctx.Done():
return
}
}
} }
func concatDER(der [][]byte) []byte { func concatDER(der [][]byte) []byte {
...@@ -308,10 +300,8 @@ func concatDER(der [][]byte) []byte { ...@@ -308,10 +300,8 @@ func concatDER(der [][]byte) []byte {
func certRequest(key crypto.Signer, domains []string) ([]byte, error) { func certRequest(key crypto.Signer, domains []string) ([]byte, error) {
req := &x509.CertificateRequest{ req := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: domains[0]}, Subject: pkix.Name{CommonName: domains[0]},
} DNSNames: domains,
if len(domains) > 1 {
req.DNSNames = domains[1:]
} }
return x509.CreateCertificateRequest(rand.Reader, req, key) return x509.CreateCertificateRequest(rand.Reader, req, key)
} }
......
...@@ -77,36 +77,44 @@ func TestManager_Reload(t *testing.T) { ...@@ -77,36 +77,44 @@ func TestManager_Reload(t *testing.T) {
defer cleanup() defer cleanup()
// Data race: we read data owned by another goroutine! // Data race: we read data owned by another goroutine!
if len(m.certs) < 1 { certs := m.getCerts()
if len(certs) < 1 {
t.Fatal("configuration not loaded?") t.Fatal("configuration not loaded?")
} }
if m.certs[0].cn() != "example.com" { if certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn()) t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn())
} }
// Try a reload, catch obvious errors. // Try a reload, catch obvious errors.
m.Reload() m.Reload()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
certs = m.getCerts()
if len(m.certs) != 1 { if len(certs) != 1 {
t.Fatalf("certs count is %d, expected 1", len(m.certs)) t.Fatalf("certs count is %d, expected 1", len(m.certs))
} }
if m.certs[0].cn() != "example.com" { if certs[0].cn() != "example.com" {
t.Fatalf("certs[0].cn() is %s, expected example.com", m.certs[0].cn()) t.Fatalf("certs[0].cn() is %s, expected example.com", certs[0].cn())
} }
} }
func TestManager_NewCert(t *testing.T) { func TestManager_NewCert(t *testing.T) {
var oldUpdateInterval time.Duration
oldUpdateInterval, updateInterval = updateInterval, 50*time.Millisecond
defer func() {
updateInterval = oldUpdateInterval
}()
cleanup, _, m := newTestManager(t, NewSelfSignedCertGenerator()) cleanup, _, m := newTestManager(t, NewSelfSignedCertGenerator())
defer cleanup() defer cleanup()
now := time.Now() now := time.Now()
ci := m.certs[0] certs := m.getCerts()
ci := certs[0]
if ci.retryDeadline.After(now) { if ci.retryDeadline.After(now) {
t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline) t.Fatalf("retry deadline is in the future: %v", ci.retryDeadline)
} }
testUpdateCh <- true
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Verify that the retry/renewal timestamp is in the future. // Verify that the retry/renewal timestamp is in the future.
...@@ -135,7 +143,7 @@ func TestManager_NewCert(t *testing.T) { ...@@ -135,7 +143,7 @@ func TestManager_NewCert(t *testing.T) {
m.Reload() m.Reload()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
ci = m.certs[0] ci = m.getCerts()[0]
if !ci.valid { if !ci.valid {
t.Fatal("certificate is invalid after a reload") t.Fatal("certificate is invalid after a reload")
} }
......
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:base"
]
}
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
"strings" "strings"
"time" "time"
"git.autistici.org/ai3/replds" "git.autistici.org/ai3/tools/replds"
) )
type dirStorage struct { type dirStorage struct {
...@@ -79,6 +79,14 @@ func dumpCertsAndKey(cn string, der [][]byte, key crypto.Signer) (map[string][]b ...@@ -79,6 +79,14 @@ func dumpCertsAndKey(cn string, der [][]byte, key crypto.Signer) (map[string][]b
} }
m[filepath.Join(cn, "cert.pem")] = data m[filepath.Join(cn, "cert.pem")] = data
if len(der) > 1 {
data, err = encodeCerts(der[1:])
if err != nil {
return nil, err
}
m[filepath.Join(cn, "chain.pem")] = data
}
data, err = encodePrivateKey(key) data, err = encodePrivateKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -104,7 +112,7 @@ func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error ...@@ -104,7 +112,7 @@ func (d *replStorage) PutCert(cn string, der [][]byte, key crypto.Signer) error
now := time.Now() now := time.Now()
var req replds.SetNodesRequest var req replds.SetNodesRequest
for path, data := range filemap { for path, data := range filemap {
req.Nodes = append(req.Nodes, replds.Node{ req.Nodes = append(req.Nodes, &replds.Node{
Path: path, Path: path,
Value: data, Value: data,
Timestamp: now, Timestamp: now,
......
package acmeserver
import (
"context"
"sync"
"time"
)
// Updates are long-term jobs, so they should be interruptible. We run
// updates in a separate goroutine, and cancel them when the
// configuration is reloaded or on exit. A semaphore ensures that only
// one update goroutine will be running at any given time (without
// other ones piling up).
func runWithUpdates(ctx context.Context, fn func(context.Context, interface{}), reloadCh <-chan interface{}, updateInterval time.Duration) {
// Function to cancel the current update, and the associated
// WaitGroup to wait for its termination.
var upCancel context.CancelFunc
var wg sync.WaitGroup
sem := make(chan struct{}, 1)
startUpdate := func(value interface{}) context.CancelFunc {
// Acquire the semaphore, return if we fail to.
// Equivalent to a 'try-lock' construct.
select {
case sem <- struct{}{}:
default:
return nil
}
defer func() {
<-sem
}()
ctx, cancel := context.WithCancel(ctx)
wg.Add(1)
go func() {
fn(ctx, value)
wg.Done()
}()
return cancel
}
// Cancel the running update, if any. Called on config
// updates, when exiting.
cancelUpdate := func() {
if upCancel != nil {
upCancel()
upCancel = nil
}
wg.Wait()
}
defer cancelUpdate()
var cur interface{}
tick := time.NewTicker(updateInterval)
defer tick.Stop()
for {
select {
case <-tick.C:
// Do not cancel running update when running the ticker.
if cancel := startUpdate(cur); cancel != nil {
upCancel = cancel
}
case value := <-reloadCh:
// Cancel the running update when configuration is reloaded.
cancelUpdate()
cur = value
case <-ctx.Done():
return
}
}
}
stages:
- test
run_tests:
stage: test
image: registry.git.autistici.org/ai3/docker/test/golang:master
script:
- run-go-test ./...
artifacts:
when: always
reports:
coverage_report:
coverage_format: cobertura
path: cover.xml
junit: report.xml
...@@ -2,6 +2,7 @@ package clientutil ...@@ -2,6 +2,7 @@ package clientutil
import ( import (
"context" "context"
"net/http"
) )
// BackendConfig specifies the configuration of a service backend. // BackendConfig specifies the configuration of a service backend.
...@@ -16,6 +17,13 @@ type BackendConfig struct { ...@@ -16,6 +17,13 @@ type BackendConfig struct {
TLSConfig *TLSClientConfig `yaml:"tls"` TLSConfig *TLSClientConfig `yaml:"tls"`
Sharded bool `yaml:"sharded"` Sharded bool `yaml:"sharded"`
Debug bool `yaml:"debug"` Debug bool `yaml:"debug"`
// Connection timeout (if unset, use default value).
ConnectTimeout string `yaml:"connect_timeout"`
// Maximum timeout for each individual request to this backend
// (if unset, use the Context timeout).
RequestMaxTimeout string `yaml:"request_max_timeout"`
} }
// Backend is a runtime class that provides http Clients for use with // Backend is a runtime class that provides http Clients for use with
...@@ -32,6 +40,13 @@ type Backend interface { ...@@ -32,6 +40,13 @@ type Backend interface {
// *without* a shard ID on a sharded service is an error. // *without* a shard ID on a sharded service is an error.
Call(context.Context, string, string, interface{}, interface{}) error Call(context.Context, string, string, interface{}, interface{}) error
// Make a simple HTTP GET request to the remote backend,
// without parsing the response as JSON.
//
// Useful for streaming large responses, where the JSON
// encoding overhead is undesirable.
Get(context.Context, string, string) (*http.Response, error)
// Close all resources associated with the backend. // Close all resources associated with the backend.
Close() Close()
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
"mime"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
...@@ -16,7 +17,7 @@ import ( ...@@ -16,7 +17,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/cenkalti/backoff" "github.com/cenkalti/backoff/v4"
) )
// Our own narrow logger interface. // Our own narrow logger interface.
...@@ -60,10 +61,11 @@ func newExponentialBackOff() *backoff.ExponentialBackOff { ...@@ -60,10 +61,11 @@ func newExponentialBackOff() *backoff.ExponentialBackOff {
type balancedBackend struct { type balancedBackend struct {
*backendTracker *backendTracker
*transportCache *transportCache
baseURI *url.URL baseURI *url.URL
sharded bool sharded bool
resolver resolver resolver resolver
log logger log logger
requestMaxTimeout time.Duration
} }
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) { func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
...@@ -80,17 +82,36 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack ...@@ -80,17 +82,36 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack
} }
} }
var connectTimeout time.Duration
if config.ConnectTimeout != "" {
t, err := time.ParseDuration(config.ConnectTimeout)
if err != nil {
return nil, fmt.Errorf("error in connect_timeout: %v", err)
}
connectTimeout = t
}
var reqTimeout time.Duration
if config.RequestMaxTimeout != "" {
t, err := time.ParseDuration(config.RequestMaxTimeout)
if err != nil {
return nil, fmt.Errorf("error in request_max_timeout: %v", err)
}
reqTimeout = t
}
var logger logger = &nilLogger{} var logger logger = &nilLogger{}
if config.Debug { if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0) logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
} }
return &balancedBackend{ return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger), backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig), transportCache: newTransportCache(tlsConfig, connectTimeout),
sharded: config.Sharded, requestMaxTimeout: reqTimeout,
baseURI: u, sharded: config.Sharded,
resolver: resolver, baseURI: u,
log: logger, resolver: resolver,
log: logger,
}, nil }, nil
} }
...@@ -115,6 +136,9 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res ...@@ -115,6 +136,9 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len()) innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
} }
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
// Call the backends in the sequence until one succeeds, with an // Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context. // exponential backoff policy controlled by the outer Context.
...@@ -135,7 +159,7 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res ...@@ -135,7 +159,7 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
defer httpResp.Body.Close() // nolint defer httpResp.Body.Close() // nolint
// Decode the response, unless the 'resp' output is nil. // Decode the response, unless the 'resp' output is nil.
if httpResp.Header.Get("Content-Type") != "application/json" { if ct, _, _ := mime.ParseMediaType(httpResp.Header.Get("Content-Type")); ct != "application/json" {
return errors.New("not a JSON response") return errors.New("not a JSON response")
} }
if resp == nil { if resp == nil {
...@@ -145,6 +169,44 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res ...@@ -145,6 +169,44 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
}, backoff.WithContext(newExponentialBackOff(), ctx)) }, backoff.WithContext(newExponentialBackOff(), ctx))
} }
// Makes a generic HTTP GET request to the backend uri.
func (b *balancedBackend) Get(ctx context.Context, shard, path string) (*http.Response, error) {
// Create the target sequence for this call. If there are multiple
// targets, reduce the timeout on each individual call accordingly to
// accomodate eventual failover.
seq, err := b.makeSequence(shard)
if err != nil {
return nil, err
}
innerTimeout := 1 * time.Hour
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
req, err := http.NewRequest("GET", b.getURIForRequest(shard, path), nil)
if err != nil {
return nil, err
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
var resp *http.Response
err = backoff.Retry(func() error {
innerCtx, cancel := context.WithTimeout(ctx, innerTimeout)
defer cancel()
// When do() returns successfully, we already know that the
// response had an HTTP status of 200.
var rerr error
resp, rerr = b.do(innerCtx, seq, req)
return rerr
}, backoff.WithContext(newExponentialBackOff(), ctx))
return resp, err
}
// Initialize a new target sequence. // Initialize a new target sequence.
func (b *balancedBackend) makeSequence(shard string) (*sequence, error) { func (b *balancedBackend) makeSequence(shard string) (*sequence, error) {
var tg targetGenerator = b.backendTracker var tg targetGenerator = b.backendTracker
...@@ -198,7 +260,7 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque ...@@ -198,7 +260,7 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
client := &http.Client{ client := &http.Client{
Transport: b.transportCache.getTransport(target), Transport: b.transportCache.getTransport(target),
} }
resp, err = client.Do(req.WithContext(ctx)) resp, err = client.Do(propagateDeadline(ctx, req))
if err == nil && resp.StatusCode != 200 { if err == nil && resp.StatusCode != 200 {
err = remoteErrorFromResponse(resp) err = remoteErrorFromResponse(resp)
if !isStatusTemporary(resp.StatusCode) { if !isStatusTemporary(resp.StatusCode) {
...@@ -212,6 +274,19 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque ...@@ -212,6 +274,19 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
return return
} }
const deadlineHeader = "X-RPC-Deadline"
// Propagate context deadline to the server using a HTTP header.
func propagateDeadline(ctx context.Context, req *http.Request) *http.Request {
req = req.WithContext(ctx)
if deadline, ok := ctx.Deadline(); ok {
req.Header.Set(deadlineHeader, strconv.FormatInt(deadline.UTC().UnixNano(), 10))
} else {
req.Header.Del(deadlineHeader)
}
return req
}
var errNoTargets = errors.New("no available backends") var errNoTargets = errors.New("no available backends")
type targetGenerator interface { type targetGenerator interface {
......
// +build go1.9
package clientutil
import (
"context"
"net"
"time"
)
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, net, addr)
}
}