transport.go 1.54 KB
Newer Older
1 2 3 4 5 6 7 8 9
package clientutil

import (
	"context"
	"crypto/tls"
	"net"
	"net/http"
	"sync"
	"time"
ale's avatar
ale committed
10 11

	"git.autistici.org/ai3/go-common/tracing"
12 13
)

ale's avatar
ale committed
14 15
// The transportCache is just a cache of http transports, each
// connecting to a specific address.
16
//
ale's avatar
ale committed
17 18 19 20
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
	tlsConfig *tls.Config
21

ale's avatar
ale committed
22 23
	mx         sync.RWMutex
	transports map[string]http.RoundTripper
24 25
}

ale's avatar
ale committed
26 27 28 29 30 31
func newTransportCache(tlsConfig *tls.Config) *transportCache {
	return &transportCache{
		tlsConfig:  tlsConfig,
		transports: make(map[string]http.RoundTripper),
	}
}
32

ale's avatar
ale committed
33
func (m *transportCache) newTransport(addr string) http.RoundTripper {
ale's avatar
ale committed
34
	return tracing.WrapTransport(&http.Transport{
ale's avatar
ale committed
35 36 37 38
		TLSClientConfig: m.tlsConfig,
		DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
			return netDialContext(ctx, network, addr)
		},
ale's avatar
ale committed
39
	})
40 41
}

ale's avatar
ale committed
42 43 44 45
func (m *transportCache) getTransport(addr string) http.RoundTripper {
	m.mx.RLock()
	t, ok := m.transports[addr]
	m.mx.RUnlock()
46

ale's avatar
ale committed
47 48 49 50 51
	if !ok {
		m.mx.Lock()
		if t, ok = m.transports[addr]; !ok {
			t = m.newTransport(addr)
			m.transports[addr] = t
52
		}
ale's avatar
ale committed
53
		m.mx.Unlock()
54 55
	}

ale's avatar
ale committed
56
	return t
57 58
}

ale's avatar
ale committed
59 60
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
61
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
ale's avatar
ale committed
62
	timeout := 60 * time.Second // some arbitrary max timeout
63 64 65 66 67
	if deadline, ok := ctx.Deadline(); ok {
		timeout = time.Until(deadline)
	}
	return net.DialTimeout(network, addr, timeout)
}