package clientutil

import (
	"context"
	"crypto/tls"
	"net"
	"net/http"
	"sync"
	"time"

	"git.autistici.org/ai3/go-common/tracing"
)

// The transportCache is just a cache of http transports, each
// connecting to a specific address.
//
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
	tlsConfig *tls.Config

	mx         sync.RWMutex
	transports map[string]http.RoundTripper
}

func newTransportCache(tlsConfig *tls.Config) *transportCache {
	return &transportCache{
		tlsConfig:  tlsConfig,
		transports: make(map[string]http.RoundTripper),
	}
}

func (m *transportCache) newTransport(addr string) http.RoundTripper {
	return tracing.WrapTransport(&http.Transport{
		TLSClientConfig: m.tlsConfig,
		DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
			return netDialContext(ctx, network, addr)
		},
	})
}

func (m *transportCache) getTransport(addr string) http.RoundTripper {
	m.mx.RLock()
	t, ok := m.transports[addr]
	m.mx.RUnlock()

	if !ok {
		m.mx.Lock()
		if t, ok = m.transports[addr]; !ok {
			t = m.newTransport(addr)
			m.transports[addr] = t
		}
		m.mx.Unlock()
	}

	return t
}

// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
	timeout := 60 * time.Second // some arbitrary max timeout
	if deadline, ok := ctx.Deadline(); ok {
		timeout = time.Until(deadline)
	}
	return net.DialTimeout(network, addr, timeout)
}