dns.go 1.83 KB
Newer Older
ale's avatar
ale committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
package clientutil

import (
	"log"
	"net"
	"sync"
	"time"

	"golang.org/x/sync/singleflight"
)

type resolver interface {
	ResolveIP(string) []string
}

type dnsResolver struct{}

func (r *dnsResolver) ResolveIP(hostport string) []string {
	var resolved []string
	host, port, err := net.SplitHostPort(hostport)
	if err != nil {
		log.Printf("error parsing %s: %v", hostport, err)
		return nil
	}
	hostIPs, err := net.LookupIP(host)
	if err != nil {
		log.Printf("error resolving %s: %v", host, err)
		return nil
	}
	for _, ip := range hostIPs {
		resolved = append(resolved, net.JoinHostPort(ip.String(), port))
	}
	return resolved
}

var defaultResolver = newDNSCache(&dnsResolver{})

type cacheDatum struct {
	addrs    []string
	deadline time.Time
}

ale's avatar
ale committed
43 44
var dnsCacheTTL = 1 * time.Minute

ale's avatar
ale committed
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
type dnsCache struct {
	resolver resolver
	sf       singleflight.Group
	mx       sync.RWMutex
	cache    map[string]cacheDatum
}

func newDNSCache(resolver resolver) *dnsCache {
	return &dnsCache{
		resolver: resolver,
		cache:    make(map[string]cacheDatum),
	}
}

func (c *dnsCache) get(host string) ([]string, bool) {
	d, ok := c.cache[host]
	if !ok {
		return nil, false
	}
	return d.addrs, d.deadline.After(time.Now())
}

func (c *dnsCache) update(host string) []string {
	v, _, _ := c.sf.Do(host, func() (interface{}, error) {
		addrs := c.resolver.ResolveIP(host)
		// By uncommenting this, we stop caching negative results.
		// if len(addrs) == 0 {
		// 	return nil, nil
		// }
		c.mx.Lock()
		c.cache[host] = cacheDatum{
			addrs:    addrs,
ale's avatar
ale committed
77
			deadline: time.Now().Add(dnsCacheTTL),
ale's avatar
ale committed
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
		}
		c.mx.Unlock()
		return addrs, nil
	})
	return v.([]string)
}

func (c *dnsCache) ResolveIP(host string) []string {
	c.mx.RLock()
	addrs, ok := c.get(host)
	c.mx.RUnlock()

	if ok {
		return addrs
	}

	if len(addrs) > 0 {
		go c.update(host)
		return addrs
	}

	return c.update(host)
}