diff --git a/node/dns.go b/node/dns.go index 0864f826b00cc0cb6f5c8797b25a7fee5f8bf3dc..3aaeeff0f69f372d41f64c498ef6a0556bdfaf93 100644 --- a/node/dns.go +++ b/node/dns.go @@ -13,7 +13,7 @@ import ( const ( zoneTTL = 21600 nsTTL = 3600 - recordTTL = 5 + recordTTL = 300 maxRecords = 5 ) @@ -61,7 +61,7 @@ func newDNSZone(lb *loadBalancer, origin string, nameservers []string) *dnsZone soa: soa, nameservers: nameservers, origin: origin, - originNumParts: len(strings.Split(origin, ".")), + originNumParts: dns.CountLabel(origin), } } @@ -70,8 +70,12 @@ func (d *dnsZone) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { ednsFromRequest(req, m) // Only consider the first question. + var name string q := req.Question[0] - name := d.getQuestionName(q) + if !dns.IsSubDomain(d.origin, q.Name) { + goto nxDomain + } + name = d.stripOrigin(q.Name) switch { case name == "" && q.Qtype == dns.TypeSOA: @@ -149,9 +153,9 @@ func (d *dnsZone) getNodeIPs(q dns.Question) []net.IP { } // Strip the origin from the query. -func (d *dnsZone) getQuestionName(q dns.Question) string { - lx := dns.SplitDomainName(q.Name) - ql := lx[0 : len(lx)-d.originNumParts] +func (d *dnsZone) stripOrigin(name string) string { + lx := dns.SplitDomainName(name) + ql := lx[:len(lx)-d.originNumParts] return strings.ToLower(strings.Join(ql, ".")) }