Skip to content
Snippets Groups Projects
dns_challenge.go 3.39 KiB
Newer Older
ale's avatar
ale committed
package acmeserver

import (
	"context"
	"errors"
	"fmt"
	"log"
	"strings"
	"time"

	"github.com/miekg/dns"
	"golang.org/x/crypto/acme"
)

const (
	rfc2136Timeout   = 600
	tsigFudgeSeconds = 300
)

type dnsValidator struct {
	nameservers []string
	enableTSIG  bool
	keyName     string
	keyAlgo     string
	keySecret   string
}

func newDNSValidator(config *Config) (*dnsValidator, error) {
	if len(config.DNS.Nameservers) == 0 {
		return nil, errors.New("no nameservers configured")
	}

	// Check that the TSIG parameters are consistent, if provided at all.
	n := 0
	if config.DNS.TSIGKeyName != "" {
		n++
	}
	if config.DNS.TSIGKeyAlgo != "" {
		n++
	}
	if config.DNS.TSIGKeySecret != "" {
		n++
	}
	if n != 0 && n != 3 {
		return nil, errors.New("either none or all of 'tsig_key_name', 'tsig_key_algo' and 'tsig_key_secret' must be set")
	}

	return &dnsValidator{
		nameservers: config.DNS.Nameservers,
		enableTSIG:  n > 0,
		keyName:     dns.Fqdn(config.DNS.TSIGKeyName),
		keyAlgo:     config.DNS.TSIGKeyAlgo,
		keySecret:   config.DNS.TSIGKeySecret,
	}, nil
}

func (d *dnsValidator) makeRR(fqdn, value string, ttl int) []dns.RR {
	rr := new(dns.TXT)
	rr.Hdr = dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)}
	rr.Txt = []string{value}
	return []dns.RR{rr}
}

func (d *dnsValidator) makeMsg(zone string, rrs []dns.RR, remove bool) *dns.Msg {
	m := new(dns.Msg)
	m.SetUpdate(zone)
	if remove {
		m.Remove(rrs)
	} else {
		m.RemoveRRset(rrs)
		m.Insert(rrs)
	}
	if d.enableTSIG {
		m.SetTsig(d.keyName, d.keyAlgo, tsigFudgeSeconds, time.Now().Unix())
	}
	return m
}

func (d *dnsValidator) client() *dns.Client {
	c := new(dns.Client)
	c.SingleInflight = true
	if d.enableTSIG {
		// TSIG authentication / msg signing.
		c.TsigSecret = map[string]string{d.keyName: d.keySecret}
	}
	return c
}

func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain string, chal *acme.Challenge) (func(), error) {
	zone := domain[strings.Index(domain, ".")+1:]
	fqdn := dns.Fqdn(domain)
	value, err := client.DNS01ChallengeRecord(chal.Token)
	if err != nil {
		return nil, err
	}

	if err := d.onAllNS(ctx, func(ns string) error {
		return d.updateNS(ctx, ns, zone, fqdn, value, false)
	}); err != nil {
		return nil, err
	}

	return func() {
		// nolint
		d.onAllNS(ctx, func(ns string) error {
			return d.updateNS(ctx, ns, zone, fqdn, value, true)
		})
	}, nil
}

func (d *dnsValidator) updateNS(ctx context.Context, ns, zone, fqdn, value string, remove bool) error {
	rrs := d.makeRR(fqdn, value, rfc2136Timeout)
	m := d.makeMsg(zone, rrs, remove)
	c := d.client()

	// Send the query
	reply, _, err := c.Exchange(m, ns)
	if err != nil {
		return err
	}
	if reply != nil && reply.Rcode != dns.RcodeSuccess {
		return fmt.Errorf("DNS server replied: %s", dns.RcodeToString[reply.Rcode])
	}

	return nil
}

// Run a function for each configured nameserver.
func (d *dnsValidator) onAllNS(ctx context.Context, f func(string) error) error {
	ch := make(chan error, len(d.nameservers))
	defer close(ch)

	for _, ns := range d.nameservers {
		if !strings.Contains(ns, ":") {
			ns += ":53"
		}
		go func(ns string) {
			err := f(ns)
			if err != nil {
				log.Printf("error updating DNS server %s: %v", ns, err)
			}
			ch <- err
		}(ns)
	}

	var ok bool
	for i := 0; i < len(d.nameservers); i++ {
		if err := <-ch; err == nil {
			ok = true
		}
	}
	if !ok {
		return errors.New("all nameservers failed")
	}
	return nil
}