Skip to content
Snippets Groups Projects
Forked from ai3 / tools / acmeserver
90 commits behind the upstream repository.
dns.go 1.83 KiB
package clientutil

import (
	"log"
	"net"
	"sync"
	"time"

	"golang.org/x/sync/singleflight"
)

type resolver interface {
	ResolveIP(string) []string
}

type dnsResolver struct{}

func (r *dnsResolver) ResolveIP(hostport string) []string {
	var resolved []string
	host, port, err := net.SplitHostPort(hostport)
	if err != nil {
		log.Printf("error parsing %s: %v", hostport, err)
		return nil
	}
	hostIPs, err := net.LookupIP(host)
	if err != nil {
		log.Printf("error resolving %s: %v", host, err)
		return nil
	}
	for _, ip := range hostIPs {
		resolved = append(resolved, net.JoinHostPort(ip.String(), port))
	}
	return resolved
}

var defaultResolver = newDNSCache(&dnsResolver{})

type cacheDatum struct {
	addrs    []string
	deadline time.Time
}

var dnsCacheTTL = 1 * time.Minute

type dnsCache struct {
	resolver resolver
	sf       singleflight.Group
	mx       sync.RWMutex
	cache    map[string]cacheDatum
}

func newDNSCache(resolver resolver) *dnsCache {
	return &dnsCache{
		resolver: resolver,
		cache:    make(map[string]cacheDatum),
	}
}

func (c *dnsCache) get(host string) ([]string, bool) {
	d, ok := c.cache[host]
	if !ok {
		return nil, false
	}
	return d.addrs, d.deadline.After(time.Now())
}

func (c *dnsCache) update(host string) []string {
	v, _, _ := c.sf.Do(host, func() (interface{}, error) {
		addrs := c.resolver.ResolveIP(host)
		// By uncommenting this, we stop caching negative results.
		// if len(addrs) == 0 {
		// 	return nil, nil
		// }
		c.mx.Lock()
		c.cache[host] = cacheDatum{
			addrs:    addrs,
			deadline: time.Now().Add(dnsCacheTTL),
		}
		c.mx.Unlock()
		return addrs, nil
	})
	return v.([]string)
}

func (c *dnsCache) ResolveIP(host string) []string {
	c.mx.RLock()
	addrs, ok := c.get(host)
	c.mx.RUnlock()

	if ok {
		return addrs
	}

	if len(addrs) > 0 {
		go c.update(host)
		return addrs
	}

	return c.update(host)
}