Skip to content
Snippets Groups Projects
transport.go 1.54 KiB
Newer Older
ale's avatar
ale committed
package clientutil

import (
	"context"
	"crypto/tls"
	"net"
	"net/http"
	"sync"
	"time"
ale's avatar
ale committed

	"git.autistici.org/ai3/go-common/tracing"
ale's avatar
ale committed
)

// The transportCache is just a cache of http transports, each
// connecting to a specific address.
ale's avatar
ale committed
//
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
	tlsConfig *tls.Config
ale's avatar
ale committed

	mx         sync.RWMutex
	transports map[string]http.RoundTripper
ale's avatar
ale committed
}

func newTransportCache(tlsConfig *tls.Config) *transportCache {
	return &transportCache{
		tlsConfig:  tlsConfig,
		transports: make(map[string]http.RoundTripper),
	}
}
ale's avatar
ale committed

func (m *transportCache) newTransport(addr string) http.RoundTripper {
ale's avatar
ale committed
	return tracing.WrapTransport(&http.Transport{
		TLSClientConfig: m.tlsConfig,
		DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
			return netDialContext(ctx, network, addr)
		},
ale's avatar
ale committed
	})
ale's avatar
ale committed
}

func (m *transportCache) getTransport(addr string) http.RoundTripper {
	m.mx.RLock()
	t, ok := m.transports[addr]
	m.mx.RUnlock()
ale's avatar
ale committed

	if !ok {
		m.mx.Lock()
		if t, ok = m.transports[addr]; !ok {
			t = m.newTransport(addr)
			m.transports[addr] = t
ale's avatar
ale committed
		}
ale's avatar
ale committed
	}

ale's avatar
ale committed
}

// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
ale's avatar
ale committed
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
	timeout := 60 * time.Second // some arbitrary max timeout
ale's avatar
ale committed
	if deadline, ok := ctx.Deadline(); ok {
		timeout = time.Until(deadline)
	}
	return net.DialTimeout(network, addr, timeout)
}