package acmeserver import ( "context" "errors" "fmt" "log" "strings" "time" "github.com/miekg/dns" "golang.org/x/crypto/acme" "golang.org/x/net/publicsuffix" ) const ( rfc2136Timeout = 600 tsigFudgeSeconds = 300 ) type dnsValidator struct { nameservers []string enableTSIG bool keyName string keyAlgo string keySecret string } func newDNSValidator(config *Config) (*dnsValidator, error) { if len(config.DNS.Nameservers) == 0 { return nil, errors.New("no nameservers configured") } // Check that the TSIG parameters are consistent, if provided at all. n := 0 if config.DNS.TSIGKeyName != "" { n++ } if config.DNS.TSIGKeyAlgo != "" { n++ } if config.DNS.TSIGKeySecret != "" { n++ } if n != 0 && n != 3 { return nil, errors.New("either none or all of 'tsig_key_name', 'tsig_key_algo' and 'tsig_key_secret' must be set") } return &dnsValidator{ nameservers: config.DNS.Nameservers, enableTSIG: n > 0, keyName: dns.Fqdn(config.DNS.TSIGKeyName), keyAlgo: config.DNS.TSIGKeyAlgo, keySecret: config.DNS.TSIGKeySecret, }, nil } func (d *dnsValidator) makeRR(fqdn, value string, ttl int) []dns.RR { rr := new(dns.TXT) rr.Hdr = dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)} rr.Txt = []string{value} return []dns.RR{rr} } func (d *dnsValidator) makeMsg(zone string, rrs []dns.RR, remove bool) *dns.Msg { m := new(dns.Msg) m.SetUpdate(zone) if remove { m.Remove(rrs) } else { m.RemoveRRset(rrs) m.Insert(rrs) } if d.enableTSIG { m.SetTsig(d.keyName, d.keyAlgo, tsigFudgeSeconds, time.Now().Unix()) } return m } func (d *dnsValidator) client() *dns.Client { c := new(dns.Client) c.SingleInflight = true if d.enableTSIG { // TSIG authentication / msg signing. c.TsigSecret = map[string]string{d.keyName: d.keySecret} } return c } func (d *dnsValidator) Fulfill(ctx context.Context, client *acme.Client, domain string, chal *acme.Challenge) (func(), error) { domain = strings.TrimPrefix(domain, "*.") zone, err := publicsuffix.EffectiveTLDPlusOne(domain) if err != nil { return nil, fmt.Errorf("could not determine effective tld: %w", err) } zone = dns.Fqdn(zone) fqdn := dns.Fqdn(domain) value, err := client.DNS01ChallengeRecord(chal.Token) if err != nil { return nil, err } if err := d.onAllNS(ctx, func(ns string) error { return d.updateNS(ctx, ns, zone, fqdn, value, false) }); err != nil { return nil, err } return func() { // nolint d.onAllNS(ctx, func(ns string) error { return d.updateNS(ctx, ns, zone, fqdn, value, true) }) }, nil } func (d *dnsValidator) updateNS(ctx context.Context, ns, zone, fqdn, value string, remove bool) error { log.Printf("updateNS(%s, %s, %s, %s, %v)", ns, zone, fqdn, value, remove) rrs := d.makeRR(fqdn, value, rfc2136Timeout) m := d.makeMsg(zone, rrs, remove) c := d.client() // Send the query reply, _, err := c.Exchange(m, ns) if err != nil { return err } if reply != nil && reply.Rcode != dns.RcodeSuccess { return fmt.Errorf("DNS server replied: %s", dns.RcodeToString[reply.Rcode]) } return nil } // Run a function for each configured nameserver. func (d *dnsValidator) onAllNS(ctx context.Context, f func(string) error) error { ch := make(chan error, len(d.nameservers)) defer close(ch) for _, ns := range d.nameservers { if !strings.Contains(ns, ":") { ns += ":53" } go func(ns string) { err := f(ns) if err != nil { log.Printf("error updating DNS server %s: %v", ns, err) } ch <- err }(ns) } var ok bool for i := 0; i < len(d.nameservers); i++ { if err := <-ch; err == nil { ok = true } } if !ok { return errors.New("all nameservers failed") } return nil }