package clientutil

import (
	"math/rand"
	"sync"
	"time"
)

// The backendTracker tracks the state of the targets associated with
// a backend, and periodically checks DNS for updates.
type backendTracker struct {
	log      logger
	addr     string
	resolver resolver
	stopCh   chan struct{}

	mx       sync.Mutex
	resolved []string
	failed   map[string]time.Time
}

func newBackendTracker(addr string, resolver resolver, logger logger) *backendTracker {
	// Resolve the targets once before returning.
	b := &backendTracker{
		addr:     addr,
		resolver: resolver,
		resolved: resolver.ResolveIP(addr),
		failed:   make(map[string]time.Time),
		stopCh:   make(chan struct{}),
		log:      logger,
	}
	go b.updateProc()
	return b
}

func (b *backendTracker) Close() {
	close(b.stopCh)
}

// Return the full list of targets in reverse preference order.
func (b *backendTracker) getTargets() []string {
	b.mx.Lock()
	defer b.mx.Unlock()

	var good, bad []string
	for _, t := range b.resolved {
		if _, ok := b.failed[t]; ok {
			bad = append(bad, t)
		} else {
			good = append(good, t)
		}
	}

	good = shuffle(good)
	bad = shuffle(bad)

	return append(good, bad...)
}

func (b *backendTracker) setStatus(addr string, ok bool) {
	b.mx.Lock()

	_, isFailed := b.failed[addr]
	if isFailed && ok {
		b.log.Printf("target %s now ok", addr)
		delete(b.failed, addr)
	} else if !isFailed && !ok {
		b.log.Printf("target %s failed", addr)
		b.failed[addr] = time.Now()
	}

	b.mx.Unlock()
}

var (
	backendUpdateInterval       = 60 * time.Second
	backendFailureRetryInterval = 60 * time.Second
)

func (b *backendTracker) expireFailedTargets() {
	b.mx.Lock()
	now := time.Now()
	for k, v := range b.failed {
		if now.Sub(v) > backendFailureRetryInterval {
			delete(b.failed, k)
		}
	}
	b.mx.Unlock()
}

func (b *backendTracker) updateProc() {
	tick := time.NewTicker(backendUpdateInterval)
	defer tick.Stop()
	for {
		select {
		case <-b.stopCh:
			return
		case <-tick.C:
			b.expireFailedTargets()
			resolved := b.resolver.ResolveIP(b.addr)
			if len(resolved) > 0 {
				b.mx.Lock()
				b.resolved = resolved
				b.mx.Unlock()
			}
		}
	}
}

var shuffleSrc = rand.NewSource(time.Now().UnixNano())

// Re-order elements of a slice randomly.
func shuffle(values []string) []string {
	if len(values) < 2 {
		return values
	}
	rnd := rand.New(shuffleSrc)
	for i := len(values) - 1; i > 0; i-- {
		j := rnd.Intn(i + 1)
		values[i], values[j] = values[j], values[i]
	}
	return values
}