Skip to content
Snippets Groups Projects
Commit 3057be76 authored by ale's avatar ale
Browse files

Upgrade go-common (enforce rpc timeouts)

parent f5a57b01
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,13 @@ type BackendConfig struct {
TLSConfig *TLSClientConfig `yaml:"tls"`
Sharded bool `yaml:"sharded"`
Debug bool `yaml:"debug"`
// Connection timeout (if unset, use default value).
ConnectTimeout string `yaml:"connect_timeout"`
// Maximum timeout for each individual request to this backend
// (if unset, use the Context timeout).
RequestMaxTimeout string `yaml:"request_max_timeout"`
}
// Backend is a runtime class that provides http Clients for use with
......
......@@ -60,10 +60,11 @@ func newExponentialBackOff() *backoff.ExponentialBackOff {
type balancedBackend struct {
*backendTracker
*transportCache
baseURI *url.URL
sharded bool
resolver resolver
log logger
baseURI *url.URL
sharded bool
resolver resolver
log logger
requestMaxTimeout time.Duration
}
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
......@@ -80,17 +81,36 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack
}
}
var connectTimeout time.Duration
if config.ConnectTimeout != "" {
t, err := time.ParseDuration(config.ConnectTimeout)
if err != nil {
return nil, fmt.Errorf("error in connect_timeout: %v", err)
}
connectTimeout = t
}
var reqTimeout time.Duration
if config.RequestMaxTimeout != "" {
t, err := time.ParseDuration(config.RequestMaxTimeout)
if err != nil {
return nil, fmt.Errorf("error in request_max_timeout: %v", err)
}
reqTimeout = t
}
var logger logger = &nilLogger{}
if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
}
return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig),
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig, connectTimeout),
requestMaxTimeout: reqTimeout,
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
}, nil
}
......@@ -115,6 +135,9 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
......@@ -198,7 +221,7 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
client := &http.Client{
Transport: b.transportCache.getTransport(target),
}
resp, err = client.Do(req.WithContext(ctx))
resp, err = client.Do(propagateDeadline(ctx, req))
if err == nil && resp.StatusCode != 200 {
err = remoteErrorFromResponse(resp)
if !isStatusTemporary(resp.StatusCode) {
......@@ -212,6 +235,19 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
return
}
const deadlineHeader = "X-RPC-Deadline"
// Propagate context deadline to the server using a HTTP header.
func propagateDeadline(ctx context.Context, req *http.Request) *http.Request {
req = req.WithContext(ctx)
if deadline, ok := ctx.Deadline(); ok {
req.Header.Set(deadlineHeader, strconv.FormatInt(deadline.UTC().UnixNano(), 10))
} else {
req.Header.Del(deadlineHeader)
}
return req
}
var errNoTargets = errors.New("no available backends")
type targetGenerator interface {
......
// +build go1.9
package clientutil
import (
"context"
"net"
"time"
)
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, net, addr)
}
}
// +build !go1.9
package clientutil
import (
"context"
"net"
"time"
)
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
ctxTimeout := time.Until(deadline)
if ctxTimeout < connectTimeout {
connectTimeout = ctxTimeout
}
}
return net.DialTimeout(network, addr, connectTimeout)
}
}
package clientutil
import (
"context"
"crypto/tls"
"net"
"net/http"
"sync"
"time"
......@@ -11,31 +9,42 @@ import (
"git.autistici.org/ai3/go-common/tracing"
)
var defaultConnectTimeout = 30 * time.Second
// The transportCache is just a cache of http transports, each
// connecting to a specific address.
//
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
tlsConfig *tls.Config
tlsConfig *tls.Config
connectTimeout time.Duration
mx sync.RWMutex
transports map[string]http.RoundTripper
}
func newTransportCache(tlsConfig *tls.Config) *transportCache {
func newTransportCache(tlsConfig *tls.Config, connectTimeout time.Duration) *transportCache {
if connectTimeout == 0 {
connectTimeout = defaultConnectTimeout
}
return &transportCache{
tlsConfig: tlsConfig,
transports: make(map[string]http.RoundTripper),
tlsConfig: tlsConfig,
connectTimeout: connectTimeout,
transports: make(map[string]http.RoundTripper),
}
}
func (m *transportCache) newTransport(addr string) http.RoundTripper {
return tracing.WrapTransport(&http.Transport{
TLSClientConfig: m.tlsConfig,
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
return netDialContext(ctx, network, addr)
},
DialContext: netDialContext(addr, m.connectTimeout),
// Parameters match those of net/http.DefaultTransport.
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
})
}
......@@ -55,13 +64,3 @@ func (m *transportCache) getTransport(addr string) http.RoundTripper {
return t
}
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
timeout := 60 * time.Second // some arbitrary max timeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
return net.DialTimeout(network, addr, timeout)
}
......@@ -11,6 +11,7 @@ import (
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"syscall"
"time"
......@@ -26,6 +27,7 @@ var gracefulShutdownTimeout = 3 * time.Second
type ServerConfig struct {
TLS *TLSServerConfig `yaml:"tls"`
MaxInflightRequests int `yaml:"max_inflight_requests"`
RequestTimeoutSecs int `yaml:"request_timeout"`
TrustedForwarders []string `yaml:"trusted_forwarders"`
}
......@@ -60,11 +62,18 @@ func (config *ServerConfig) buildHTTPServer(h http.Handler) (*http.Server, error
}
}
// Wrap the handler with a TimeoutHandler if 'request_timeout'
// is set.
h = addDefaultHandlers(h)
if config.RequestTimeoutSecs > 0 {
h = http.TimeoutHandler(h, time.Duration(config.RequestTimeoutSecs)*time.Second, "")
}
// These are not meant to be external-facing servers, so we
// can be generous with the timeouts to keep the number of
// reconnections low.
return &http.Server{
Handler: defaultHandler(h),
Handler: h,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 600 * time.Second,
......@@ -134,7 +143,7 @@ func Serve(h http.Handler, config *ServerConfig, addr string) error {
return nil
}
func defaultHandler(h http.Handler) http.Handler {
func addDefaultHandlers(h http.Handler) http.Handler {
root := http.NewServeMux()
// Add an endpoint for HTTP health checking probes.
......@@ -154,11 +163,30 @@ func defaultHandler(h http.Handler) http.Handler {
// Prometheus instrumentation (requests to /metrics and
// /health are not included).
root.Handle("/", promhttp.InstrumentHandlerInFlight(inFlightRequests,
promhttp.InstrumentHandlerCounter(totalRequests, h)))
promhttp.InstrumentHandlerCounter(totalRequests,
propagateDeadline(h))))
return root
}
const deadlineHeader = "X-RPC-Deadline"
// Read an eventual deadline from the HTTP request, and set it as the
// deadline of the request context.
func propagateDeadline(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if hdr := req.Header.Get(deadlineHeader); hdr != "" {
if deadlineNano, err := strconv.ParseInt(hdr, 10, 64); err == nil {
deadline := time.Unix(0, deadlineNano)
ctx, cancel := context.WithDeadline(req.Context(), deadline)
defer cancel()
req = req.WithContext(ctx)
}
}
h.ServeHTTP(w, req)
})
}
func guessEndpointName(addr string) string {
_, port, err := net.SplitHostPort(addr)
if err != nil {
......
......@@ -2,10 +2,12 @@ package serverutil
import (
"crypto/tls"
"errors"
"fmt"
"log"
"net/http"
"regexp"
"strings"
common "git.autistici.org/ai3/go-common"
)
......@@ -42,6 +44,31 @@ func (p *TLSAuthACL) match(req *http.Request) bool {
return false
}
// TLSAuthACLListFlag is a convenience type that allows callers to use
// the 'flag' package to specify a list of TLSAuthACL objects. It
// implements the flag.Value interface.
type TLSAuthACLListFlag []*TLSAuthACL
func (l TLSAuthACLListFlag) String() string {
var out []string
for _, acl := range l {
out = append(out, fmt.Sprintf("%s:%s", acl.Path, acl.CommonName))
}
return strings.Join(out, ",")
}
func (l *TLSAuthACLListFlag) Set(value string) error {
parts := strings.SplitN(value, ":", 2)
if len(parts) != 2 {
return errors.New("bad acl format")
}
*l = append(*l, &TLSAuthACL{
Path: parts[0],
CommonName: parts[1],
})
return nil
}
// TLSAuthConfig stores access control lists for TLS authentication. Access
// control lists are matched against the request path and the
// CommonName component of the peer certificate subject.
......
......@@ -5,26 +5,26 @@
{
"checksumSHA1": "oUOxU+Tw1/jOzWVP05HuGvVSC/A=",
"path": "git.autistici.org/ai3/go-common",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "d4396660b1f09d2d07a7271b8a480726c5d29a6e",
"revisionTime": "2020-02-06T11:03:59Z"
},
{
"checksumSHA1": "kJwm6y9JXhybelO2zUl7UbzIdP0=",
"checksumSHA1": "/uo2RHmQ/sae2RMYKm81zb6OUn4=",
"path": "git.autistici.org/ai3/go-common/clientutil",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "d4396660b1f09d2d07a7271b8a480726c5d29a6e",
"revisionTime": "2020-02-06T11:03:59Z"
},
{
"checksumSHA1": "TKGUNmKxj7KH3qhwiCh/6quUnwc=",
"checksumSHA1": "xrew5jDNLDh5bhTZ4jgO0rX0N+4=",
"path": "git.autistici.org/ai3/go-common/serverutil",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "d4396660b1f09d2d07a7271b8a480726c5d29a6e",
"revisionTime": "2020-02-06T11:03:59Z"
},
{
"checksumSHA1": "y5pRYZ/NhfEOCFslPEuUZTYXcro=",
"path": "git.autistici.org/ai3/go-common/tracing",
"revision": "54f0ac4c46184ae44486a31ca2705076abcc5321",
"revisionTime": "2019-06-30T08:30:15Z"
"revision": "d4396660b1f09d2d07a7271b8a480726c5d29a6e",
"revisionTime": "2020-02-06T11:03:59Z"
},
{
"checksumSHA1": "FRxoT4jwgKDffIm5RwpFWjVVilc=",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment