diff --git a/api.go b/api.go index 57372bdc3ee8cd58cd806549e2fcf39724a67068..a6b81b6bd14770868428ba5c78dd82a822acdf30 100644 --- a/api.go +++ b/api.go @@ -264,7 +264,7 @@ func (ns *NodeStatus) NumListeners() int { // Client is the actual API to the streaming cluster's database. type Client struct { - client EtcdClient + EtcdClient EtcdClient presenceCache *presence.Cache } @@ -274,7 +274,7 @@ func NewClient(client EtcdClient) *Client { // since it is likely that it will be accessed quite often (in // the case of redirectord, on every request). return &Client{ - client: client, + EtcdClient: client, presenceCache: presence.NewCache(client, NodePrefix, 2*time.Second, func(data []string) interface{} { // Convert a list of JSON-encoded NodeStatus // objects into a lisce of *NodeStatus @@ -293,6 +293,11 @@ func NewClient(client EtcdClient) *Client { } } +// WaitForNodes waits until the node presence cache is initialized. +func (r *Client) WaitForNodes() { + r.presenceCache.WaitForInit() +} + // GetNodes returns the list of active cluster nodes. func (r *Client) GetNodes() ([]*NodeStatus, error) { data, err := r.presenceCache.Data() @@ -305,7 +310,7 @@ func (r *Client) GetNodes() ([]*NodeStatus, error) { // GetMount returns data on a specific mountpoint (returns nil if not // found). func (r *Client) GetMount(mountName string) (*Mount, error) { - response, err := r.client.Get(mountEtcdPath(mountName), false, false) + response, err := r.EtcdClient.Get(mountEtcdPath(mountName), false, false) if err != nil || response.Node == nil { return nil, err } @@ -327,19 +332,19 @@ func (r *Client) SetMount(m *Mount) error { return err } - _, err := r.client.Set(mountEtcdPath(m.Name), buf.String(), 0) + _, err := r.EtcdClient.Set(mountEtcdPath(m.Name), buf.String(), 0) return err } // DelMount removes a mountpoint. func (r *Client) DelMount(mountName string) error { - _, err := r.client.Delete(mountEtcdPath(mountName), false) + _, err := r.EtcdClient.Delete(mountEtcdPath(mountName), false) return err } // ListMounts returns a list of all the configured mountpoints. func (r *Client) ListMounts() ([]*Mount, error) { - response, err := r.client.Get(MountPrefix, true, false) + response, err := r.EtcdClient.Get(MountPrefix, true, false) if err != nil || response.Node == nil { return nil, err } @@ -384,7 +389,7 @@ func (m MasterNodeInfo) GetInternalIP() []net.IP { // GetMasterInfo returns the address of the current master server. func (r *Client) GetMasterInfo() (*MasterNodeInfo, error) { - response, err := r.client.Get(MasterElectionPath, false, false) + response, err := r.EtcdClient.Get(MasterElectionPath, false, false) if err != nil || response.Node == nil { return nil, err } diff --git a/coordination/etcdtest/fake_etcd.go b/coordination/etcdtest/fake_etcd.go index a19f46886bc4ef052bb48b72cb44856161811ac1..f5c34b19b58659d8a160243318524e93a57ccc4b 100644 --- a/coordination/etcdtest/fake_etcd.go +++ b/coordination/etcdtest/fake_etcd.go @@ -286,3 +286,11 @@ func (s *FakeEtcdClient) Watch(key string, index uint64, recursive bool, respch } return resp, nil } + +func (s *FakeEtcdClient) GetCluster() []string { + return []string{"http://localhost:2379"} +} + +func (s *FakeEtcdClient) SyncCluster() bool { + return false +} diff --git a/coordination/presence/cache.go b/coordination/presence/cache.go index 3c67b2e9edbb68d0b7fef8c4d4f8fb3858952016..aa6aa8551cdde8963fd342502ec813d9ac87dd7a 100644 --- a/coordination/presence/cache.go +++ b/coordination/presence/cache.go @@ -74,7 +74,7 @@ func (c *Cache) run(refresh time.Duration) { } doUpdate() - c.loaded <- struct{}{} + close(c.loaded) for { select { case <-tick.C: diff --git a/coordination/presence/presence.go b/coordination/presence/presence.go index 2b5cd0679c707f5e576990e8dc6fc6665605b94f..24d2246ea1c35cc1829b03e62f6d1adfbfc308ac 100644 --- a/coordination/presence/presence.go +++ b/coordination/presence/presence.go @@ -7,6 +7,7 @@ package presence import ( "log" + "strings" "time" "git.autistici.org/ale/autoradio/Godeps/_workspace/src/github.com/coreos/go-etcd/etcd" @@ -33,7 +34,7 @@ type Client struct { func NewClient(client EtcdClient, path string) *Client { return &Client{ client: client, - path: path, + path: strings.TrimRight(path, "/"), } } diff --git a/etcd_client.go b/etcd_client.go index d95a49e7c0ee7661dcd68ab3a63f8c605bd7f0bf..aa721ff25fdb70bfc02f2fe59cd5eed8e59b6078 100644 --- a/etcd_client.go +++ b/etcd_client.go @@ -94,7 +94,9 @@ type EtcdClient interface { CompareAndSwap(string, string, uint64, string, uint64) (*etcd.Response, error) Delete(string, bool) (*etcd.Response, error) Get(string, bool, bool) (*etcd.Response, error) + GetCluster() []string Set(string, string, uint64) (*etcd.Response, error) + SyncCluster() bool Update(string, string, uint64) (*etcd.Response, error) Watch(string, uint64, bool, chan *etcd.Response, chan bool) (*etcd.Response, error) } diff --git a/fe/dns.go b/fe/dns.go index 7977bd2ca1619cd56956de977b592c4bdc96e534..afa681b9a91febc91b3c72086034cdfe62f09a22 100644 --- a/fe/dns.go +++ b/fe/dns.go @@ -5,6 +5,8 @@ import ( "log" "math/rand" "net" + "net/url" + "strconv" "strings" "time" @@ -17,27 +19,78 @@ var ( // Max number of results for an A query. maxResults = 3 - // The names that we are serving. Currently, all services are - // mapped to all the active nodes in the cluster. - validNames = []string{ - "", - "www", - "stream", - "etcd", - } - dnsQueryStats = instrumentation.NewCounter("dns.status") dnsTargetStats = instrumentation.NewCounter("dns.target") ) +type etcdAddr struct { + host string + port int + ips []net.IP +} + +// Parsed etcd cluster state. +type etcdClusterState struct { + addrs []etcdAddr + ssl bool +} + +func (s *etcdClusterState) IPs() []net.IP { + var out []net.IP + for _, addr := range s.addrs { + out = append(out, addr.ips...) + } + return out +} + +func parseEtcdClusterState(urls []string) *etcdClusterState { + if len(urls) == 0 { + return nil + } + var state etcdClusterState + for _, u := range urls { + parsedURL, err := url.Parse(u) + if err != nil { + continue + } + host, portStr, err := net.SplitHostPort(parsedURL.Host) + if err != nil { + host = parsedURL.Host + portStr = "2379" + } + port, err := strconv.Atoi(portStr) + if err != nil { + continue + } + if parsedURL.Scheme == "https" { + state.ssl = true + } + + // Resolve the hostname if necessary. + var ips []net.IP + if ip := net.ParseIP(host); ip != nil { + ips = []net.IP{ip} + } else if resolved, err := net.LookupIP(host); err == nil { + ips = resolved + } else { + continue + } + + state.addrs = append(state.addrs, etcdAddr{host: dns.Fqdn(host), port: port, ips: ips}) + } + return &state +} + // DNSRedirector sends clients to backends using DNS. type DNSRedirector struct { client *autoradio.Client origin string originNumParts int publicIps []net.IP + etcdCluster *etcdClusterState ttl int soa dns.RR + queryTable map[string]ipFunc } // NewDNSRedirector returns a DNS server for the given origin and @@ -71,9 +124,31 @@ func NewDNSRedirector(client *autoradio.Client, origin string, publicIps []net.I publicIps: publicIps, ttl: ttl, soa: soa, + queryTable: map[string]ipFunc{ + "": getAutoradioIPs, + "www": getAutoradioIPs, + "stream": getAutoradioIPs, + "etcd": getEtcdIPs, + }, } } +// Periodically update the list of etcd nodes. We need to parse the +// etcd URLs and resolve to a list of IPs. +func (d *DNSRedirector) updateEtcdCluster() { + d.etcdCluster = parseEtcdClusterState(d.client.EtcdClient.GetCluster()) + go func() { + for range time.Tick(60 * time.Second) { + if !d.client.EtcdClient.SyncCluster() { + continue + } + if s := parseEtcdClusterState(d.client.EtcdClient.GetCluster()); s != nil { + d.etcdCluster = s + } + } + }() +} + // Randomly shuffle a list of addresses. func shuffle(list []net.IP) []net.IP { out := make([]net.IP, len(list)) @@ -83,15 +158,6 @@ func shuffle(list []net.IP) []net.IP { return out } -func isValidQuery(query string) bool { - for _, q := range validNames { - if query == q { - return true - } - } - return false -} - // Create skeleton edns opt RR from the query and add it to the // message m. func ednsFromRequest(req, m *dns.Msg) { @@ -138,8 +204,8 @@ func (d *DNSRedirector) newAAAA(name string, ip net.IP) dns.RR { } // Strip the origin from the query. -func (d *DNSRedirector) getQuestionName(req *dns.Msg) string { - lx := dns.SplitDomainName(req.Question[0].Name) +func (d *DNSRedirector) getQuestionName(q dns.Question) string { + lx := dns.SplitDomainName(q.Name) ql := lx[0 : len(lx)-d.originNumParts] return strings.ToLower(strings.Join(ql, ".")) } @@ -153,69 +219,95 @@ func flattenIPs(nodes []*autoradio.NodeStatus) []net.IP { return ips } -func (d *DNSRedirector) serveDNS(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) +type ipFunc func(*DNSRedirector, dns.Question) []net.IP - // Just NACK ANYs - if req.Question[0].Qtype == dns.TypeANY { - m.SetRcode(req, dns.RcodeServerFailure) - ednsFromRequest(req, m) - w.WriteMsg(m) - return +func getEtcdIPs(d *DNSRedirector, q dns.Question) []net.IP { + if d.etcdCluster == nil { + return nil + } + + // Always serve all IPs. + return filterIPByProto(d.etcdCluster.IPs(), q.Qtype == dns.TypeAAAA) +} + +func getAutoradioIPs(d *DNSRedirector, q dns.Question) []net.IP { + // Serve all active nodes on every request. We don't really + // care about errors from GetNodes as long as some nodes are + // returned (i.e. stale data from the cache is accepted). + // Also, we need to filter the resulting list for nodes whose + // IP address protocol version matches the request type (IPv4 + // for A requests, IPv6 for AAAA). + nodes, err := d.client.GetNodes() + if err != nil { + log.Printf("error fetching list of nodes: %v", err) + return nil } - query := d.getQuestionName(req) - var responseMsg string + // Shuffle the list in random order, and keep only the first N + // results. + ips := shuffle(filterIPByProto(flattenIPs(nodes), q.Qtype == dns.TypeAAAA)) + if len(ips) > maxResults { + ips = ips[:maxResults] + } + return ips +} +func (d *DNSRedirector) handleQuestion(q dns.Question, m *dns.Msg) bool { + query := d.getQuestionName(q) switch { - case query == "" && req.Question[0].Qtype == dns.TypeSOA: - // Serve SOA record. - m.SetReply(req) - m.MsgHdr.Authoritative = true + case query == "" && q.Qtype == dns.TypeSOA: m.Answer = append(m.Answer, d.soa) - responseMsg = "SOA" - - case req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA: - // Return an NXDOMAIN for unknown queries. - if !isValidQuery(query) { - m.SetRcode(req, dns.RcodeNameError) - responseMsg = "NXDOMAIN" - break + return true + + case q.Qtype == dns.TypeSRV: + if d.etcdCluster == nil { + return false } - // Serve all active nodes on every request. We don't - // really care about errors from GetNodes as long as - // some nodes are returned (i.e. stale data from the - // cache is accepted). Also, we need to filter the - // resulting list for nodes whose IP address protocol - // version matches the request type (IPv4 for A - // requests, IPv6 for AAAA). - var ips []net.IP - nodes, _ := d.client.GetNodes() - if len(nodes) > 0 { - ips = flattenIPs(nodes) - } else { - // In case of errors retrieving the list of - // active nodes, fall back to serving our - // public IP (just to avoid returning an empty - // reply, which might be cached for longer). - ips = d.publicIps + // Serve the right name depending on whether the etcd + // cluster is set up to use SSL or not. + if !((query == "_etcd-server-ssl._tcp" && d.etcdCluster.ssl) || (query == "_etcd-server._tcp" && !d.etcdCluster.ssl)) { + return false + } + + for _, addr := range d.etcdCluster.addrs { + rec := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: d.withOrigin(query), + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: uint32(d.ttl), + }, + Target: addr.host, + Port: uint16(addr.port), + } + m.Answer = append(m.Answer, rec) } - isV6 := (req.Question[0].Qtype == dns.TypeAAAA) - ips = filterIPByProto(ips, isV6) - - // Shuffle the list in random order, and keep only the - // first N results. - ips = shuffle(ips) - if len(ips) > maxResults { - ips = ips[:maxResults] + return true + + case q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA: + // Look up the requested name in the routing table, + // and retrieve the list of IPs for the reply. Return + // an NXDOMAIN for unknown queries. + ipfn, ok := d.queryTable[query] + if !ok { + return false + } + + // Get the IPs for the reply. + ips := ipfn(d, q) + if len(ips) == 0 { + // In case of errors, if the IP list is empty, fall + // back to serving our public IP (just to avoid + // returning an empty reply, which might be cached for + // longer). + log.Printf("fallback to localhost") + ips = d.publicIps } - m.SetReply(req) - m.MsgHdr.Authoritative = true for _, ip := range ips { var rec dns.RR - if isV6 { + if q.Qtype == dns.TypeAAAA { rec = d.newAAAA(query, ip) } else { rec = d.newA(query, ip) @@ -223,32 +315,69 @@ func (d *DNSRedirector) serveDNS(w dns.ResponseWriter, req *dns.Msg) { m.Answer = append(m.Answer, rec) dnsTargetStats.IncrVar(ipToMetric(ip)) } - responseMsg = fmt.Sprintf("%v", ips) + return true + } + return false +} - default: - // Return an error for anything else. - m.SetRcode(req, dns.RcodeNameError) - responseMsg = "NXDOMAIN" +func responseToString(m *dns.Msg) string { + if m.MsgHdr.Rcode != dns.RcodeSuccess { + return dns.RcodeToString[m.MsgHdr.Rcode] + } + var out []string + for _, ans := range m.Answer { + var s string + switch t := ans.(type) { + case *dns.A: + s = fmt.Sprintf("A %s", t.A) + case *dns.AAAA: + s = fmt.Sprintf("AAAA %s", t.AAAA) + case *dns.SRV: + s = fmt.Sprintf("SRV %s:%d", strings.TrimRight(t.Target, "."), t.Port) + default: + s = t.String() + } + out = append(out, s) } + return strings.Join(out, ", ") +} - if responseMsg == "NXDOMAIN" { +func (d *DNSRedirector) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + + // Only consider the first question. + q := req.Question[0] + + // Just NACK ANYs + if q.Qtype == dns.TypeANY { + m.SetRcode(req, dns.RcodeServerFailure) + ednsFromRequest(req, m) + w.WriteMsg(m) + return + } + + // Handle the request. + ednsFromRequest(req, m) + if !d.handleQuestion(q, m) { + m.SetRcode(req, dns.RcodeNameError) dnsQueryStats.IncrVar("error") } else { + m.SetReply(req) + m.MsgHdr.Authoritative = true dnsQueryStats.IncrVar("ok") } - log.Printf("[%d] %s.%s %s (from %s) -> %s", req.MsgHdr.Id, query, d.origin, dns.TypeToString[req.Question[0].Qtype], w.RemoteAddr(), responseMsg) + log.Printf("[%d] %s %s %s -> %s", req.MsgHdr.Id, w.RemoteAddr(), q.Name, dns.TypeToString[q.Qtype], responseToString(m)) - ednsFromRequest(req, m) w.WriteMsg(m) } // Start the DNS servers on the given address (both tcp and udp). // It creates new goroutines and returns immediately. func (d *DNSRedirector) Start(addr string) { - dns.HandleFunc(d.origin, func(w dns.ResponseWriter, r *dns.Msg) { - d.serveDNS(w, r) - }) + d.updateEtcdCluster() + + dns.Handle(d.origin, d) for _, proto := range []string{"tcp", "udp"} { go func(proto string) { diff --git a/fe/dns_test.go b/fe/dns_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2df777b635776a3c650172f6bc0bf19dc72caa89 --- /dev/null +++ b/fe/dns_test.go @@ -0,0 +1,173 @@ +package fe + +import ( + "encoding/json" + "log" + "net" + "testing" + + "git.autistici.org/ale/autoradio/Godeps/_workspace/src/github.com/miekg/dns" + + "git.autistici.org/ale/autoradio" + "git.autistici.org/ale/autoradio/coordination/etcdtest" +) + +func createTestDNSRedirector(t testing.TB, withNode bool) *DNSRedirector { + etcd := etcdtest.NewClient() + if withNode { + // Create fake presence data, so that GetNodes() returns something. + nodeData, _ := json.Marshal(&autoradio.NodeStatus{ + Name: "node1", + IcecastUp: true, + IP: []net.IP{ + net.ParseIP("1.2.3.4"), + net.ParseIP("2001:a:b::1"), + }, + }) + log.Printf("creating %s", autoradio.NodePrefix+"0001") + if _, err := etcd.Create(autoradio.NodePrefix+"0001", string(nodeData), 86400); err != nil { + t.Fatalf("etcd.Create(): %v", err) + } + } + + client := autoradio.NewClient(etcd) + client.WaitForNodes() + d := NewDNSRedirector(client, "example.com", []net.IP{net.ParseIP("2.3.4.5")}, 30) + d.updateEtcdCluster() + return d +} + +type testNetAddr struct{} + +func (n testNetAddr) Network() string { return "tcp" } +func (n testNetAddr) String() string { return "127.0.0.1" } + +// Fake dns.ResponseWriter that records messages written to it. +type testDNSResponseWriter struct { + messages []*dns.Msg +} + +// WriteMsg writes a reply back to the client. +func (w *testDNSResponseWriter) WriteMsg(m *dns.Msg) error { + w.messages = append(w.messages, m) + return nil +} + +// LocalAddr returns the net.Addr of the server +func (w *testDNSResponseWriter) LocalAddr() net.Addr { return &testNetAddr{} } + +// RemoteAddr returns the net.Addr of the client that sent the current request. +func (w *testDNSResponseWriter) RemoteAddr() net.Addr { return &testNetAddr{} } + +// Write writes a raw buffer back to the client. +func (w *testDNSResponseWriter) Write([]byte) (int, error) { return 0, nil } + +// Close closes the connection. +func (w *testDNSResponseWriter) Close() error { return nil } + +// TsigStatus returns the status of the Tsig. +func (w *testDNSResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly sets the tsig timers only boolean. +func (w *testDNSResponseWriter) TsigTimersOnly(bool) {} + +// Hijack lets the caller take over the connection. +// After a call to Hijack(), the DNS package will not do anything with the connection. +func (w *testDNSResponseWriter) Hijack() {} + +func TestDNSRedirector_A(t *testing.T) { + testQueryA(t, true, "stream.example.com.", "1.2.3.4") +} + +func TestDNSRedirector_A_Etcd(t *testing.T) { + testQueryA(t, false, "etcd.example.com.", "127.0.0.1") +} + +func TestDNSRedirector_A_Fallback(t *testing.T) { + testQueryA(t, false, "stream.example.com.", "2.3.4.5") +} + +func doTestQuery(t testing.TB, withNode bool, q *dns.Msg) *dns.Msg { + d := createTestDNSRedirector(t, withNode) + w := &testDNSResponseWriter{} + d.ServeDNS(w, q) + if len(w.messages) != 1 { + t.Fatal("no reply") + } + return w.messages[0] +} + +func testQueryA(t testing.TB, withNode bool, query, expected string) { + q := new(dns.Msg) + q.SetQuestion(query, dns.TypeA) + response := doTestQuery(t, withNode, q) + + answer, ok := response.Answer[0].(*dns.A) + if !ok { + t.Fatalf("expected A, got: %s:", response.Answer[0]) + } + if answer.A.String() != expected { + t.Fatalf("bad IP: %s, expected %s", answer.A, expected) + } +} + +func TestDNSRedirector_AAAA(t *testing.T) { + q := new(dns.Msg) + q.SetQuestion("stream.example.com", dns.TypeAAAA) + response := doTestQuery(t, true, q) + + answer, ok := response.Answer[0].(*dns.AAAA) + if !ok { + t.Fatalf("bad reply (not AAAA): %s:", response.Answer[0]) + } + expected := "2001:a:b::1" + if answer.AAAA.String() != expected { + t.Fatalf("bad IP: %s, expected %s", answer.AAAA, expected) + } +} + +func TestDNSRedirector_NXDOMAIN(t *testing.T) { + q := new(dns.Msg) + q.SetQuestion("nonexisting.example.com", dns.TypeA) + response := doTestQuery(t, false, q) + if response.MsgHdr.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got: %s", response) + } +} + +func TestDNSRedirector_NXDOMAIN_WrongZone(t *testing.T) { + q := new(dns.Msg) + q.SetQuestion("foo.bar.com", dns.TypeA) + response := doTestQuery(t, false, q) + if response.MsgHdr.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got: %s", response) + } +} + +func TestDNSRedirector_SRV_Etcd(t *testing.T) { + q := new(dns.Msg) + q.SetQuestion("_etcd-server._tcp.example.com", dns.TypeSRV) + response := doTestQuery(t, false, q) + if len(response.Answer) != 1 { + t.Fatalf("expected 1 answer, got: %s", response) + } + srv, ok := response.Answer[0].(*dns.SRV) + if !ok { + t.Fatalf("expected SRV, got: %s", response) + } + if srv.Port != 2379 { + t.Fatalf("expected port 2379, got: %s", srv) + } + if srv.Target != "localhost." { + t.Fatalf("expected target localhost, got: %s", srv) + } +} + +func TestDNSRedirector_SRV_Etcd_BadScheme(t *testing.T) { + q := new(dns.Msg) + q.SetQuestion("_etcd-server-ssl._tcp.example.com", dns.TypeSRV) + response := doTestQuery(t, false, q) + if response.MsgHdr.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got: %s", response) + } +}