dns.go 1.83 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
package clientutil

import (
	"log"
	"net"
	"sync"
	"time"

	"golang.org/x/sync/singleflight"
)

type resolver interface {
	ResolveIP(string) []string
}

type dnsResolver struct{}

func (r *dnsResolver) ResolveIP(hostport string) []string {
	var resolved []string
	host, port, err := net.SplitHostPort(hostport)
	if err != nil {
		log.Printf("error parsing %s: %v", hostport, err)
		return nil
	}
	hostIPs, err := net.LookupIP(host)
	if err != nil {
		log.Printf("error resolving %s: %v", host, err)
		return nil
	}
	for _, ip := range hostIPs {
		resolved = append(resolved, net.JoinHostPort(ip.String(), port))
	}
	return resolved
}

var defaultResolver = newDNSCache(&dnsResolver{})

type cacheDatum struct {
	addrs    []string
	deadline time.Time
}

ale's avatar
ale committed
43 44
var dnsCacheTTL = 1 * time.Minute

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
type dnsCache struct {
	resolver resolver
	sf       singleflight.Group
	mx       sync.RWMutex
	cache    map[string]cacheDatum
}

func newDNSCache(resolver resolver) *dnsCache {
	return &dnsCache{
		resolver: resolver,
		cache:    make(map[string]cacheDatum),
	}
}

func (c *dnsCache) get(host string) ([]string, bool) {
	d, ok := c.cache[host]
	if !ok {
		return nil, false
	}
	return d.addrs, d.deadline.After(time.Now())
}

func (c *dnsCache) update(host string) []string {
	v, _, _ := c.sf.Do(host, func() (interface{}, error) {
		addrs := c.resolver.ResolveIP(host)
		// By uncommenting this, we stop caching negative results.
		// if len(addrs) == 0 {
		// 	return nil, nil
		// }
		c.mx.Lock()
		c.cache[host] = cacheDatum{
			addrs:    addrs,
ale's avatar
ale committed
77
			deadline: time.Now().Add(dnsCacheTTL),
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
		}
		c.mx.Unlock()
		return addrs, nil
	})
	return v.([]string)
}

func (c *dnsCache) ResolveIP(host string) []string {
	c.mx.RLock()
	addrs, ok := c.get(host)
	c.mx.RUnlock()

	if ok {
		return addrs
	}

	if len(addrs) > 0 {
		go c.update(host)
		return addrs
	}

	return c.update(host)
}