diff --git a/clientutil/backend.go b/clientutil/backend.go index f74d539dacb47263ec35bee266ca949f25076dcb..6580d0eb42ae0070a803db7df6ce15d7da595e3b 100644 --- a/clientutil/backend.go +++ b/clientutil/backend.go @@ -1,15 +1,10 @@ package clientutil import ( - "crypto/tls" - "fmt" - "net/http" - "net/url" - "sync" - "time" + "context" ) -// BackendConfig specifies the configuration to access a service. +// BackendConfig specifies the configuration of a service backend. // // Services with multiple backends can be replicated or partitioned, // depending on a configuration switch, making it a deployment-time @@ -18,102 +13,30 @@ import ( // 'shard' parameter on their APIs. type BackendConfig struct { URL string `yaml:"url"` - Sharded bool `yaml:"sharded"` TLSConfig *TLSClientConfig `yaml:"tls_config"` + Sharded bool `yaml:"sharded"` + Debug bool `yaml:"debug"` } // Backend is a runtime class that provides http Clients for use with // a specific service backend. If the service can't be partitioned, -// pass an empty string to the Client method. +// pass an empty string to the Call method. type Backend interface { - // URL for the service for a specific shard. - URL(string) string + // Call a remote method. The sharding behavior is the following: + // + // Services that support sharding (partitioning) should always + // include the shard ID in their Call() requests. Users can + // then configure backends to be sharded or not in their + // Config. When invoking Call with a shard ID on a non-sharded + // service, the shard ID is simply ignored. Invoking Call + // *without* a shard ID on a sharded service is an error. + Call(context.Context, string, string, interface{}, interface{}) error - // Client that can be used to make a request to the service. - Client(string) *http.Client + // Close all resources associated with the backend. + Close() } // NewBackend returns a new Backend with the given config. func NewBackend(config *BackendConfig) (Backend, error) { - u, err := url.Parse(config.URL) - if err != nil { - return nil, err - } - - var tlsConfig *tls.Config - if config.TLSConfig != nil { - tlsConfig, err = config.TLSConfig.TLSConfig() - if err != nil { - return nil, err - } - } - - if !config.Sharded { - return &replicatedClient{ - u: u, - c: newHTTPClient(u, tlsConfig), - }, nil - } - return &shardedClient{ - baseURL: u, - tlsConfig: tlsConfig, - urls: make(map[string]*url.URL), - shards: make(map[string]*http.Client), - }, nil -} - -type replicatedClient struct { - c *http.Client - u *url.URL -} - -func (r *replicatedClient) Client(_ string) *http.Client { return r.c } -func (r *replicatedClient) URL(_ string) string { return r.u.String() } - -type shardedClient struct { - baseURL *url.URL - tlsConfig *tls.Config - mx sync.Mutex - urls map[string]*url.URL - shards map[string]*http.Client -} - -func (s *shardedClient) getShardURL(shard string) *url.URL { - if shard == "" { - return s.baseURL - } - u, ok := s.urls[shard] - if !ok { - var tmp = *s.baseURL - tmp.Host = fmt.Sprintf("%s.%s", shard, tmp.Host) - u = &tmp - s.urls[shard] = u - } - return u -} - -func (s *shardedClient) URL(shard string) string { - s.mx.Lock() - defer s.mx.Unlock() - return s.getShardURL(shard).String() -} - -func (s *shardedClient) Client(shard string) *http.Client { - s.mx.Lock() - defer s.mx.Unlock() - - client, ok := s.shards[shard] - if !ok { - u := s.getShardURL(shard) - client = newHTTPClient(u, s.tlsConfig) - s.shards[shard] = client - } - return client -} - -func newHTTPClient(u *url.URL, tlsConfig *tls.Config) *http.Client { - return &http.Client{ - Transport: NewTransport([]string{u.Host}, tlsConfig, nil), - Timeout: 30 * time.Second, - } + return newBalancedBackend(config, defaultResolver) } diff --git a/clientutil/backend_test.go b/clientutil/backend_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ac11093c8e2978c3261506a437eec21e804d4005 --- /dev/null +++ b/clientutil/backend_test.go @@ -0,0 +1,266 @@ +package clientutil + +import ( + "context" + "io" + "log" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +type tcpHandler interface { + Handle(net.Conn) +} + +type tcpHandlerFunc func(net.Conn) + +func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) } + +// Base TCP server type (to build fake LDAP servers). +type tcpServer struct { + l net.Listener + handler tcpHandler +} + +func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Listen():", err) + } + log.Printf("started new tcp server on %s", l.Addr().String()) + s := &tcpServer{l: l, handler: handler} + go s.serve() + return s +} + +func (s *tcpServer) serve() { + for { + conn, err := s.l.Accept() + if err != nil { + return + } + go func(c net.Conn) { + s.handler.Handle(c) + c.Close() + }(conn) + } +} + +func (s *tcpServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tcpServer) Close() { + s.l.Close() +} + +// A test server that will close all incoming connections right away. +func newConnFailServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {})) +} + +// A test server that will close all connections after a 1s delay. +func newConnFailDelayServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) })) +} + +type httpServer struct { + *httptest.Server +} + +func (s *httpServer) Addr() string { + u, _ := url.Parse(s.Server.URL) + return u.Host +} + +// An HTTP server that will always return a specific HTTP status using +// http.Error(). +func newErrorHTTPServer(statusCode int) *httpServer { + return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Connection", "close") + http.Error(w, "oh no", statusCode) + }))} +} + +func newJSONHTTPServer() *httpServer { + return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, "{\"value\": 42}") + }))} +} + +func newHostCountingJSONHTTPServer() (*httpServer, map[string]int) { + counters := make(map[string]int) + return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + counters[r.Host]++ + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, "{\"value\": 42}") + }))}, counters +} + +type testServer interface { + Addr() string + Close() +} + +type testBackends struct { + servers []testServer + addrs []string +} + +func newTestBackends(servers ...testServer) *testBackends { + b := new(testBackends) + for _, s := range servers { + b.servers = append(b.servers, s) + b.addrs = append(b.addrs, s.Addr()) + } + return b +} + +func (b *testBackends) ResolveIP(_ string) []string { + return b.addrs +} + +func (b *testBackends) stop(i int) { + b.servers[i].Close() +} + +func (b *testBackends) close() { + for _, s := range b.servers { + s.Close() + } +} + +// Do a number of fake requests to a test JSONHTTPServer. If shards is +// not nil, set up a fake sharded service and pick one of the given +// shards randomly on every request. +func doJSONRequests(backends *testBackends, u string, n int, shards []string) (int, int) { + b, err := newBalancedBackend(&BackendConfig{ + URL: u, + Debug: true, + Sharded: len(shards) > 0, + }, backends) + if err != nil { + panic(err) + } + defer b.Close() + + var errs, oks int + for i := 0; i < n; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + + var resp struct { + Value int `json:"value"` + } + var shard string + if len(shards) > 0 { + shard = shards[rand.Intn(len(shards))] + } + err = b.Call(ctx, "/", shard, struct{}{}, &resp) + cancel() + if err != nil { + errs++ + log.Printf("request error: %v", err) + } else if resp.Value != 42 { + errs++ + } else { + oks++ + } + } + + return oks, errs +} + +func TestBackend_TargetsDown(t *testing.T) { + b := newTestBackends(newJSONHTTPServer(), newJSONHTTPServer(), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } + + // Stop the first two backends, request should still succeed. + b.stop(0) + b.stop(1) + + oks, errs = doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks < 10 { + t.Fatalf("oks=%d", oks) + } +} + +func TestBackend_OverloadedTargets(t *testing.T) { + b := newTestBackends(newErrorHTTPServer(http.StatusTooManyRequests), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks < 10 { + t.Fatalf("oks=%d", oks) + } +} + +func TestBackend_BrokenTarget(t *testing.T) { + b := newTestBackends(newConnFailServer(t), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } +} + +func TestBackend_HighLatencyTarget(t *testing.T) { + b := newTestBackends(newConnFailDelayServer(t), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + // At most one request should fail (timing out). + if errs > 1 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } +} + +func TestBackend_Sharded(t *testing.T) { + srv, counters := newHostCountingJSONHTTPServer() + b := newTestBackends(srv) + defer b.close() + + // Make some requests to two different shards (simulated by a + // single http server), and count the Host headers seen. + shards := []string{"s1", "s2"} + oks, errs := doJSONRequests(b, "http://test/", 10, shards) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } + + for _, s := range shards { + n := counters[s+".test"] + if n == 0 { + t.Errorf("no requests for shard %s", s) + } + } +} diff --git a/clientutil/balancer.go b/clientutil/balancer.go new file mode 100644 index 0000000000000000000000000000000000000000..1f6df88ca84ad81c4e45622a860cdeab51ffe915 --- /dev/null +++ b/clientutil/balancer.go @@ -0,0 +1,271 @@ +package clientutil + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "log" + "math/rand" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/cenkalti/backoff" +) + +// Our own narrow logger interface. +type logger interface { + Printf(string, ...interface{}) +} + +// A nilLogger is used when Config.Debug is false. +type nilLogger struct{} + +func (l nilLogger) Printf(_ string, _ ...interface{}) {} + +// Parameters that define the exponential backoff algorithm used. +var ( + ExponentialBackOffInitialInterval = 100 * time.Millisecond + ExponentialBackOffMultiplier = 1.4142 +) + +// newExponentialBackOff creates a backoff.ExponentialBackOff object +// with our own default values. +func newExponentialBackOff() *backoff.ExponentialBackOff { + b := backoff.NewExponentialBackOff() + b.InitialInterval = ExponentialBackOffInitialInterval + b.Multiplier = ExponentialBackOffMultiplier + + // Set MaxElapsedTime to 0 because we expect the overall + // timeout to be dictated by the request Context. + b.MaxElapsedTime = 0 + + return b +} + +// Balancer for HTTP connections. It will round-robin across available +// backends, trying to avoid ones that are erroring out, until one +// succeeds or returns a permanent error. +// +// This object should not be used for load balancing of individual +// HTTP requests: it doesn't do anything smart beyond trying to avoid +// broken targets. It's meant to provide a *reliable* connection to a +// set of equivalent services for HA purposes. +type balancedBackend struct { + *backendTracker + *transportCache + baseURI *url.URL + sharded bool + resolver resolver + log logger +} + +func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) { + u, err := url.Parse(config.URL) + if err != nil { + return nil, err + } + + var tlsConfig *tls.Config + if config.TLSConfig != nil { + tlsConfig, err = config.TLSConfig.TLSConfig() + if err != nil { + return nil, err + } + } + + var logger logger = &nilLogger{} + if config.Debug { + logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0) + } + return &balancedBackend{ + backendTracker: newBackendTracker(u.Host, resolver, logger), + transportCache: newTransportCache(tlsConfig), + sharded: config.Sharded, + baseURI: u, + resolver: resolver, + log: logger, + }, nil +} + +// Call the backend. Makes an HTTP POST request to the specified uri, +// with a JSON-encoded request body. It will attempt to decode the +// response body as JSON. +func (b *balancedBackend) Call(ctx context.Context, path, shard string, req, resp interface{}) error { + data, err := json.Marshal(req) + if err != nil { + return err + } + + var tg targetGenerator = b.backendTracker + if b.sharded && shard != "" { + tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver) + } + seq := newSequence(tg) + b.log.Printf("%016x: initialized", seq.ID()) + + var httpResp *http.Response + err = backoff.Retry(func() error { + req, rerr := b.newJSONRequest(path, shard, data) + if rerr != nil { + return rerr + } + httpResp, rerr = b.do(ctx, seq, req) + return rerr + }, backoff.WithContext(newExponentialBackOff(), ctx)) + if err != nil { + return err + } + defer httpResp.Body.Close() // nolint + + if httpResp.Header.Get("Content-Type") != "application/json" { + return errors.New("not a JSON response") + } + + if resp == nil { + return nil + } + return json.NewDecoder(httpResp.Body).Decode(resp) +} + +// Return the URI to be used for the request. This is used both in the +// Host HTTP header and as the TLS server name used to pick a server +// certificate (if using TLS). +func (b *balancedBackend) getURIForRequest(shard, path string) string { + u := *b.baseURI + if b.sharded && shard != "" { + u.Host = fmt.Sprintf("%s.%s", shard, u.Host) + } + u.Path = appendPath(u.Path, path) + return u.String() +} + +// Build a http.Request object. +func (b *balancedBackend) newJSONRequest(path, shard string, data []byte) (*http.Request, error) { + req, err := http.NewRequest("POST", b.getURIForRequest(shard, path), bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Length", strconv.FormatInt(int64(len(data)), 10)) + return req, nil +} + +// Select a new target from the given sequence and send the request to +// it. Wrap HTTP errors in a RemoteError object. +func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Request) (resp *http.Response, err error) { + target, terr := seq.Next() + if terr != nil { + return + } + + b.log.Printf("sequence %016x: connecting to %s", seq.ID(), target) + client := &http.Client{ + Transport: b.transportCache.getTransport(target), + } + resp, err = client.Do(req.WithContext(ctx)) + if err == nil && resp.StatusCode != 200 { + err = remoteErrorFromResponse(resp) + if !isStatusTemporary(resp.StatusCode) { + err = backoff.Permanent(err) + } + resp.Body.Close() // nolint + resp = nil + } + + seq.Done(target, err) + return +} + +var errNoTargets = errors.New("no available backends") + +type targetGenerator interface { + getTargets() []string + setStatus(string, bool) +} + +// A replicatedSequence repeatedly iterates over available backends in order of +// preference. Once in a while it refreshes its list of available +// targets. +type sequence struct { + id uint64 + tg targetGenerator + targets []string + pos int +} + +func newSequence(tg targetGenerator) *sequence { + return &sequence{ + id: rand.Uint64(), + tg: tg, + targets: tg.getTargets(), + } +} + +func (s *sequence) ID() uint64 { return s.id } + +func (s *sequence) reloadTargets() { + targets := s.tg.getTargets() + if len(targets) > 0 { + s.targets = targets + s.pos = 0 + } +} + +// Next returns the next target. +func (s *sequence) Next() (t string, err error) { + if s.pos >= len(s.targets) { + s.reloadTargets() + if len(s.targets) == 0 { + err = errNoTargets + return + } + } + t = s.targets[s.pos] + s.pos++ + return +} + +func (s *sequence) Done(t string, err error) { + s.tg.setStatus(t, err == nil) +} + +// A shardedGenerator returns a single sharded target to a sequence. +type shardedGenerator struct { + id uint64 + addrs []string +} + +func newShardedGenerator(shard, base string, resolver resolver) *shardedGenerator { + return &shardedGenerator{ + id: rand.Uint64(), + addrs: resolver.ResolveIP(fmt.Sprintf("%s.%s", shard, base)), + } +} + +func (g *shardedGenerator) getTargets() []string { return g.addrs } +func (g *shardedGenerator) setStatus(_ string, _ bool) {} + +// Concatenate two URI paths. +func appendPath(a, b string) string { + if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") { + return a + b[1:] + } + return a + b +} + +// Some HTTP status codes are treated are temporary errors. +func isStatusTemporary(code int) bool { + switch code { + case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return true + default: + return false + } +} diff --git a/clientutil/cpu.prof b/clientutil/cpu.prof new file mode 100644 index 0000000000000000000000000000000000000000..89fe0a7d104f6c1ca9466d218abbfe393d431402 Binary files /dev/null and b/clientutil/cpu.prof differ diff --git a/clientutil/dns.go b/clientutil/dns.go new file mode 100644 index 0000000000000000000000000000000000000000..ed30f87342295aa2594c9302b1b1ecdee8014aee --- /dev/null +++ b/clientutil/dns.go @@ -0,0 +1,98 @@ +package clientutil + +import ( + "log" + "net" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +type resolver interface { + ResolveIP(string) []string +} + +type dnsResolver struct{} + +func (r *dnsResolver) ResolveIP(hostport string) []string { + var resolved []string + host, port, err := net.SplitHostPort(hostport) + if err != nil { + log.Printf("error parsing %s: %v", hostport, err) + return nil + } + hostIPs, err := net.LookupIP(host) + if err != nil { + log.Printf("error resolving %s: %v", host, err) + return nil + } + for _, ip := range hostIPs { + resolved = append(resolved, net.JoinHostPort(ip.String(), port)) + } + return resolved +} + +var defaultResolver = newDNSCache(&dnsResolver{}) + +type cacheDatum struct { + addrs []string + deadline time.Time +} + +type dnsCache struct { + resolver resolver + sf singleflight.Group + mx sync.RWMutex + cache map[string]cacheDatum +} + +func newDNSCache(resolver resolver) *dnsCache { + return &dnsCache{ + resolver: resolver, + cache: make(map[string]cacheDatum), + } +} + +func (c *dnsCache) get(host string) ([]string, bool) { + d, ok := c.cache[host] + if !ok { + return nil, false + } + return d.addrs, d.deadline.After(time.Now()) +} + +func (c *dnsCache) update(host string) []string { + v, _, _ := c.sf.Do(host, func() (interface{}, error) { + addrs := c.resolver.ResolveIP(host) + // By uncommenting this, we stop caching negative results. + // if len(addrs) == 0 { + // return nil, nil + // } + c.mx.Lock() + c.cache[host] = cacheDatum{ + addrs: addrs, + deadline: time.Now().Add(60 * time.Second), + } + c.mx.Unlock() + return addrs, nil + }) + return v.([]string) +} + +func (c *dnsCache) ResolveIP(host string) []string { + c.mx.RLock() + addrs, ok := c.get(host) + c.mx.RUnlock() + + if ok { + return addrs + } + + if len(addrs) > 0 { + go c.update(host) + return addrs + } + + return c.update(host) +} diff --git a/clientutil/dns_test.go b/clientutil/dns_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3162c5e7bb5884b1367383b613f96f82f6b79b4c --- /dev/null +++ b/clientutil/dns_test.go @@ -0,0 +1,27 @@ +package clientutil + +import "testing" + +type fakeResolver struct { + addrs []string + requests int +} + +func (r *fakeResolver) ResolveIP(host string) []string { + r.requests++ + return r.addrs +} + +func TestDNSCache(t *testing.T) { + r := &fakeResolver{addrs: []string{"1.2.3.4"}} + c := newDNSCache(r) + for i := 0; i < 5; i++ { + addrs := c.ResolveIP("a.b.c.d") + if len(addrs) != 1 { + t.Errorf("ResolveIP returned bad response: %v", addrs) + } + } + if r.requests != 1 { + t.Errorf("cached resolver has wrong number of requests: %d, expecting 1", r.requests) + } +} diff --git a/clientutil/doc.go b/clientutil/doc.go new file mode 100644 index 0000000000000000000000000000000000000000..421915b66a9818d943c5fcbf80c39cfc2d141a2c --- /dev/null +++ b/clientutil/doc.go @@ -0,0 +1,37 @@ +// Package clientutil implements a very simple style of JSON RPC. +// +// Requests and responses are both encoded in JSON, and they should +// have the "application/json" Content-Type. +// +// HTTP response statuses other than 200 indicate an error: in this +// case, the response body may contain (in plain text) further details +// about the error. Some HTTP status codes are considered temporary +// errors (incl. 429 for throttling). The client will retry requests, +// if targets are available, until the context expires - so it's quite +// important to remember to set a timeout on the context given to the +// Call() function! +// +// The client handles both replicated services and sharded +// (partitioned) services. Users of this package that want to support +// sharded deployments are supposed to pass a shard ID to every +// Call(). At the deployment stage, sharding can be enabled via the +// configuration. +// +// For replicated services, the client will expect the provided +// hostname to resolve to one or more IP addresses, in which case it +// will pick a random IP address on every new request, while +// remembering which addresses have had errors and trying to avoid +// them. It will however send an occasional request to the failed +// targets, to see if they've come back. +// +// For sharded services, the client makes simple HTTP requests to the +// specific target identified by the shard. It does this by prepending +// the shard ID to the backend hostname (so a request to "example.com" +// with shard ID "1" becomes a request to "1.example.com"). +// +// The difference with other JSON-RPC implementations is that we use a +// different URI for every method, and we force the usage of +// request/response types. This makes it easy for projects to +// eventually migrate to GRPC. +// +package clientutil diff --git a/clientutil/error.go b/clientutil/error.go new file mode 100644 index 0000000000000000000000000000000000000000..f011e162c39d835952bda45b846b17be3a308336 --- /dev/null +++ b/clientutil/error.go @@ -0,0 +1,35 @@ +package clientutil + +import ( + "fmt" + "io/ioutil" + "net/http" +) + +// RemoteError represents a HTTP error from the server. The status +// code and response body can be retrieved with the StatusCode() and +// Body() methods. +type RemoteError struct { + statusCode int + body string +} + +func remoteErrorFromResponse(resp *http.Response) *RemoteError { + // Optimistically read the response body, ignoring errors. + var body string + if data, err := ioutil.ReadAll(resp.Body); err == nil { + body = string(data) + } + return &RemoteError{statusCode: resp.StatusCode, body: body} +} + +// Error implements the error interface. +func (e *RemoteError) Error() string { + return fmt.Sprintf("%d - %s", e.statusCode, e.body) +} + +// StatusCode returns the HTTP status code. +func (e *RemoteError) StatusCode() int { return e.statusCode } + +// Body returns the response body. +func (e *RemoteError) Body() string { return e.body } diff --git a/clientutil/json.go b/clientutil/json.go deleted file mode 100644 index 5fc1ab2e4ab75061c44e816c79a874c580c73a73..0000000000000000000000000000000000000000 --- a/clientutil/json.go +++ /dev/null @@ -1,45 +0,0 @@ -package clientutil - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "net/http" -) - -// DoJSONHTTPRequest makes an HTTP POST request to the specified uri, -// with a JSON-encoded request body. It will attempt to decode the -// response body as JSON. -func DoJSONHTTPRequest(ctx context.Context, client *http.Client, uri string, req, resp interface{}) error { - data, err := json.Marshal(req) - if err != nil { - return err - } - - httpReq, err := http.NewRequest("POST", uri, bytes.NewReader(data)) - if err != nil { - return err - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq = httpReq.WithContext(ctx) - - httpResp, err := RetryHTTPDo(client, httpReq, NewExponentialBackOff()) - if err != nil { - return err - } - defer httpResp.Body.Close() - - if httpResp.StatusCode != 200 { - return fmt.Errorf("HTTP status %d", httpResp.StatusCode) - } - if httpResp.Header.Get("Content-Type") != "application/json" { - return errors.New("not a JSON response") - } - - if resp == nil { - return nil - } - return json.NewDecoder(httpResp.Body).Decode(resp) -} diff --git a/clientutil/retry.go b/clientutil/retry.go deleted file mode 100644 index 3ca7b51a48289a18cd829947cbdbc400392ffc93..0000000000000000000000000000000000000000 --- a/clientutil/retry.go +++ /dev/null @@ -1,92 +0,0 @@ -package clientutil - -import ( - "errors" - "net/http" - "time" - - "github.com/cenkalti/backoff" -) - -// NewExponentialBackOff creates a backoff.ExponentialBackOff object -// with our own default values. -func NewExponentialBackOff() *backoff.ExponentialBackOff { - b := backoff.NewExponentialBackOff() - b.InitialInterval = 100 * time.Millisecond - //b.Multiplier = 1.4142 - return b -} - -// A temporary (retriable) error is something that has a Temporary method. -type tempError interface { - Temporary() bool -} - -type tempErrorWrapper struct { - error -} - -func (t tempErrorWrapper) Temporary() bool { return true } - -// TempError makes a temporary (retriable) error out of a normal error. -func TempError(err error) error { - return tempErrorWrapper{err} -} - -// Retry operation op until it succeeds according to the backoff -// policy b. -// -// Note that this function reverses the error semantics of -// backoff.Operation: all errors are permanent unless explicitly -// marked as temporary (i.e. they have a Temporary() method that -// returns true). This is to better align with the errors returned by -// the net package. -func Retry(op backoff.Operation, b backoff.BackOff) error { - innerOp := func() error { - err := op() - if err == nil { - return err - } - if tmpErr, ok := err.(tempError); ok && tmpErr.Temporary() { - return err - } - return backoff.Permanent(err) - } - return backoff.Retry(innerOp, b) -} - -var errHTTPBackOff = TempError(errors.New("temporary http error")) - -func isStatusTemporary(code int) bool { - switch code { - case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: - return true - default: - return false - } -} - -// RetryHTTPDo retries an HTTP request until it succeeds, according to -// the backoff policy b. It will retry on temporary network errors and -// upon receiving specific temporary HTTP errors. It will use the -// context associated with the HTTP request object. -func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*http.Response, error) { - var resp *http.Response - op := func() error { - // Clear up previous response if set. - if resp != nil { - resp.Body.Close() - } - - var err error - resp, err = client.Do(req) - if err == nil && isStatusTemporary(resp.StatusCode) { - resp.Body.Close() - return errHTTPBackOff - } - return err - } - - err := Retry(op, backoff.WithContext(b, req.Context())) - return resp, err -} diff --git a/clientutil/retry_test.go b/clientutil/retry_test.go deleted file mode 100644 index b7d5f03f128c4a6e6c5cb553ff40075e2f1dbcb8..0000000000000000000000000000000000000000 --- a/clientutil/retry_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package clientutil - -import ( - "context" - "io" - "io/ioutil" - "log" - "net" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" -) - -type tcpHandler interface { - Handle(net.Conn) -} - -type tcpHandlerFunc func(net.Conn) - -func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) } - -// Base TCP server type (to build fake LDAP servers). -type tcpServer struct { - l net.Listener - handler tcpHandler -} - -func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal("Listen():", err) - } - log.Printf("started new tcp server on %s", l.Addr().String()) - s := &tcpServer{l: l, handler: handler} - go s.serve() - return s -} - -func (s *tcpServer) serve() { - for { - conn, err := s.l.Accept() - if err != nil { - return - } - go func(c net.Conn) { - s.handler.Handle(c) - c.Close() - }(conn) - } -} - -func (s *tcpServer) Addr() string { - return s.l.Addr().String() -} - -func (s *tcpServer) Close() { - s.l.Close() -} - -// A test server that will close all incoming connections right away. -func newConnFailServer(t testing.TB) *tcpServer { - return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {})) -} - -// A test server that will close all connections after a 1s delay. -func newConnFailDelayServer(t testing.TB) *tcpServer { - return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) })) -} - -type httpServer struct { - *httptest.Server -} - -func (s *httpServer) Addr() string { - u, _ := url.Parse(s.Server.URL) - return u.Host -} - -// An HTTP server that will always return a specific HTTP status. -func newStatusHTTPServer(statusCode int) *httpServer { - return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(statusCode) - io.WriteString(w, "hello\n") - }))} -} - -func newOKHTTPServer() *httpServer { - return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - io.WriteString(w, "OK") - }))} -} - -type testServer interface { - Addr() string - Close() -} - -type testBackends struct { - servers []testServer - addrs []string -} - -func newTestBackends(servers ...testServer) *testBackends { - b := new(testBackends) - for _, s := range servers { - b.servers = append(b.servers, s) - b.addrs = append(b.addrs, s.Addr()) - } - return b -} - -func (b *testBackends) ResolveIPs(_ []string) []string { - return b.addrs -} - -func (b *testBackends) stop(i int) { - b.servers[i].Close() -} - -func (b *testBackends) close() { - for _, s := range b.servers { - s.Close() - } -} - -func doRequests(backends *testBackends, u string, n int) (int, int) { - c := &http.Client{ - Transport: NewTransport([]string{"backend"}, nil, backends), - } - b := NewExponentialBackOff() - - var errs, oks int - for i := 0; i < n; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - req, _ := http.NewRequest("GET", u, nil) - resp, err := RetryHTTPDo(c, req.WithContext(ctx), b) - cancel() - if err != nil { - errs++ - continue - } - _, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 200 { - errs++ - continue - } - if err != nil { - errs++ - continue - } - oks++ - } - - return oks, errs -} - -func TestRetryHTTP_BackendsDown(t *testing.T) { - b := newTestBackends(newOKHTTPServer(), newOKHTTPServer(), newOKHTTPServer()) - defer b.close() - - oks, errs := doRequests(b, "http://backend/", 100) - if errs > 0 { - t.Fatalf("errs=%d", errs) - } - if oks == 0 { - t.Fatal("oks=0") - } - - b.stop(0) - b.stop(1) - - oks, errs = doRequests(b, "http://backend/", 100) - if errs > 0 { - t.Fatalf("errs=%d", errs) - } - if oks == 0 { - t.Fatal("oks=0") - } -} - -func TestRetryHTTP_HighLatencyBackend(t *testing.T) { - b := newTestBackends(newConnFailDelayServer(t), newOKHTTPServer()) - defer b.close() - - _, _ = doRequests(b, "http://backend/", 10) - // Silly transport.go load balancer only balances connections, - // so in this scenario we'll just keep hitting the slow - // server, exhausting our deadline budget, and never fail over - // to the secondary backend :( - // if errs > 0 { - // t.Fatalf("errs=%d", errs) - // } - // if oks == 0 { - // t.Fatal("oks=0") - // } -} diff --git a/clientutil/track.go b/clientutil/track.go new file mode 100644 index 0000000000000000000000000000000000000000..2db20bbbf4d79ce2377e6eb7781f5c9b65567492 --- /dev/null +++ b/clientutil/track.go @@ -0,0 +1,123 @@ +package clientutil + +import ( + "math/rand" + "sync" + "time" +) + +// The backendTracker tracks the state of the targets associated with +// a backend, and periodically checks DNS for updates. +type backendTracker struct { + log logger + addr string + resolver resolver + stopCh chan struct{} + + mx sync.Mutex + resolved []string + failed map[string]time.Time +} + +func newBackendTracker(addr string, resolver resolver, logger logger) *backendTracker { + // Resolve the targets once before returning. + b := &backendTracker{ + addr: addr, + resolver: resolver, + resolved: resolver.ResolveIP(addr), + failed: make(map[string]time.Time), + stopCh: make(chan struct{}), + log: logger, + } + go b.updateProc() + return b +} + +func (b *backendTracker) Close() { + close(b.stopCh) +} + +// Return the full list of targets in reverse preference order. +func (b *backendTracker) getTargets() []string { + b.mx.Lock() + defer b.mx.Unlock() + + var good, bad []string + for _, t := range b.resolved { + if _, ok := b.failed[t]; ok { + bad = append(bad, t) + } else { + good = append(good, t) + } + } + + good = shuffle(good) + bad = shuffle(bad) + + return append(good, bad...) +} + +func (b *backendTracker) setStatus(addr string, ok bool) { + b.mx.Lock() + + _, isFailed := b.failed[addr] + if isFailed && ok { + b.log.Printf("target %s now ok", addr) + delete(b.failed, addr) + } else if !isFailed && !ok { + b.log.Printf("target %s failed", addr) + b.failed[addr] = time.Now() + } + + b.mx.Unlock() +} + +var ( + backendUpdateInterval = 60 * time.Second + backendFailureRetryInterval = 60 * time.Second +) + +func (b *backendTracker) expireFailedTargets() { + b.mx.Lock() + now := time.Now() + for k, v := range b.failed { + if now.Sub(v) > backendFailureRetryInterval { + delete(b.failed, k) + } + } + b.mx.Unlock() +} + +func (b *backendTracker) updateProc() { + tick := time.NewTicker(backendUpdateInterval) + defer tick.Stop() + for { + select { + case <-b.stopCh: + return + case <-tick.C: + b.expireFailedTargets() + resolved := b.resolver.ResolveIP(b.addr) + if len(resolved) > 0 { + b.mx.Lock() + b.resolved = resolved + b.mx.Unlock() + } + } + } +} + +var shuffleSrc = rand.NewSource(time.Now().UnixNano()) + +// Re-order elements of a slice randomly. +func shuffle(values []string) []string { + if len(values) < 2 { + return values + } + rnd := rand.New(shuffleSrc) + for i := len(values) - 1; i > 0; i-- { + j := rnd.Intn(i + 1) + values[i], values[j] = values[j], values[i] + } + return values +} diff --git a/clientutil/transport.go b/clientutil/transport.go index e4f98e3f60ccbf214ff10db1b2860bb6bb1d8ed0..843a760b1e25eb1c21b2c8552b9c6ad107ea6d34 100644 --- a/clientutil/transport.go +++ b/clientutil/transport.go @@ -3,170 +3,63 @@ package clientutil import ( "context" "crypto/tls" - "errors" - "log" "net" "net/http" "sync" "time" ) -var errAllBackendsFailed = errors.New("all backends failed") - -type dnsResolver struct{} - -func (r *dnsResolver) ResolveIPs(hosts []string) []string { - var resolved []string - for _, hostport := range hosts { - host, port, err := net.SplitHostPort(hostport) - if err != nil { - log.Printf("error parsing %s: %v", hostport, err) - continue - } - hostIPs, err := net.LookupIP(host) - if err != nil { - log.Printf("error resolving %s: %v", host, err) - continue - } - for _, ip := range hostIPs { - resolved = append(resolved, net.JoinHostPort(ip.String(), port)) - } - } - return resolved -} - -var defaultResolver = &dnsResolver{} - -type resolver interface { - ResolveIPs([]string) []string -} - -// Balancer for HTTP connections. It will round-robin across available -// backends, trying to avoid ones that are erroring out, until one -// succeeds or they all fail. +// The transportCache is just a cache of http transports, each +// connecting to a specific address. // -// This object should not be used for load balancing of individual -// HTTP requests: once a new connection is established, requests will -// be sent over it until it errors out. It's meant to provide a -// *reliable* connection to a set of equivalent backends for HA -// purposes. -type balancer struct { - hosts []string - resolver resolver - stop chan bool +// We use this to control the HTTP Host header and the TLS ServerName +// independently of the target address. +type transportCache struct { + tlsConfig *tls.Config - // List of currently valid (or untested) backends, and ones - // that errored out at least once. - mx sync.Mutex - addrs []string - ok map[string]bool + mx sync.RWMutex + transports map[string]http.RoundTripper } -var backendUpdateInterval = 60 * time.Second +func newTransportCache(tlsConfig *tls.Config) *transportCache { + return &transportCache{ + tlsConfig: tlsConfig, + transports: make(map[string]http.RoundTripper), + } +} -// Periodically update the list of available backends. -func (b *balancer) updateProc() { - tick := time.NewTicker(backendUpdateInterval) - for { - select { - case <-b.stop: - return - case <-tick.C: - resolved := b.resolver.ResolveIPs(b.hosts) - if len(resolved) > 0 { - b.mx.Lock() - b.addrs = resolved - b.mx.Unlock() - } - } +func (m *transportCache) newTransport(addr string) http.RoundTripper { + return &http.Transport{ + TLSClientConfig: m.tlsConfig, + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { + return netDialContext(ctx, network, addr) + }, } } -// Returns a list of all available backends, split into "good ones" -// (no errors seen since last successful connection) and "bad ones". -func (b *balancer) getBackends() ([]string, []string) { - b.mx.Lock() - defer b.mx.Unlock() +func (m *transportCache) getTransport(addr string) http.RoundTripper { + m.mx.RLock() + t, ok := m.transports[addr] + m.mx.RUnlock() - var good, bad []string - for _, addr := range b.addrs { - if ok := b.ok[addr]; ok { - good = append(good, addr) - } else { - bad = append(bad, addr) + if !ok { + m.mx.Lock() + if t, ok = m.transports[addr]; !ok { + t = m.newTransport(addr) + m.transports[addr] = t } + m.mx.Unlock() } - return good, bad -} -func (b *balancer) notify(addr string, ok bool) { - b.mx.Lock() - b.ok[addr] = ok - b.mx.Unlock() + return t } +// Go < 1.9 does not have net.DialContext, reimplement it in terms of +// net.DialTimeout. func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) { - timeout := 30 * time.Second - // Go < 1.9 does not have net.DialContext, reimplement it in - // terms of net.DialTimeout. + timeout := 60 * time.Second // some arbitrary max timeout if deadline, ok := ctx.Deadline(); ok { timeout = time.Until(deadline) } return net.DialTimeout(network, addr, timeout) } - -func (b *balancer) dial(ctx context.Context, network, addr string) (net.Conn, error) { - // Start by attempting a connection on 'good' targets. - good, bad := b.getBackends() - - for _, addr := range good { - // Go < 1.9 does not have DialContext, deal with it - conn, err := netDialContext(ctx, network, addr) - if err == nil { - return conn, nil - } else if err == context.Canceled { - // A timeout might be bad, set the error bit - // on the connection. - b.notify(addr, false) - return nil, err - } - b.notify(addr, false) - } - - for _, addr := range bad { - conn, err := netDialContext(ctx, network, addr) - if err == nil { - b.notify(addr, true) - return conn, nil - } else if err == context.Canceled { - return nil, err - } - } - - return nil, errAllBackendsFailed -} - -// NewTransport returns a suitably configured http.RoundTripper that -// talks to a specific backend service. It performs discovery of -// available backends via DNS (using A or AAAA record lookups), tries -// to route traffic away from faulty backends. -// -// It will periodically attempt to rediscover new backends. -func NewTransport(backends []string, tlsConf *tls.Config, resolver resolver) http.RoundTripper { - if resolver == nil { - resolver = defaultResolver - } - addrs := resolver.ResolveIPs(backends) - b := &balancer{ - hosts: backends, - resolver: resolver, - addrs: addrs, - ok: make(map[string]bool), - } - go b.updateProc() - - return &http.Transport{ - DialContext: b.dial, - TLSClientConfig: tlsConf, - } -}