Skip to content
Snippets Groups Projects
Commit 3bd78197 authored by ale's avatar ale
Browse files

add the DNS redirector

parent 7e4eac4d
No related branches found
No related tags found
No related merge requests found
......@@ -11,22 +11,35 @@ import (
)
var (
httpPort = flag.Int("port", 80, "TCP port to bind to")
domain = flag.String("domain", "", "DNS domain to serve")
dnsPort = flag.Int("dns-port", 53, "DNS port")
httpPort = flag.Int("http-port", 80, "HTTP port")
publicIp = flag.String("ip", "127.0.0.1", "Public IP for this machine")
// Default DNS TTL (seconds).
dnsTtl = 5
)
func main() {
flag.Parse()
if *domain == "" {
log.Fatal("Must specify --domain")
}
client := radioai.NewEtcdClient()
api := radioai.NewRadioAPI(client)
red := radioai.NewHttpRedirector(api)
server := &http.Server{
dnsRed := radioai.NewDnsRedirector(api, *domain, *publicIp, dnsTtl)
dnsRed.Run(fmt.Sprintf(":%d", *dnsPort))
httpServer := &http.Server{
Addr: fmt.Sprintf(":%d", *httpPort),
Handler: red,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
log.Fatal(server.ListenAndServe())
}
\ No newline at end of file
log.Fatal(httpServer.ListenAndServe())
}
dns.go 0 → 100644
package radioai
import (
"fmt"
"log"
"math/rand"
"strconv"
"strings"
"time"
"github.com/miekg/dns"
)
var (
// Max number of results for an A query.
maxResults = 4
// The names that we are serving. Currently, all services are
// mapped to all the active nodes in the cluster.
validNames = []string{
"",
"www",
"stream",
"etcd",
}
)
type DnsRedirector struct {
client *RadioAPI
nodeCache *activeNodesCache
origin string
originNumParts int
publicIp string
ttl int
soa dns.RR
}
func NewDnsRedirector(client *RadioAPI, origin, publicIp string, ttl int) *DnsRedirector {
if !strings.HasSuffix(origin, ".") {
origin += "."
}
// Create a SOA record for the zone.
serialNo := strconv.FormatInt(time.Now().Unix(), 10)
soaRec := fmt.Sprintf("%s %d IN SOA localhost. hostmaster.%s %s 43200 3600 2419200 %d", origin, ttl, origin, serialNo, ttl)
soa, err := dns.NewRR(soaRec)
if err != nil {
log.Fatalf("Could not generate SOA record: %s", err)
}
return &DnsRedirector{
client: client,
nodeCache: newActiveNodesCache(client),
origin: origin,
originNumParts: len(dns.SplitDomainName(origin)),
publicIp: publicIp,
ttl: ttl,
soa: soa,
}
}
// Randomly shuffle a list of strings.
func shuffle(list []string) []string {
out := make([]string, len(list))
for dst, src := range rand.Perm(len(list)) {
out[dst] = list[src]
}
return out
}
func isValidQuery(query string) bool {
for _, q := range validNames {
if query == q {
return true
}
}
return false
}
// Create skeleton edns opt RR from the query and add it to the
// message m.
func ednsFromRequest(req, m *dns.Msg) {
for _, r := range req.Extra {
if r.Header().Rrtype == dns.TypeOPT {
m.SetEdns0(4096, r.(*dns.OPT).Do())
return
}
}
return
}
// Create a RR string for an IP.
func recordForIp(query string, ttl int, recType string, ip string) string {
return fmt.Sprintf("%s %d IN %s %s", query, ttl, recType, ip)
}
func (d *DnsRedirector) getQuestionName(req *dns.Msg) string {
lx := dns.SplitDomainName(req.Question[0].Name)
ql := lx[0 : len(lx)-d.originNumParts]
return strings.ToLower(strings.Join(ql, "."))
}
func (d *DnsRedirector) serveDNS(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
// Just NACK ANYs
if req.Question[0].Qtype == dns.TypeANY {
m.SetRcode(req, dns.RcodeServerFailure)
ednsFromRequest(req, m)
w.WriteMsg(m)
return
}
query := d.getQuestionName(req)
log.Printf("[zone %s] incoming %s %s %d from %s\n", d.origin, query, dns.TypeToString[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr())
switch {
case query == "" && req.Question[0].Qtype == dns.TypeSOA:
// Serve SOA record.
m.SetReply(req)
m.MsgHdr.Authoritative = true
m.Answer = append(m.Answer, d.soa)
case req.Question[0].Qtype == dns.TypeA:
// Return an NXDOMAIN for unknown queries.
if !isValidQuery(query) {
m.SetRcode(req, dns.RcodeNameError)
log.Printf("Query(%s): NXDOMAIN", query)
}
// Serve all active nodes on every request.
ips := d.nodeCache.GetNodes()
if ips == nil || len(ips) == 0 {
// In case of errors retrieving the list of
// active nodes, fall back to serving our
// public IP (just to avoid returning an empty
// reply, which might be cached for longer).
ips = []string{d.publicIp}
}
// Shuffle the list in random order, and keep only the
// first N results.
ips = shuffle(ips)
if len(ips) > maxResults {
ips = ips[:maxResults]
}
m.SetReply(req)
m.MsgHdr.Authoritative = true
for _, ip := range ips {
rec := recordForIp(query, d.ttl, "A", ip)
answer, _ := dns.NewRR(rec)
m.Answer = append(m.Answer, answer)
}
log.Printf("Query(%s): %v", query, ips)
}
ednsFromRequest(req, m)
w.WriteMsg(m)
}
func (d *DnsRedirector) Run(addr string) {
dns.HandleFunc(d.origin, func(w dns.ResponseWriter, r *dns.Msg) {
d.serveDNS(w, r)
})
for _, proto := range []string{"tcp", "udp"} {
go func(proto string) {
server := &dns.Server{Addr: addr, Net: proto}
log.Printf("Starting DNS server on %s/53", proto)
log.Fatal(server.ListenAndServe())
}(proto)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment