package clientutil

import (
	"fmt"
	"sync"
	"time"
)

var dnsWatcherInterval = 1 * time.Minute

// A DNSWatcher monitors a DNS name for changes, constantly attempting
// to resolve it every minute and notifying a channel when the list of
// returned IP addresses changes. All addresses must be in host:port
// format.
type DNSWatcher struct {
	hostport string
	resolver resolver
	addrs    []string
	updateCh chan []string
	stopCh   chan struct{}
}

// NewDNSWatcher creates a new DNSWatcher.
func NewDNSWatcher(hostport string) (*DNSWatcher, error) {
	return newDNSWatcherWithResolver(hostport, defaultResolver)
}

func newDNSWatcherWithResolver(hostport string, resolver resolver) (*DNSWatcher, error) {
	// Resolve names once before returning. Return a fatal error
	// when there are no results, as it may indicate a syntax
	// error in hostport.
	addrs := resolver.ResolveIP(hostport)
	if len(addrs) == 0 {
		return nil, fmt.Errorf("can't resolve %s", hostport)
	}
	w := &DNSWatcher{
		hostport: hostport,
		resolver: resolver,
		addrs:    addrs,
		updateCh: make(chan []string, 10),
		stopCh:   make(chan struct{}),
	}
	w.updateCh <- addrs
	go w.loop()
	return w, nil
}

// Stop the watcher.
func (w *DNSWatcher) Stop() {
	close(w.stopCh)
}

// Changes returns a channel where the resolved addresses are sent
// whenever they change.
func (w *DNSWatcher) Changes() <-chan []string {
	return w.updateCh
}

func (w *DNSWatcher) check() {
	addrs := w.resolver.ResolveIP(w.hostport)
	if len(addrs) > 0 && !addrListEqual(addrs, w.addrs) {
		w.addrs = addrs
		w.updateCh <- addrs
	}
}

func (w *DNSWatcher) loop() {
	defer close(w.updateCh)

	tick := time.NewTicker(dnsWatcherInterval)
	defer tick.Stop()

	for {
		select {
		case <-tick.C:
			w.check()
		case <-w.stopCh:
			return
		}
	}
}

type multiDNSUpdate struct {
	hostport string
	addrs    []string
}

// A MultiDNSWatcher watches multiple addresses for DNS changes. The
// results are merged and returned as a list of addresses.
type MultiDNSWatcher struct {
	watchers []*DNSWatcher
	addrmap  map[string][]string
	faninCh  chan multiDNSUpdate
	updateCh chan []string
}

// NewMultiDNSWatcher creates a new MultiDNSWatcher.
func NewMultiDNSWatcher(hostports []string) (*MultiDNSWatcher, error) {
	return newMultiDNSWatcherWithResolver(hostports, defaultResolver)
}

func newMultiDNSWatcherWithResolver(hostports []string, resolver resolver) (*MultiDNSWatcher, error) {
	mw := &MultiDNSWatcher{
		addrmap:  make(map[string][]string),
		faninCh:  make(chan multiDNSUpdate, 10),
		updateCh: make(chan []string, 10),
	}

	// All the MultiDNSWatcher does is multiplex updates from the
	// individual DNSWatchers onto faninCh, then merging those
	// updates with all the others and sending the result to
	// updateCh.
	go func() {
		defer close(mw.updateCh)
		for up := range mw.faninCh {
			mw.addrmap[up.hostport] = up.addrs
			mw.updateCh <- mw.allAddrs()
		}
	}()

	var wg sync.WaitGroup
	for _, hostport := range hostports {
		w, err := newDNSWatcherWithResolver(hostport, resolver)
		if err != nil {
			return nil, err
		}
		mw.watchers = append(mw.watchers, w)

		wg.Add(1)
		go func(hostport string) {
			for addrs := range w.Changes() {
				mw.faninCh <- multiDNSUpdate{
					hostport: hostport,
					addrs:    addrs,
				}
			}
			wg.Done()
		}(hostport)
	}

	go func() {
		wg.Wait()
		close(mw.faninCh)
	}()

	return mw, nil
}

func (mw *MultiDNSWatcher) allAddrs() []string {
	var out []string
	for _, addrs := range mw.addrmap {
		out = append(out, addrs...)
	}
	return out
}

// Stop the watcher.
func (mw *MultiDNSWatcher) Stop() {
	for _, w := range mw.watchers {
		w.Stop()
	}
}

// Changes returns a channel where the aggregate resolved addresses
// are sent whenever they change.
func (mw *MultiDNSWatcher) Changes() <-chan []string {
	return mw.updateCh
}

func addrListEqual(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}

	tmp := make(map[string]struct{})
	for _, aa := range a {
		tmp[aa] = struct{}{}
	}
	for _, bb := range b {
		if _, ok := tmp[bb]; !ok {
			return false
		}
		delete(tmp, bb)
	}
	return len(tmp) == 0
}