package clientutil import ( "context" "crypto/tls" "net" "net/http" "sync" "time" "git.autistici.org/ai3/go-common/tracing" ) // The transportCache is just a cache of http transports, each // connecting to a specific address. // // We use this to control the HTTP Host header and the TLS ServerName // independently of the target address. type transportCache struct { tlsConfig *tls.Config mx sync.RWMutex transports map[string]http.RoundTripper } func newTransportCache(tlsConfig *tls.Config) *transportCache { return &transportCache{ tlsConfig: tlsConfig, transports: make(map[string]http.RoundTripper), } } func (m *transportCache) newTransport(addr string) http.RoundTripper { return tracing.WrapTransport(&http.Transport{ TLSClientConfig: m.tlsConfig, DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { return netDialContext(ctx, network, addr) }, }) } func (m *transportCache) getTransport(addr string) http.RoundTripper { m.mx.RLock() t, ok := m.transports[addr] m.mx.RUnlock() if !ok { m.mx.Lock() if t, ok = m.transports[addr]; !ok { t = m.newTransport(addr) m.transports[addr] = t } m.mx.Unlock() } return t } // Go < 1.9 does not have net.DialContext, reimplement it in terms of // net.DialTimeout. func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) { timeout := 60 * time.Second // some arbitrary max timeout if deadline, ok := ctx.Deadline(); ok { timeout = time.Until(deadline) } return net.DialTimeout(network, addr, timeout) }