Commit 7bd90854 authored by ale's avatar ale

Upgrade go-common (enforce rpc timeouts)

parent 1723e6aa
Pipeline #6245 passed with stages
in 5 minutes and 48 seconds
......@@ -16,6 +16,13 @@ type BackendConfig struct {
TLSConfig *TLSClientConfig `yaml:"tls"`
Sharded bool `yaml:"sharded"`
Debug bool `yaml:"debug"`
// Connection timeout (if unset, use default value).
ConnectTimeout string `yaml:"connect_timeout"`
// Maximum timeout for each individual request to this backend
// (if unset, use the Context timeout).
RequestMaxTimeout string `yaml:"request_max_timeout"`
}
// Backend is a runtime class that provides http Clients for use with
......
......@@ -60,10 +60,11 @@ func newExponentialBackOff() *backoff.ExponentialBackOff {
type balancedBackend struct {
*backendTracker
*transportCache
baseURI *url.URL
sharded bool
resolver resolver
log logger
baseURI *url.URL
sharded bool
resolver resolver
log logger
requestMaxTimeout time.Duration
}
func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
......@@ -80,17 +81,36 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack
}
}
var connectTimeout time.Duration
if config.ConnectTimeout != "" {
t, err := time.ParseDuration(config.ConnectTimeout)
if err != nil {
return nil, fmt.Errorf("error in connect_timeout: %v", err)
}
connectTimeout = t
}
var reqTimeout time.Duration
if config.RequestMaxTimeout != "" {
t, err := time.ParseDuration(config.RequestMaxTimeout)
if err != nil {
return nil, fmt.Errorf("error in request_max_timeout: %v", err)
}
reqTimeout = t
}
var logger logger = &nilLogger{}
if config.Debug {
logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
}
return &balancedBackend{
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig),
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
backendTracker: newBackendTracker(u.Host, resolver, logger),
transportCache: newTransportCache(tlsConfig, connectTimeout),
requestMaxTimeout: reqTimeout,
sharded: config.Sharded,
baseURI: u,
resolver: resolver,
log: logger,
}, nil
}
......@@ -115,6 +135,9 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
}
if b.requestMaxTimeout > 0 && innerTimeout > b.requestMaxTimeout {
innerTimeout = b.requestMaxTimeout
}
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
......@@ -198,7 +221,7 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
client := &http.Client{
Transport: b.transportCache.getTransport(target),
}
resp, err = client.Do(req.WithContext(ctx))
resp, err = client.Do(propagateDeadline(ctx, req))
if err == nil && resp.StatusCode != 200 {
err = remoteErrorFromResponse(resp)
if !isStatusTemporary(resp.StatusCode) {
......@@ -212,6 +235,19 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque
return
}
const deadlineHeader = "X-RPC-Deadline"
// Propagate context deadline to the server using a HTTP header.
func propagateDeadline(ctx context.Context, req *http.Request) *http.Request {
req = req.WithContext(ctx)
if deadline, ok := ctx.Deadline(); ok {
req.Header.Set(deadlineHeader, strconv.FormatInt(deadline.UTC().UnixNano(), 10))
} else {
req.Header.Del(deadlineHeader)
}
return req
}
var errNoTargets = errors.New("no available backends")
type targetGenerator interface {
......
// +build go1.9
package clientutil
import (
"context"
"net"
"time"
)
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, net, addr)
}
}
// +build !go1.9
package clientutil
import (
"context"
"net"
"time"
)
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(addr string, connectTimeout time.Duration) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, net string, _ string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
ctxTimeout := time.Until(deadline)
if ctxTimeout < connectTimeout {
connectTimeout = ctxTimeout
}
}
return net.DialTimeout(network, addr, connectTimeout)
}
}
package clientutil
import (
"context"
"crypto/tls"
"net"
"net/http"
"sync"
"time"
......@@ -11,31 +9,42 @@ import (
"git.autistici.org/ai3/go-common/tracing"
)
var defaultConnectTimeout = 30 * time.Second
// The transportCache is just a cache of http transports, each
// connecting to a specific address.
//
// We use this to control the HTTP Host header and the TLS ServerName
// independently of the target address.
type transportCache struct {
tlsConfig *tls.Config
tlsConfig *tls.Config
connectTimeout time.Duration
mx sync.RWMutex
transports map[string]http.RoundTripper
}
func newTransportCache(tlsConfig *tls.Config) *transportCache {
func newTransportCache(tlsConfig *tls.Config, connectTimeout time.Duration) *transportCache {
if connectTimeout == 0 {
connectTimeout = defaultConnectTimeout
}
return &transportCache{
tlsConfig: tlsConfig,
transports: make(map[string]http.RoundTripper),
tlsConfig: tlsConfig,
connectTimeout: connectTimeout,
transports: make(map[string]http.RoundTripper),
}
}
func (m *transportCache) newTransport(addr string) http.RoundTripper {
return tracing.WrapTransport(&http.Transport{
TLSClientConfig: m.tlsConfig,
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
return netDialContext(ctx, network, addr)
},
DialContext: netDialContext(addr, m.connectTimeout),
// Parameters match those of net/http.DefaultTransport.
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
})
}
......@@ -55,13 +64,3 @@ func (m *transportCache) getTransport(addr string) http.RoundTripper {
return t
}
// Go < 1.9 does not have net.DialContext, reimplement it in terms of
// net.DialTimeout.
func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
timeout := 60 * time.Second // some arbitrary max timeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
return net.DialTimeout(network, addr, timeout)
}
......@@ -11,6 +11,7 @@ import (
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"syscall"
"time"
......@@ -26,6 +27,7 @@ var gracefulShutdownTimeout = 3 * time.Second
type ServerConfig struct {
TLS *TLSServerConfig `yaml:"tls"`
MaxInflightRequests int `yaml:"max_inflight_requests"`
RequestTimeoutSecs int `yaml:"request_timeout"`
TrustedForwarders []string `yaml:"trusted_forwarders"`
}
......@@ -60,11 +62,18 @@ func (config *ServerConfig) buildHTTPServer(h http.Handler) (*http.Server, error
}
}
// Wrap the handler with a TimeoutHandler if 'request_timeout'
// is set.
h = addDefaultHandlers(h)
if config.RequestTimeoutSecs > 0 {
h = http.TimeoutHandler(h, time.Duration(config.RequestTimeoutSecs)*time.Second, "")
}
// These are not meant to be external-facing servers, so we
// can be generous with the timeouts to keep the number of
// reconnections low.
return &http.Server{
Handler: defaultHandler(h),
Handler: h,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 600 * time.Second,
......@@ -134,7 +143,7 @@ func Serve(h http.Handler, config *ServerConfig, addr string) error {
return nil
}
func defaultHandler(h http.Handler) http.Handler {
func addDefaultHandlers(h http.Handler) http.Handler {
root := http.NewServeMux()
// Add an endpoint for HTTP health checking probes.
......@@ -154,11 +163,30 @@ func defaultHandler(h http.Handler) http.Handler {
// Prometheus instrumentation (requests to /metrics and
// /health are not included).
root.Handle("/", promhttp.InstrumentHandlerInFlight(inFlightRequests,
promhttp.InstrumentHandlerCounter(totalRequests, h)))
promhttp.InstrumentHandlerCounter(totalRequests,
propagateDeadline(h))))
return root
}
const deadlineHeader = "X-RPC-Deadline"
// Read an eventual deadline from the HTTP request, and set it as the
// deadline of the request context.
func propagateDeadline(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if hdr := req.Header.Get(deadlineHeader); hdr != "" {
if deadlineNano, err := strconv.ParseInt(hdr, 10, 64); err == nil {
deadline := time.Unix(0, deadlineNano)
ctx, cancel := context.WithDeadline(req.Context(), deadline)
defer cancel()
req = req.WithContext(ctx)
}
}
h.ServeHTTP(w, req)
})
}
func guessEndpointName(addr string) string {
_, port, err := net.SplitHostPort(addr)
if err != nil {
......
......@@ -2,10 +2,12 @@ package serverutil
import (
"crypto/tls"
"errors"
"fmt"
"log"
"net/http"
"regexp"
"strings"
common "git.autistici.org/ai3/go-common"
)
......@@ -42,6 +44,31 @@ func (p *TLSAuthACL) match(req *http.Request) bool {
return false
}
// TLSAuthACLListFlag is a convenience type that allows callers to use
// the 'flag' package to specify a list of TLSAuthACL objects. It
// implements the flag.Value interface.
type TLSAuthACLListFlag []*TLSAuthACL
func (l TLSAuthACLListFlag) String() string {
var out []string
for _, acl := range l {
out = append(out, fmt.Sprintf("%s:%s", acl.Path, acl.CommonName))
}
return strings.Join(out, ",")
}
func (l *TLSAuthACLListFlag) Set(value string) error {
parts := strings.SplitN(value, ":", 2)
if len(parts) != 2 {
return errors.New("bad acl format")
}
*l = append(*l, &TLSAuthACL{
Path: parts[0],
CommonName: parts[1],
})
return nil
}
// TLSAuthConfig stores access control lists for TLS authentication. Access
// control lists are matched against the request path and the
// CommonName component of the peer certificate subject.
......
// Copyright 2018 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"io/ioutil"
"strconv"
"strings"
)
// ParseUint32s parses a slice of strings into a slice of uint32s.
func ParseUint32s(ss []string) ([]uint32, error) {
us := make([]uint32, 0, len(ss))
for _, s := range ss {
u, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
us = append(us, uint32(u))
}
return us, nil
}
// ParseUint64s parses a slice of strings into a slice of uint64s.
func ParseUint64s(ss []string) ([]uint64, error) {
us := make([]uint64, 0, len(ss))
for _, s := range ss {
u, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return nil, err
}
us = append(us, u)
}
return us, nil
}
// ReadUintFromFile reads a file and attempts to parse a uint64 from it.
func ReadUintFromFile(path string) (uint64, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return 0, err
}
return strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
}
// ParseBool parses a string into a boolean pointer.
func ParseBool(b string) *bool {
var truth bool
switch b {
case "enabled":
truth = true
case "disabled":
truth = false
default:
return nil
}
return &truth
}
// Copyright 2018 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux,!appengine
package util
import (
"bytes"
"os"
"syscall"
)
// SysReadFile is a simplified ioutil.ReadFile that invokes syscall.Read directly.
// https://github.com/prometheus/node_exporter/pull/728/files
func SysReadFile(file string) (string, error) {
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
// On some machines, hwmon drivers are broken and return EAGAIN. This causes
// Go's ioutil.ReadFile implementation to poll forever.
//
// Since we either want to read data or bail immediately, do the simplest
// possible read using syscall directly.
b := make([]byte, 128)
n, err := syscall.Read(int(f.Fd()), b)
if err != nil {
return "", err
}
return string(bytes.TrimSpace(b[:n])), nil
}
// Copyright 2019 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux,appengine !linux
package util
import (
"fmt"
)
// SysReadFile is here implemented as a noop for builds that do not support
// the read syscall. For example Windows, or Linux on Google App Engine.
func SysReadFile(file string) (string, error) {
return "", fmt.Errorf("not supported on this platform")
}
// Copyright 2018 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package nfs implements parsing of /proc/net/rpc/nfsd.
// Fields are documented in https://www.svennd.be/nfsd-stats-explained-procnetrpcnfsd/
package nfs
import (
"os"
"path"
)
// ReplyCache models the "rc" line.
type ReplyCache struct {
Hits uint64
Misses uint64
NoCache uint64
}
// FileHandles models the "fh" line.
type FileHandles struct {
Stale uint64
TotalLookups uint64
AnonLookups uint64
DirNoCache uint64
NoDirNoCache uint64
}
// InputOutput models the "io" line.
type InputOutput struct {
Read uint64
Write uint64
}
// Threads models the "th" line.
type Threads struct {
Threads uint64
FullCnt uint64
}
// ReadAheadCache models the "ra" line.
type ReadAheadCache struct {
CacheSize uint64
CacheHistogram []uint64
NotFound uint64
}
// Network models the "net" line.
type Network struct {
NetCount uint64
UDPCount uint64
TCPCount uint64
TCPConnect uint64
}
// ClientRPC models the nfs "rpc" line.
type ClientRPC struct {
RPCCount uint64
Retransmissions uint64
AuthRefreshes uint64
}
// ServerRPC models the nfsd "rpc" line.
type ServerRPC struct {
RPCCount uint64
BadCnt uint64
BadFmt uint64
BadAuth uint64
BadcInt uint64
}
// V2Stats models the "proc2" line.
type V2Stats struct {
Null uint64
GetAttr uint64
SetAttr uint64
Root uint64
Lookup uint64
ReadLink uint64
Read uint64
WrCache uint64
Write uint64
Create uint64
Remove uint64
Rename uint64
Link uint64
SymLink uint64
MkDir uint64
RmDir uint64
ReadDir uint64
FsStat uint64
}
// V3Stats models the "proc3" line.
type V3Stats struct {
Null uint64
GetAttr uint64
SetAttr uint64
Lookup uint64
Access uint64
ReadLink uint64
Read uint64
Write uint64
Create uint64
MkDir uint64
SymLink uint64
MkNod uint64
Remove uint64
RmDir uint64
Rename uint64
Link uint64