Skip to content
Snippets Groups Projects
Commit 3ecaec92 authored by ale's avatar ale
Browse files

Update ai3/go-common/clientutil

parent 0c0c8424
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,13 @@ type BackendConfig struct {
TLSConfig *TLSClientConfig `yaml:"tls"`
Sharded bool `yaml:"sharded"`
Debug bool `yaml:"debug"`
// Connection timeout (if unset, use default value).
ConnectTimeout string `yaml:"connect_timeout"`
// Maximum timeout for each individual request to this backend
// (if unset, use the Context timeout).
RequestMaxTimeout string `yaml:"request_max_timeout"`
}
// Backend is a runtime class that provides http Clients for use with
......
......@@ -60,10 +60,11 @@ func newExponentialBackOff() *backoff.ExponentialBackOff {
type balancedBackend struct {
*backendTracker
*transportCache
baseURI *url.URL
sharded bool
resolver resolver
log logger
baseURI *url.URL
sharded bool
resolver resolver
log logger
requestMaxTimeout time.Duration
}
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
......@@ -80,17 +81,36 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack
}
}
var connectTimeout time.Duration
if config.ConnectTimeout != "" {
t, err := time.ParseDuration(config.ConnectTimeout)
if err != nil {
return nil, fmt.Errorf("error in connect_timeout: %v", err)
}
connectTimeout = t
}
var reqTimeout time.Duration
if config.RequestMaxTimeout != "" {
t, err := time.ParseDuration(config.RequestMaxTimeout)
if err != nil {
return nil, fmt.Errorf("error in request_max_timeout: %v", err)
}
reqTimeout = t
}
var logger logger = &nilLogger{}
if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
}
return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig),
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig, connectTimeout),
requestMaxTimeout: reqTimeout,
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
}, nil
}
......@@ -115,6 +135,9 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
......
// +build go1.9
package clientutil
import (
"context"
"net"
"time"
)
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, net, addr)
}
}
// +build !go1.9
package clientutil
import (
"context"
"net"
"time"
)
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
ctxTimeout := time.Until(deadline)
if ctxTimeout < connectTimeout {
connectTimeout = ctxTimeout
}
}
return net.DialTimeout(network, addr, connectTimeout)
}
}
package clientutil
import (
"context"
"crypto/tls"
"net"
"net/http"
"sync"
"time"
......@@ -11,31 +9,42 @@ import (
"git.autistici.org/ai3/go-common/tracing"
)
var defaultConnectTimeout = 30 * time.Second
// The transportCache is just a cache of http transports, each
// connecting to a specific address.
//
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
tlsConfig *tls.Config
tlsConfig *tls.Config
connectTimeout time.Duration
mx sync.RWMutex
transports map[string]http.RoundTripper
}
func newTransportCache(tlsConfig *tls.Config) *transportCache {
func newTransportCache(tlsConfig *tls.Config, connectTimeout time.Duration) *transportCache {
if connectTimeout == 0 {
connectTimeout = defaultConnectTimeout
}
return &transportCache{
tlsConfig: tlsConfig,
transports: make(map[string]http.RoundTripper),
tlsConfig: tlsConfig,
connectTimeout: connectTimeout,
transports: make(map[string]http.RoundTripper),
}
}
func (m *transportCache) newTransport(addr string) http.RoundTripper {
return tracing.WrapTransport(&http.Transport{
TLSClientConfig: m.tlsConfig,
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
return netDialContext(ctx, network, addr)
},
DialContext: netDialContext(addr, m.connectTimeout),
// Parameters match those of net/http.DefaultTransport.
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
})
}
......@@ -55,13 +64,3 @@ func (m *transportCache) getTransport(addr string) http.RoundTripper {
return t
}
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
timeout := 60 * time.Second // some arbitrary max timeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
return net.DialTimeout(network, addr, timeout)
}
......@@ -2,10 +2,12 @@ package serverutil
import (
"crypto/tls"
"errors"
"fmt"
"log"
"net/http"
"regexp"
"strings"
common "git.autistici.org/ai3/go-common"
)
......@@ -42,6 +44,31 @@ func (p *TLSAuthACL) match(req *http.Request) bool {
return false
}
// TLSAuthACLListFlag is a convenience type that allows callers to use
// the 'flag' package to specify a list of TLSAuthACL objects. It
// implements the flag.Value interface.
type TLSAuthACLListFlag []*TLSAuthACL
func (l TLSAuthACLListFlag) String() string {
var out []string
for _, acl := range l {
out = append(out, fmt.Sprintf("%s:%s", acl.Path, acl.CommonName))
}
return strings.Join(out, ",")
}
func (l *TLSAuthACLListFlag) Set(value string) error {
parts := strings.SplitN(value, ":", 2)
if len(parts) != 2 {
return errors.New("bad acl format")
}
*l = append(*l, &TLSAuthACL{
Path: parts[0],
CommonName: parts[1],
})
return nil
}
// TLSAuthConfig stores access control lists for TLS authentication. Access
// control lists are matched against the request path and the
// CommonName component of the peer certificate subject.
......
......@@ -5,26 +5,26 @@
{
"checksumSHA1": "oUOxU+Tw1/jOzWVP05HuGvVSC/A=",
"path": "git.autistici.org/ai3/go-common",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "c165311f4270e8a2d75d9a610abdfb54d72ae4e5",
"revisionTime": "2020-01-06T11:09:19Z"
},
{
"checksumSHA1": "kJwm6y9JXhybelO2zUl7UbzIdP0=",
"checksumSHA1": "k/t0n698YTZJR6ncd9gzSvymXsE=",
"path": "git.autistici.org/ai3/go-common/clientutil",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "c165311f4270e8a2d75d9a610abdfb54d72ae4e5",
"revisionTime": "2020-01-06T11:09:19Z"
},
{
"checksumSHA1": "TKGUNmKxj7KH3qhwiCh/6quUnwc=",
"checksumSHA1": "Xm0ZN1urTQaagPJ5kJlCjeGONOU=",
"path": "git.autistici.org/ai3/go-common/serverutil",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "c165311f4270e8a2d75d9a610abdfb54d72ae4e5",
"revisionTime": "2020-01-06T11:09:19Z"
},
{
"checksumSHA1": "y5pRYZ/NhfEOCFslPEuUZTYXcro=",
"path": "git.autistici.org/ai3/go-common/tracing",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "c165311f4270e8a2d75d9a610abdfb54d72ae4e5",
"revisionTime": "2020-01-06T11:09:19Z"
},
{
"checksumSHA1": "5WLGZjUV9Ly/rMdQwo9j8FJSlQA=",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment