Skip to content
Snippets Groups Projects
server.go 7.42 KiB
Newer Older
ale's avatar
ale committed
package acmeserver

import (
	"context"
	"crypto"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"errors"
	"fmt"
	"io"
	"log"
	"path/filepath"
	"time"

	"git.autistici.org/ai3/go-common/clientutil"
	"golang.org/x/crypto/acme"
)

type certInfo struct {
	domains       []string
	retryDeadline time.Time
}

type certStorage interface {
	GetCert(string) ([][]byte, crypto.Signer, error)
	PutCert(string, [][]byte, crypto.Signer) error
}

type CertGenerator interface {
	GetCertificate(context.Context, crypto.Signer, []string) ([][]byte, *x509.Certificate, error)
}

// Manager periodically renews certificates before they expire, and
// responds to http-01 validation requests.
type Manager struct {
	//email          string
	//accountKeyPath string
	configDir string
	useRSA    bool
	storage   certStorage
	certs     []*certInfo
	certGen   CertGenerator

	configCh chan [][]string
	stopCh   chan bool
	doneCh   chan bool
}

// NewManager creates a new Manager with the given configuration.
func NewManager(config *Config, certGen CertGenerator) (*Manager, error) {
	// Validate the configuration.
	if config.Dir == "" {
		return nil, errors.New("configuration parameter 'cert_dir' is unset")
	}

	m := &Manager{
		useRSA:    config.UseRSA,
		configDir: filepath.Join(config.Dir, "config"),
		stopCh:    make(chan bool),
		doneCh:    make(chan bool),
		configCh:  make(chan [][]string, 1),
		certGen:   certGen,
	}

	ds := &dirStorage{root: filepath.Join(config.Dir, "certs")}
	if config.ReplDS == nil {
		m.storage = ds
	} else {
		be, err := clientutil.NewBackend(config.ReplDS)
		if err != nil {
			return nil, err
		}
		m.storage = &replStorage{
			dirStorage: ds,
			replClient: be,
		}
	}

	return m, nil
}

// Start the renewal processes. Canceling the provided context will
// cause background processing to stop.
func (m *Manager) Start(ctx context.Context) error {
	domains, err := readCertConfigsFromDir(m.configDir)
	if err != nil {
		return err
	}
	m.configCh <- domains
	go func() {
		m.loop(ctx)
		close(m.doneCh)
	}()
	return nil
}

// Stop any pending operation and release all resources.
func (m *Manager) Stop() {
	close(m.stopCh)
	<-m.doneCh
}

// Reload configuration.
func (m *Manager) Reload() {
	domains, err := readCertConfigsFromDir(m.configDir)
	if err != nil {
		log.Printf("error reading config: %v", err)
		return
	}
	m.configCh <- domains
	log.Printf("configuration reloaded")
}

var (
	renewalTimeout    = 10 * time.Minute
	errorRetryTimeout = 10 * time.Minute
)

func (m *Manager) updateAllCerts(ctx context.Context) {
	for _, certInfo := range m.certs {
		if certInfo.retryDeadline.Before(time.Now()) {
			continue
		}
		uctx, cancel := context.WithTimeout(ctx, renewalTimeout)
		err := m.updateCert(uctx, certInfo)
		cancel()
		if err != nil {
			log.Printf("error updating %s: %v", certInfo.domains[0], err)
			certInfo.retryDeadline = time.Now().Add(errorRetryTimeout)
		}
	}
}

func (m *Manager) updateCert(ctx context.Context, certInfo *certInfo) error {
	// Create a new private key.
	var (
		err error
		key crypto.Signer
	)
	if m.useRSA {
		key, err = rsa.GenerateKey(rand.Reader, 2048)
	} else {
		key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	}
	if err != nil {
		return err
	}

	der, leaf, err := m.certGen.GetCertificate(ctx, key, certInfo.domains)
	if err != nil {
		return err
	}

	if err := m.storage.PutCert(certInfo.domains[0], der, key); err != nil {
		return err
	}

	certInfo.retryDeadline = renewalDeadline(leaf)

	return nil
}

// Replace the current configuration.
func (m *Manager) loadConfig(certDomains [][]string) {
	var certs []*certInfo
	for _, domains := range certDomains {
		cn := domains[0]
		certInfo := &certInfo{
			domains: domains,
		}
		pub, priv, err := m.storage.GetCert(cn)
		if err != nil {
			log.Printf("cert for %s is missing", cn)
		} else {
			// By calling validCert we catch things like subjectAltName changes.
			leaf, err := validCert(domains, pub, priv)
			if err == nil {
				log.Printf("cert for %s loaded from storage", cn)
				certInfo.retryDeadline = renewalDeadline(leaf)
			} else {
				log.Printf("cert for %s loaded from storage but parameters have changed", cn)
			}
		}
		certs = append(certs, certInfo)
	}
	m.certs = certs
}

func (m *Manager) loop(ctx context.Context) {
	tick := time.NewTicker(5 * time.Minute)
	for {
		select {
		case <-tick.C:
			m.updateAllCerts(ctx)
		case certDomains := <-m.configCh:
			m.loadConfig(certDomains)
		case <-m.stopCh:
			return
		case <-ctx.Done():
			return
		}
	}
}

var renewalDays = 15

func renewalDeadline(cert *x509.Certificate) time.Time {
	return cert.NotAfter.AddDate(0, 0, -renewalDays)
}

func concatDER(der [][]byte) []byte {
	// Append DERs to a single []byte buffer and parse the results.
	var n int
	for _, b := range der {
		n += len(b)
	}
	out := make([]byte, n)
	n = 0
	for _, b := range der {
		n += copy(out[n:], b)
	}
	return out
}

func validCert(domains []string, der [][]byte, key crypto.Signer) (leaf *x509.Certificate, err error) {
	x509Cert, err := x509.ParseCertificates(concatDER(der))
	if err != nil {
		return nil, err
	}
	if len(x509Cert) == 0 {
		return nil, errors.New("no public key found")
	}
	leaf = x509Cert[0]

	// verify the leaf is not expired and matches the given domains
	now := time.Now()
	if now.Before(leaf.NotBefore) {
		return nil, errors.New("certificate isn't valid yet")
	}
	if now.After(leaf.NotAfter) {
		return nil, errors.New("certificate expired")
	}
	for _, domain := range domains {
		if err := leaf.VerifyHostname(domain); err != nil {
			return nil, fmt.Errorf("certificate does not match domain %q", domain)
		}
	}

	// ensure the leaf corresponds to the private key
	switch pub := leaf.PublicKey.(type) {
	case *rsa.PublicKey:
		prv, ok := key.(*rsa.PrivateKey)
		if !ok {
			return nil, errors.New("private key type does not match public key type")
		}
		if pub.N.Cmp(prv.N) != 0 {
			return nil, errors.New("private key does not match public key")
		}
	case *ecdsa.PublicKey:
		prv, ok := key.(*ecdsa.PrivateKey)
		if !ok {
			return nil, errors.New("private key type does not match public key type")
		}
		if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 {
			return nil, errors.New("private key does not match public key")
		}
	default:
		return nil, errors.New("unsupported public key algorithm")
	}
	return leaf, nil
}

func certRequest(key crypto.Signer, domains []string) ([]byte, error) {
	req := &x509.CertificateRequest{
		Subject: pkix.Name{CommonName: domains[0]},
	}
	if len(domains) > 1 {
		req.DNSNames = domains[1:]
	}
	return x509.CreateCertificateRequest(rand.Reader, req, key)
}

func parsePrivateKey(der []byte) (crypto.Signer, error) {
	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
		return key, nil
	}
	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
		switch key := key.(type) {
		case *rsa.PrivateKey:
			return key, nil
		case *ecdsa.PrivateKey:
			return key, nil
		default:
			return nil, errors.New("unknown private key type in PKCS#8 wrapping")
		}
	}
	if key, err := x509.ParseECPrivateKey(der); err == nil {
		return key, nil
	}

	return nil, errors.New("failed to parse private key")
}

func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error {
	b, err := x509.MarshalECPrivateKey(key)
	if err != nil {
		return err
	}
	pb := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
	return pem.Encode(w, pb)
}

func pickChallenge(typ string, chal []*acme.Challenge) *acme.Challenge {
	for _, c := range chal {
		if c.Type == typ {
			return c
		}
	}
	return nil
}