package node import ( "context" "math/rand" "net" "strings" "time" "github.com/miekg/dns" ) const ( zoneTTL = 21600 nsTTL = 3600 recordTTL = 5 maxRecords = 5 ) func newDNSHandler(n *Node, origin string, nameservers []string) dns.Handler { if !strings.HasSuffix(origin, ".") { origin += "." } dnsz := newDNSZone(n.lb, origin, nameservers) mux := dns.NewServeMux() mux.Handle(origin, dnsz) return mux } // Serve DNS records for our zone. type dnsZone struct { lb *loadBalancer soa *dns.SOA nameservers []string origin string originNumParts int } func newDNSZone(lb *loadBalancer, origin string, nameservers []string) *dnsZone { // Create a SOA record for the zone. Some entries will be bogus. soa := &dns.SOA{ Hdr: dns.RR_Header{ Name: origin, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: zoneTTL, }, Ns: "ns1." + origin, Mbox: "hostmaster." + origin, Serial: uint32(time.Now().Unix()), Refresh: 43200, Retry: 3600, Expire: uint32(zoneTTL), Minttl: uint32(zoneTTL), } return &dnsZone{ lb: lb, soa: soa, nameservers: nameservers, origin: origin, originNumParts: len(strings.Split(origin, ".")), } } func (d *dnsZone) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) ednsFromRequest(req, m) // Only consider the first question. q := req.Question[0] name := d.getQuestionName(q) switch { case name == "" && q.Qtype == dns.TypeSOA: m.Answer = append(m.Answer, d.soa) case name == "" && q.Qtype == dns.TypeNS: for _, ns := range d.nameservers { m.Answer = append(m.Answer, &dns.NS{ Hdr: dns.RR_Header{ Name: d.withOrigin(name), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: nsTTL, }, Ns: ns, }) } case q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA: if name != "" && name != "stream" { goto nxDomain } for _, ip := range d.getNodeIPs(q) { var rec dns.RR if q.Qtype == dns.TypeAAAA { rec = d.newAAAA(name, ip) } else { rec = d.newA(name, ip) } m.Answer = append(m.Answer, rec) } case q.Qtype == dns.TypeANY: goto servFail default: goto nxDomain } m.SetReply(req) m.MsgHdr.Authoritative = true w.WriteMsg(m) //nolint return servFail: m.SetRcode(req, dns.RcodeServerFailure) w.WriteMsg(m) //nolint return nxDomain: m.SetRcode(req, dns.RcodeNameError) w.WriteMsg(m) //nolint } func (d *dnsZone) getNodeIPs(q dns.Question) []net.IP { // Pick all known endpoint IPs, filtering those that match the // protocol in the DNS request. var ips []net.IP for _, ns := range d.lb.getNodes() { for _, ipp := range filterIPByProto(ns.parsedAddrs, (q.Qtype == dns.TypeAAAA)) { ips = append(ips, ipp.ip) } } // Shuffle the IP list in-place. rand.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] }) // Trim it to a maximum of maxRecords. if len(ips) > maxRecords { ips = ips[:maxRecords] } return ips } // Strip the origin from the query. func (d *dnsZone) getQuestionName(q dns.Question) string { lx := dns.SplitDomainName(q.Name) ql := lx[0 : len(lx)-d.originNumParts] return strings.ToLower(strings.Join(ql, ".")) } // Add the origin to a query. func (d *dnsZone) withOrigin(name string) string { if name == "" { return d.origin } return name + "." + d.origin } // Create an A resource record. func (d *dnsZone) newA(name string, ip net.IP) dns.RR { return &dns.A{ Hdr: dns.RR_Header{ Name: d.withOrigin(name), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: recordTTL, }, A: ip, } } // Create an AAAA resource record. func (d *dnsZone) newAAAA(name string, ip net.IP) dns.RR { return &dns.AAAA{ Hdr: dns.RR_Header{ Name: d.withOrigin(name), Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: recordTTL, }, AAAA: ip, } } // Create skeleton edns opt RR from the query and add it to the // message m. func ednsFromRequest(req, m *dns.Msg) { for _, r := range req.Extra { if r.Header().Rrtype == dns.TypeOPT { m.SetEdns0(4096, r.(*dns.OPT).Do()) return } } } // Wrapper to make the dns.Server match the genericServer interface. type dnsServer struct { *dns.Server name string } func newDNSServer(name, addr, proto string, h dns.Handler) *dnsServer { return &dnsServer{ Server: &dns.Server{ Addr: addr, Net: proto, Handler: h, }, name: name, } } func (s *dnsServer) Name() string { return s.name } func (s *dnsServer) Serve() error { return s.Server.ListenAndServe() } func (s *dnsServer) Stop() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) s.Server.ShutdownContext(ctx) //nolint cancel() }