Skip to content
Snippets Groups Projects
dns_challenge.go 3.39 KiB
Newer Older
  • Learn to ignore specific revisions
  • ale's avatar
    ale committed
    package acmeserver
    
    import (
    	"context"
    	"errors"
    	"fmt"
    	"log"
    	"strings"
    	"time"
    
    	"github.com/miekg/dns"
    	"golang.org/x/crypto/acme"
    )
    
    const (
    	rfc2136Timeout   = 600
    	tsigFudgeSeconds = 300
    )
    
    type dnsValidator struct {
    	nameservers []string
    	enableTSIG  bool
    	keyName     string
    	keyAlgo     string
    	keySecret   string
    }
    
    func newDNSValidator(config *Config) (*dnsValidator, error) {
    	if len(config.DNS.Nameservers) == 0 {
    		return nil, errors.New("no nameservers configured")
    	}
    
    	// Check that the TSIG parameters are consistent, if provided at all.
    	n := 0
    	if config.DNS.TSIGKeyName != "" {
    		n++
    	}
    	if config.DNS.TSIGKeyAlgo != "" {
    		n++
    	}
    	if config.DNS.TSIGKeySecret != "" {
    		n++
    	}
    	if n != 0 && n != 3 {
    		return nil, errors.New("either none or all of 'tsig_key_name', 'tsig_key_algo' and 'tsig_key_secret' must be set")
    	}
    
    	return &dnsValidator{
    		nameservers: config.DNS.Nameservers,
    		enableTSIG:  n > 0,
    		keyName:     dns.Fqdn(config.DNS.TSIGKeyName),
    		keyAlgo:     config.DNS.TSIGKeyAlgo,
    		keySecret:   config.DNS.TSIGKeySecret,
    	}, nil
    }
    
    func (d *dnsValidator) makeRR(fqdn, value string, ttl int) []dns.RR {
    	rr := new(dns.TXT)
    	rr.Hdr = dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)}
    	rr.Txt = []string{value}
    	return []dns.RR{rr}
    }
    
    func (d *dnsValidator) makeMsg(zone string, rrs []dns.RR, remove bool) *dns.Msg {
    	m := new(dns.Msg)
    	m.SetUpdate(zone)
    	if remove {
    		m.Remove(rrs)
    	} else {
    		m.RemoveRRset(rrs)
    		m.Insert(rrs)
    	}
    	if d.enableTSIG {
    		m.SetTsig(d.keyName, d.keyAlgo, tsigFudgeSeconds, time.Now().Unix())
    	}
    	return m
    }
    
    func (d *dnsValidator) client() *dns.Client {
    	c := new(dns.Client)
    	c.SingleInflight = true
    	if d.enableTSIG {
    		// TSIG authentication / msg signing.
    		c.TsigSecret = map[string]string{d.keyName: d.keySecret}
    	}
    	return c
    }
    
    func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain string, chal *acme.Challenge) (func(), error) {
    	zone := domain[strings.Index(domain, ".")+1:]
    	fqdn := dns.Fqdn(domain)
    	value, err := client.DNS01ChallengeRecord(chal.Token)
    	if err != nil {
    		return nil, err
    	}
    
    	if err := d.onAllNS(ctx, func(ns string) error {
    		return d.updateNS(ctx, ns, zone, fqdn, value, false)
    	}); err != nil {
    		return nil, err
    	}
    
    	return func() {
    		// nolint
    		d.onAllNS(ctx, func(ns string) error {
    			return d.updateNS(ctx, ns, zone, fqdn, value, true)
    		})
    	}, nil
    }
    
    func (d *dnsValidator) updateNS(ctx context.Context, ns, zone, fqdn, value string, remove bool) error {
    	rrs := d.makeRR(fqdn, value, rfc2136Timeout)
    	m := d.makeMsg(zone, rrs, remove)
    	c := d.client()
    
    	// Send the query
    	reply, _, err := c.Exchange(m, ns)
    	if err != nil {
    		return err
    	}
    	if reply != nil && reply.Rcode != dns.RcodeSuccess {
    		return fmt.Errorf("DNS server replied: %s", dns.RcodeToString[reply.Rcode])
    	}
    
    	return nil
    }
    
    // Run a function for each configured nameserver.
    func (d *dnsValidator) onAllNS(ctx context.Context, f func(string) error) error {
    	ch := make(chan error, len(d.nameservers))
    	defer close(ch)
    
    	for _, ns := range d.nameservers {
    		if !strings.Contains(ns, ":") {
    			ns += ":53"
    		}
    		go func(ns string) {
    			err := f(ns)
    			if err != nil {
    				log.Printf("error updating DNS server %s: %v", ns, err)
    			}
    			ch <- err
    		}(ns)
    	}
    
    	var ok bool
    	for i := 0; i < len(d.nameservers); i++ {
    		if err := <-ch; err == nil {
    			ok = true
    		}
    	}
    	if !ok {
    		return errors.New("all nameservers failed")
    	}
    	return nil
    }