package clientutil import ( "fmt" "sync" "time" ) var dnsWatcherInterval = 1 * time.Minute // A DNSWatcher monitors a DNS name for changes, constantly attempting // to resolve it every minute and notifying a channel when the list of // returned IP addresses changes. All addresses must be in host:port // format. type DNSWatcher struct { hostport string resolver resolver addrs []string updateCh chan []string stopCh chan struct{} } // NewDNSWatcher creates a new DNSWatcher. func NewDNSWatcher(hostport string) (*DNSWatcher, error) { return newDNSWatcherWithResolver(hostport, defaultResolver) } func newDNSWatcherWithResolver(hostport string, resolver resolver) (*DNSWatcher, error) { // Resolve names once before returning. Return a fatal error // when there are no results, as it may indicate a syntax // error in hostport. addrs := resolver.ResolveIP(hostport) if len(addrs) == 0 { return nil, fmt.Errorf("can't resolve %s", hostport) } w := &DNSWatcher{ hostport: hostport, resolver: resolver, addrs: addrs, updateCh: make(chan []string, 10), stopCh: make(chan struct{}), } w.updateCh <- addrs go w.loop() return w, nil } // Stop the watcher. func (w *DNSWatcher) Stop() { close(w.stopCh) } // Changes returns a channel where the resolved addresses are sent // whenever they change. func (w *DNSWatcher) Changes() <-chan []string { return w.updateCh } func (w *DNSWatcher) check() { addrs := w.resolver.ResolveIP(w.hostport) if len(addrs) > 0 && !addrListEqual(addrs, w.addrs) { w.addrs = addrs w.updateCh <- addrs } } func (w *DNSWatcher) loop() { defer close(w.updateCh) tick := time.NewTicker(dnsWatcherInterval) defer tick.Stop() for { select { case <-tick.C: w.check() case <-w.stopCh: return } } } type multiDNSUpdate struct { hostport string addrs []string } // A MultiDNSWatcher watches multiple addresses for DNS changes. The // results are merged and returned as a list of addresses. type MultiDNSWatcher struct { watchers []*DNSWatcher addrmap map[string][]string faninCh chan multiDNSUpdate updateCh chan []string } // NewMultiDNSWatcher creates a new MultiDNSWatcher. func NewMultiDNSWatcher(hostports []string) (*MultiDNSWatcher, error) { return newMultiDNSWatcherWithResolver(hostports, defaultResolver) } func newMultiDNSWatcherWithResolver(hostports []string, resolver resolver) (*MultiDNSWatcher, error) { mw := &MultiDNSWatcher{ addrmap: make(map[string][]string), faninCh: make(chan multiDNSUpdate, 10), updateCh: make(chan []string, 10), } // All the MultiDNSWatcher does is multiplex updates from the // individual DNSWatchers onto faninCh, then merging those // updates with all the others and sending the result to // updateCh. go func() { defer close(mw.updateCh) for up := range mw.faninCh { mw.addrmap[up.hostport] = up.addrs mw.updateCh <- mw.allAddrs() } }() var wg sync.WaitGroup for _, hostport := range hostports { w, err := newDNSWatcherWithResolver(hostport, resolver) if err != nil { return nil, err } mw.watchers = append(mw.watchers, w) wg.Add(1) go func(hostport string) { for addrs := range w.Changes() { mw.faninCh <- multiDNSUpdate{ hostport: hostport, addrs: addrs, } } wg.Done() }(hostport) } go func() { wg.Wait() close(mw.faninCh) }() return mw, nil } func (mw *MultiDNSWatcher) allAddrs() []string { var out []string for _, addrs := range mw.addrmap { out = append(out, addrs...) } return out } // Stop the watcher. func (mw *MultiDNSWatcher) Stop() { for _, w := range mw.watchers { w.Stop() } } // Changes returns a channel where the aggregate resolved addresses // are sent whenever they change. func (mw *MultiDNSWatcher) Changes() <-chan []string { return mw.updateCh } func addrListEqual(a, b []string) bool { if len(a) != len(b) { return false } tmp := make(map[string]struct{}) for _, aa := range a { tmp[aa] = struct{}{} } for _, bb := range b { if _, ok := tmp[bb]; !ok { return false } delete(tmp, bb) } return len(tmp) == 0 }