diff --git a/cmd/redirectord/redirectord.go b/cmd/redirectord/redirectord.go index 386eea1b6892eeb529828fbdeea6e64c758f58b4..bba08a672cc885e93589008b4c5307a996dc5e55 100644 --- a/cmd/redirectord/redirectord.go +++ b/cmd/redirectord/redirectord.go @@ -11,22 +11,35 @@ import ( ) var ( - httpPort = flag.Int("port", 80, "TCP port to bind to") + domain = flag.String("domain", "", "DNS domain to serve") + dnsPort = flag.Int("dns-port", 53, "DNS port") + httpPort = flag.Int("http-port", 80, "HTTP port") + publicIp = flag.String("ip", "127.0.0.1", "Public IP for this machine") + + // Default DNS TTL (seconds). + dnsTtl = 5 ) func main() { flag.Parse() + if *domain == "" { + log.Fatal("Must specify --domain") + } + client := radioai.NewEtcdClient() api := radioai.NewRadioAPI(client) red := radioai.NewHttpRedirector(api) - server := &http.Server{ + dnsRed := radioai.NewDnsRedirector(api, *domain, *publicIp, dnsTtl) + dnsRed.Run(fmt.Sprintf(":%d", *dnsPort)) + + httpServer := &http.Server{ Addr: fmt.Sprintf(":%d", *httpPort), Handler: red, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } - log.Fatal(server.ListenAndServe()) -} \ No newline at end of file + log.Fatal(httpServer.ListenAndServe()) +} diff --git a/dns.go b/dns.go new file mode 100644 index 0000000000000000000000000000000000000000..d25384f05fc497c280d8aef0b322394fb5424b14 --- /dev/null +++ b/dns.go @@ -0,0 +1,176 @@ +package radioai + +import ( + "fmt" + "log" + "math/rand" + "strconv" + "strings" + "time" + + "github.com/miekg/dns" +) + +var ( + // Max number of results for an A query. + maxResults = 4 + + // The names that we are serving. Currently, all services are + // mapped to all the active nodes in the cluster. + validNames = []string{ + "", + "www", + "stream", + "etcd", + } +) + +type DnsRedirector struct { + client *RadioAPI + nodeCache *activeNodesCache + + origin string + originNumParts int + publicIp string + ttl int + soa dns.RR +} + +func NewDnsRedirector(client *RadioAPI, origin, publicIp string, ttl int) *DnsRedirector { + if !strings.HasSuffix(origin, ".") { + origin += "." + } + + // Create a SOA record for the zone. + serialNo := strconv.FormatInt(time.Now().Unix(), 10) + soaRec := fmt.Sprintf("%s %d IN SOA localhost. hostmaster.%s %s 43200 3600 2419200 %d", origin, ttl, origin, serialNo, ttl) + soa, err := dns.NewRR(soaRec) + if err != nil { + log.Fatalf("Could not generate SOA record: %s", err) + } + + return &DnsRedirector{ + client: client, + nodeCache: newActiveNodesCache(client), + origin: origin, + originNumParts: len(dns.SplitDomainName(origin)), + publicIp: publicIp, + ttl: ttl, + soa: soa, + } +} + +// Randomly shuffle a list of strings. +func shuffle(list []string) []string { + out := make([]string, len(list)) + for dst, src := range rand.Perm(len(list)) { + out[dst] = list[src] + } + return out +} + +func isValidQuery(query string) bool { + for _, q := range validNames { + if query == q { + return true + } + } + return false +} + +// Create skeleton edns opt RR from the query and add it to the +// message m. +func ednsFromRequest(req, m *dns.Msg) { + for _, r := range req.Extra { + if r.Header().Rrtype == dns.TypeOPT { + m.SetEdns0(4096, r.(*dns.OPT).Do()) + return + } + } + return +} + +// Create a RR string for an IP. +func recordForIp(query string, ttl int, recType string, ip string) string { + return fmt.Sprintf("%s %d IN %s %s", query, ttl, recType, ip) +} + +func (d *DnsRedirector) getQuestionName(req *dns.Msg) string { + lx := dns.SplitDomainName(req.Question[0].Name) + ql := lx[0 : len(lx)-d.originNumParts] + return strings.ToLower(strings.Join(ql, ".")) +} + +func (d *DnsRedirector) serveDNS(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + + // Just NACK ANYs + if req.Question[0].Qtype == dns.TypeANY { + m.SetRcode(req, dns.RcodeServerFailure) + ednsFromRequest(req, m) + w.WriteMsg(m) + return + } + + query := d.getQuestionName(req) + log.Printf("[zone %s] incoming %s %s %d from %s\n", d.origin, query, dns.TypeToString[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr()) + + switch { + case query == "" && req.Question[0].Qtype == dns.TypeSOA: + // Serve SOA record. + m.SetReply(req) + m.MsgHdr.Authoritative = true + m.Answer = append(m.Answer, d.soa) + + case req.Question[0].Qtype == dns.TypeA: + // Return an NXDOMAIN for unknown queries. + if !isValidQuery(query) { + m.SetRcode(req, dns.RcodeNameError) + log.Printf("Query(%s): NXDOMAIN", query) + } + + // Serve all active nodes on every request. + ips := d.nodeCache.GetNodes() + + if ips == nil || len(ips) == 0 { + // In case of errors retrieving the list of + // active nodes, fall back to serving our + // public IP (just to avoid returning an empty + // reply, which might be cached for longer). + ips = []string{d.publicIp} + } + + // Shuffle the list in random order, and keep only the + // first N results. + ips = shuffle(ips) + if len(ips) > maxResults { + ips = ips[:maxResults] + } + + m.SetReply(req) + m.MsgHdr.Authoritative = true + for _, ip := range ips { + rec := recordForIp(query, d.ttl, "A", ip) + answer, _ := dns.NewRR(rec) + m.Answer = append(m.Answer, answer) + } + log.Printf("Query(%s): %v", query, ips) + } + + ednsFromRequest(req, m) + w.WriteMsg(m) +} + +func (d *DnsRedirector) Run(addr string) { + dns.HandleFunc(d.origin, func(w dns.ResponseWriter, r *dns.Msg) { + d.serveDNS(w, r) + }) + + for _, proto := range []string{"tcp", "udp"} { + go func(proto string) { + server := &dns.Server{Addr: addr, Net: proto} + log.Printf("Starting DNS server on %s/53", proto) + log.Fatal(server.ListenAndServe()) + }(proto) + } +}