From 193e29e61c81e6bb5548bfe89bba05836d06b61f Mon Sep 17 00:00:00 2001 From: ale <ale@incal.net> Date: Thu, 16 Aug 2018 22:47:58 +0100 Subject: [PATCH] Refactor the clientutil package Implement a simpler API for the Backend interface, removing most old public methods and replacing them with a single Call() method, making the package look a bit more like an actual rpc package (so, hopefully, easier to replace in the future). --- clientutil/backend.go | 111 +++------------ clientutil/backend_test.go | 266 ++++++++++++++++++++++++++++++++++++ clientutil/balancer.go | 271 +++++++++++++++++++++++++++++++++++++ clientutil/cpu.prof | Bin 0 -> 4566 bytes clientutil/dns.go | 98 ++++++++++++++ clientutil/dns_test.go | 27 ++++ clientutil/doc.go | 37 +++++ clientutil/error.go | 35 +++++ clientutil/json.go | 45 ------ clientutil/retry.go | 92 ------------- clientutil/retry_test.go | 200 --------------------------- clientutil/track.go | 123 +++++++++++++++++ clientutil/transport.go | 175 +++++------------------- 13 files changed, 908 insertions(+), 572 deletions(-) create mode 100644 clientutil/backend_test.go create mode 100644 clientutil/balancer.go create mode 100644 clientutil/cpu.prof create mode 100644 clientutil/dns.go create mode 100644 clientutil/dns_test.go create mode 100644 clientutil/doc.go create mode 100644 clientutil/error.go delete mode 100644 clientutil/json.go delete mode 100644 clientutil/retry.go delete mode 100644 clientutil/retry_test.go create mode 100644 clientutil/track.go diff --git a/clientutil/backend.go b/clientutil/backend.go index f74d539..6580d0e 100644 --- a/clientutil/backend.go +++ b/clientutil/backend.go @@ -1,15 +1,10 @@ package clientutil import ( - "crypto/tls" - "fmt" - "net/http" - "net/url" - "sync" - "time" + "context" ) -// BackendConfig specifies the configuration to access a service. +// BackendConfig specifies the configuration of a service backend. // // Services with multiple backends can be replicated or partitioned, // depending on a configuration switch, making it a deployment-time @@ -18,102 +13,30 @@ import ( // 'shard' parameter on their APIs. type BackendConfig struct { URL string `yaml:"url"` - Sharded bool `yaml:"sharded"` TLSConfig *TLSClientConfig `yaml:"tls_config"` + Sharded bool `yaml:"sharded"` + Debug bool `yaml:"debug"` } // Backend is a runtime class that provides http Clients for use with // a specific service backend. If the service can't be partitioned, -// pass an empty string to the Client method. +// pass an empty string to the Call method. type Backend interface { - // URL for the service for a specific shard. - URL(string) string + // Call a remote method. The sharding behavior is the following: + // + // Services that support sharding (partitioning) should always + // include the shard ID in their Call() requests. Users can + // then configure backends to be sharded or not in their + // Config. When invoking Call with a shard ID on a non-sharded + // service, the shard ID is simply ignored. Invoking Call + // *without* a shard ID on a sharded service is an error. + Call(context.Context, string, string, interface{}, interface{}) error - // Client that can be used to make a request to the service. - Client(string) *http.Client + // Close all resources associated with the backend. + Close() } // NewBackend returns a new Backend with the given config. func NewBackend(config *BackendConfig) (Backend, error) { - u, err := url.Parse(config.URL) - if err != nil { - return nil, err - } - - var tlsConfig *tls.Config - if config.TLSConfig != nil { - tlsConfig, err = config.TLSConfig.TLSConfig() - if err != nil { - return nil, err - } - } - - if !config.Sharded { - return &replicatedClient{ - u: u, - c: newHTTPClient(u, tlsConfig), - }, nil - } - return &shardedClient{ - baseURL: u, - tlsConfig: tlsConfig, - urls: make(map[string]*url.URL), - shards: make(map[string]*http.Client), - }, nil -} - -type replicatedClient struct { - c *http.Client - u *url.URL -} - -func (r *replicatedClient) Client(_ string) *http.Client { return r.c } -func (r *replicatedClient) URL(_ string) string { return r.u.String() } - -type shardedClient struct { - baseURL *url.URL - tlsConfig *tls.Config - mx sync.Mutex - urls map[string]*url.URL - shards map[string]*http.Client -} - -func (s *shardedClient) getShardURL(shard string) *url.URL { - if shard == "" { - return s.baseURL - } - u, ok := s.urls[shard] - if !ok { - var tmp = *s.baseURL - tmp.Host = fmt.Sprintf("%s.%s", shard, tmp.Host) - u = &tmp - s.urls[shard] = u - } - return u -} - -func (s *shardedClient) URL(shard string) string { - s.mx.Lock() - defer s.mx.Unlock() - return s.getShardURL(shard).String() -} - -func (s *shardedClient) Client(shard string) *http.Client { - s.mx.Lock() - defer s.mx.Unlock() - - client, ok := s.shards[shard] - if !ok { - u := s.getShardURL(shard) - client = newHTTPClient(u, s.tlsConfig) - s.shards[shard] = client - } - return client -} - -func newHTTPClient(u *url.URL, tlsConfig *tls.Config) *http.Client { - return &http.Client{ - Transport: NewTransport([]string{u.Host}, tlsConfig, nil), - Timeout: 30 * time.Second, - } + return newBalancedBackend(config, defaultResolver) } diff --git a/clientutil/backend_test.go b/clientutil/backend_test.go new file mode 100644 index 0000000..ac11093 --- /dev/null +++ b/clientutil/backend_test.go @@ -0,0 +1,266 @@ +package clientutil + +import ( + "context" + "io" + "log" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +type tcpHandler interface { + Handle(net.Conn) +} + +type tcpHandlerFunc func(net.Conn) + +func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) } + +// Base TCP server type (to build fake LDAP servers). +type tcpServer struct { + l net.Listener + handler tcpHandler +} + +func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Listen():", err) + } + log.Printf("started new tcp server on %s", l.Addr().String()) + s := &tcpServer{l: l, handler: handler} + go s.serve() + return s +} + +func (s *tcpServer) serve() { + for { + conn, err := s.l.Accept() + if err != nil { + return + } + go func(c net.Conn) { + s.handler.Handle(c) + c.Close() + }(conn) + } +} + +func (s *tcpServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tcpServer) Close() { + s.l.Close() +} + +// A test server that will close all incoming connections right away. +func newConnFailServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {})) +} + +// A test server that will close all connections after a 1s delay. +func newConnFailDelayServer(t testing.TB) *tcpServer { + return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) })) +} + +type httpServer struct { + *httptest.Server +} + +func (s *httpServer) Addr() string { + u, _ := url.Parse(s.Server.URL) + return u.Host +} + +// An HTTP server that will always return a specific HTTP status using +// http.Error(). +func newErrorHTTPServer(statusCode int) *httpServer { + return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Connection", "close") + http.Error(w, "oh no", statusCode) + }))} +} + +func newJSONHTTPServer() *httpServer { + return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, "{\"value\": 42}") + }))} +} + +func newHostCountingJSONHTTPServer() (*httpServer, map[string]int) { + counters := make(map[string]int) + return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + counters[r.Host]++ + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, "{\"value\": 42}") + }))}, counters +} + +type testServer interface { + Addr() string + Close() +} + +type testBackends struct { + servers []testServer + addrs []string +} + +func newTestBackends(servers ...testServer) *testBackends { + b := new(testBackends) + for _, s := range servers { + b.servers = append(b.servers, s) + b.addrs = append(b.addrs, s.Addr()) + } + return b +} + +func (b *testBackends) ResolveIP(_ string) []string { + return b.addrs +} + +func (b *testBackends) stop(i int) { + b.servers[i].Close() +} + +func (b *testBackends) close() { + for _, s := range b.servers { + s.Close() + } +} + +// Do a number of fake requests to a test JSONHTTPServer. If shards is +// not nil, set up a fake sharded service and pick one of the given +// shards randomly on every request. +func doJSONRequests(backends *testBackends, u string, n int, shards []string) (int, int) { + b, err := newBalancedBackend(&BackendConfig{ + URL: u, + Debug: true, + Sharded: len(shards) > 0, + }, backends) + if err != nil { + panic(err) + } + defer b.Close() + + var errs, oks int + for i := 0; i < n; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + + var resp struct { + Value int `json:"value"` + } + var shard string + if len(shards) > 0 { + shard = shards[rand.Intn(len(shards))] + } + err = b.Call(ctx, "/", shard, struct{}{}, &resp) + cancel() + if err != nil { + errs++ + log.Printf("request error: %v", err) + } else if resp.Value != 42 { + errs++ + } else { + oks++ + } + } + + return oks, errs +} + +func TestBackend_TargetsDown(t *testing.T) { + b := newTestBackends(newJSONHTTPServer(), newJSONHTTPServer(), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } + + // Stop the first two backends, request should still succeed. + b.stop(0) + b.stop(1) + + oks, errs = doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks < 10 { + t.Fatalf("oks=%d", oks) + } +} + +func TestBackend_OverloadedTargets(t *testing.T) { + b := newTestBackends(newErrorHTTPServer(http.StatusTooManyRequests), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks < 10 { + t.Fatalf("oks=%d", oks) + } +} + +func TestBackend_BrokenTarget(t *testing.T) { + b := newTestBackends(newConnFailServer(t), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } +} + +func TestBackend_HighLatencyTarget(t *testing.T) { + b := newTestBackends(newConnFailDelayServer(t), newJSONHTTPServer()) + defer b.close() + + oks, errs := doJSONRequests(b, "http://test/", 10, nil) + // At most one request should fail (timing out). + if errs > 1 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } +} + +func TestBackend_Sharded(t *testing.T) { + srv, counters := newHostCountingJSONHTTPServer() + b := newTestBackends(srv) + defer b.close() + + // Make some requests to two different shards (simulated by a + // single http server), and count the Host headers seen. + shards := []string{"s1", "s2"} + oks, errs := doJSONRequests(b, "http://test/", 10, shards) + if errs > 0 { + t.Fatalf("errs=%d", errs) + } + if oks == 0 { + t.Fatal("oks=0") + } + + for _, s := range shards { + n := counters[s+".test"] + if n == 0 { + t.Errorf("no requests for shard %s", s) + } + } +} diff --git a/clientutil/balancer.go b/clientutil/balancer.go new file mode 100644 index 0000000..1f6df88 --- /dev/null +++ b/clientutil/balancer.go @@ -0,0 +1,271 @@ +package clientutil + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "log" + "math/rand" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/cenkalti/backoff" +) + +// Our own narrow logger interface. +type logger interface { + Printf(string, ...interface{}) +} + +// A nilLogger is used when Config.Debug is false. +type nilLogger struct{} + +func (l nilLogger) Printf(_ string, _ ...interface{}) {} + +// Parameters that define the exponential backoff algorithm used. +var ( + ExponentialBackOffInitialInterval = 100 * time.Millisecond + ExponentialBackOffMultiplier = 1.4142 +) + +// newExponentialBackOff creates a backoff.ExponentialBackOff object +// with our own default values. +func newExponentialBackOff() *backoff.ExponentialBackOff { + b := backoff.NewExponentialBackOff() + b.InitialInterval = ExponentialBackOffInitialInterval + b.Multiplier = ExponentialBackOffMultiplier + + // Set MaxElapsedTime to 0 because we expect the overall + // timeout to be dictated by the request Context. + b.MaxElapsedTime = 0 + + return b +} + +// Balancer for HTTP connections. It will round-robin across available +// backends, trying to avoid ones that are erroring out, until one +// succeeds or returns a permanent error. +// +// This object should not be used for load balancing of individual +// HTTP requests: it doesn't do anything smart beyond trying to avoid +// broken targets. It's meant to provide a *reliable* connection to a +// set of equivalent services for HA purposes. +type balancedBackend struct { + *backendTracker + *transportCache + baseURI *url.URL + sharded bool + resolver resolver + log logger +} + +func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) { + u, err := url.Parse(config.URL) + if err != nil { + return nil, err + } + + var tlsConfig *tls.Config + if config.TLSConfig != nil { + tlsConfig, err = config.TLSConfig.TLSConfig() + if err != nil { + return nil, err + } + } + + var logger logger = &nilLogger{} + if config.Debug { + logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0) + } + return &balancedBackend{ + backendTracker: newBackendTracker(u.Host, resolver, logger), + transportCache: newTransportCache(tlsConfig), + sharded: config.Sharded, + baseURI: u, + resolver: resolver, + log: logger, + }, nil +} + +// Call the backend. Makes an HTTP POST request to the specified uri, +// with a JSON-encoded request body. It will attempt to decode the +// response body as JSON. +func (b *balancedBackend) Call(ctx context.Context, path, shard string, req, resp interface{}) error { + data, err := json.Marshal(req) + if err != nil { + return err + } + + var tg targetGenerator = b.backendTracker + if b.sharded && shard != "" { + tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver) + } + seq := newSequence(tg) + b.log.Printf("%016x: initialized", seq.ID()) + + var httpResp *http.Response + err = backoff.Retry(func() error { + req, rerr := b.newJSONRequest(path, shard, data) + if rerr != nil { + return rerr + } + httpResp, rerr = b.do(ctx, seq, req) + return rerr + }, backoff.WithContext(newExponentialBackOff(), ctx)) + if err != nil { + return err + } + defer httpResp.Body.Close() // nolint + + if httpResp.Header.Get("Content-Type") != "application/json" { + return errors.New("not a JSON response") + } + + if resp == nil { + return nil + } + return json.NewDecoder(httpResp.Body).Decode(resp) +} + +// Return the URI to be used for the request. This is used both in the +// Host HTTP header and as the TLS server name used to pick a server +// certificate (if using TLS). +func (b *balancedBackend) getURIForRequest(shard, path string) string { + u := *b.baseURI + if b.sharded && shard != "" { + u.Host = fmt.Sprintf("%s.%s", shard, u.Host) + } + u.Path = appendPath(u.Path, path) + return u.String() +} + +// Build a http.Request object. +func (b *balancedBackend) newJSONRequest(path, shard string, data []byte) (*http.Request, error) { + req, err := http.NewRequest("POST", b.getURIForRequest(shard, path), bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Length", strconv.FormatInt(int64(len(data)), 10)) + return req, nil +} + +// Select a new target from the given sequence and send the request to +// it. Wrap HTTP errors in a RemoteError object. +func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Request) (resp *http.Response, err error) { + target, terr := seq.Next() + if terr != nil { + return + } + + b.log.Printf("sequence %016x: connecting to %s", seq.ID(), target) + client := &http.Client{ + Transport: b.transportCache.getTransport(target), + } + resp, err = client.Do(req.WithContext(ctx)) + if err == nil && resp.StatusCode != 200 { + err = remoteErrorFromResponse(resp) + if !isStatusTemporary(resp.StatusCode) { + err = backoff.Permanent(err) + } + resp.Body.Close() // nolint + resp = nil + } + + seq.Done(target, err) + return +} + +var errNoTargets = errors.New("no available backends") + +type targetGenerator interface { + getTargets() []string + setStatus(string, bool) +} + +// A replicatedSequence repeatedly iterates over available backends in order of +// preference. Once in a while it refreshes its list of available +// targets. +type sequence struct { + id uint64 + tg targetGenerator + targets []string + pos int +} + +func newSequence(tg targetGenerator) *sequence { + return &sequence{ + id: rand.Uint64(), + tg: tg, + targets: tg.getTargets(), + } +} + +func (s *sequence) ID() uint64 { return s.id } + +func (s *sequence) reloadTargets() { + targets := s.tg.getTargets() + if len(targets) > 0 { + s.targets = targets + s.pos = 0 + } +} + +// Next returns the next target. +func (s *sequence) Next() (t string, err error) { + if s.pos >= len(s.targets) { + s.reloadTargets() + if len(s.targets) == 0 { + err = errNoTargets + return + } + } + t = s.targets[s.pos] + s.pos++ + return +} + +func (s *sequence) Done(t string, err error) { + s.tg.setStatus(t, err == nil) +} + +// A shardedGenerator returns a single sharded target to a sequence. +type shardedGenerator struct { + id uint64 + addrs []string +} + +func newShardedGenerator(shard, base string, resolver resolver) *shardedGenerator { + return &shardedGenerator{ + id: rand.Uint64(), + addrs: resolver.ResolveIP(fmt.Sprintf("%s.%s", shard, base)), + } +} + +func (g *shardedGenerator) getTargets() []string { return g.addrs } +func (g *shardedGenerator) setStatus(_ string, _ bool) {} + +// Concatenate two URI paths. +func appendPath(a, b string) string { + if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") { + return a + b[1:] + } + return a + b +} + +// Some HTTP status codes are treated are temporary errors. +func isStatusTemporary(code int) bool { + switch code { + case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return true + default: + return false + } +} diff --git a/clientutil/cpu.prof b/clientutil/cpu.prof new file mode 100644 index 0000000000000000000000000000000000000000..89fe0a7d104f6c1ca9466d218abbfe393d431402 GIT binary patch literal 4566 zcmb2|=3oE;mj6pCI7>vXPyH6JaBwl}7uSXz9cwkV<u&eFda3P(<C|#{bvQR{bPNgK zpC`I%T~>H#a)Fs;o5amFiF^rzZr&R|)gN1^f1D<%zR`GI&g3~qJZ<{k>iE5>C`y~5 za?WCPvhnkMtHQj?x+VK9m#=*Heee5yyZ1-0k~y#P;`M9k#@h2S5|R_Ov~O^{u>O7L z<UjND#QQgozk6R_9G2p8_QWYJGofE`wU0vV*Y5viVymobx^$}O8kS$b_7z3FsedN@ zrKe|=WUOW9RHyqlx9|J-)_e1Y_n{wqHk|S@>+-AqE%bNthW9!DvQy3+d;j-OWAwYd zKTfrtTqq>gv?+j7n)~ss^Y?0RpEfS<c)$C%+|SfbCZ9w4wng7LmQL~fz|FZt>zIl} z|Kq0{m~vR`KAs9Z<aS_bAd@s>gMxoRL+_FM`^u(lm{`!qnWHi#dFKJI5B%%zct>!~ zk<{EEG^0cKspr$j4op@SAxgaxk&9X)HRt{RbyCoaQ@b@YDP7a&*vX<+rW}#DKW{uI z8F9%;X>SNLU|oJ<1%q_&`Tx&Y5*I(@jemX7Ky>-3sa`V-|6JT{uRf9I8n;w)riJgk zbqvzY`~SUq?IO+WUUN%gMx*hQnO-wmKe+qv@!KNsMr4b?nM0AEd{`qS|NXSyuM#CG z^DQ$~VX@BnutPaIa#laJ7xlJ^HkpZuggmkIn$dfBcOQ3#jE>+1iy7U;Pc6ANeEjj} zdYzZh@xwcX8hhI-E&S)LXW)J;Jzq3JBJbNq1Dyk`Ow4~iB{nk3e6!^iIM{JO?Lnir zz2Z#a2Qvkop0RW!g(N)v@xz))yIucBgk?zLRGWCt&7D84i=XQc@w(x1g6oz`L<jd% zSFR176~C@=D^CqF<J|X7+KzM8#;88c18R?(_3!)NNU1pO-*-~r_>P#x466^s*ZdPn z6fL;t-^aD$XjG%`!Q}ioiaA_1R#9`zSwlje&ieE+fT_0st3Qvgj;!gdUe2k#vgUR? z!l%^zX2lq~GPYH`dzzf0aA@+rj~t~D>esas%|zuZ{)KTJ4p|}QwA_i=>_jiu1}~e! zZwwc;9?r;jQ8h`NyeERuj5ogiW!a%+2W20#<V%?(Z{Fj=XeK+q);;0!Lsx+tB@tcT ztY*yj_kA|_n0Da18b=PB&f!$1*`42I#Bcam1nF&<xj|++caGS+3eS*ZIu>t_CyK2( zAo^*^Cr$&VUvYJLJ{=FQoXWhMerC>6hRGk-pXa!mTCpUgV^X5(x+&qBiDuIKOJDeH zsxW&abRj3AclA?KFV^4>*Y5B2oxpT@YLFTGy>BPEwj^}#+19nuas%J>ldnY+ia(g= z&s|n<rnR|l+u2s$9gg=_FwE}e-+$HrhE2uU<IfJN-BHO8Ig}$eZ&!TE*_wi*A3rA= z=S@rYnsN3xce`GIlWxOPrYMH+2ck?-4gB^0WE=%&s4w8ylC-*JsZg5WpUcl#L|Rue zbsZABqqIQ7u0CK+C~G%s2y5;QwM`qi`uA1N3yI-K|NB6%r7UOq`k)5gBlqv>Y$_Mr z^CY=@0kbvBHjeW)e=S%yNd3E8?q^z(wERwJ+{FCA$4_panzQrsGyM*i)nY-WA&;l? zbssqPph5nNM9$H*N*fGoKBzVNw!dfc;M!oZ@56?O&co7mQb!V{^CnN9(tmz_K$Prk z!Q~d$AAD?HboJ_i+PbJK2g>f~Efz}K`s32X%_3)_A~bu}x32z}^;3#<Mpt~zYaYpk zhxlumLvJ)%KgphIu=?=%JBk^l7Zfu}3uNQeGrm^b-x%H*#_#`gOVy1|YuEOi1OI=# z$!;+{U|ze1vFxGV&S=JK9sDu7`fKiT`2XPDDOmP6FEY5n^kKJt*Q&fG?m9(NM<!vW z=$`mJ?=3!x9gcn6s@Lu49((PS`ZHTWt+gvdU%!%;n!Q46_Q%ideG`Sw7Ti0(JIy00 zXx<Wr=%)F<q5_2KHAVa4y8eHw-MEQm-%ImFss|>XUM-Sv`a{TlL#rxHneVgD26-fU z`Dq`R_SpWO@|j3B@tPdTFA4ntm8^9?`y8IuOR@)5TC9Ju;+d-Vblv!cEjPXXypx=` z{&e%ZeKI^IPi(zrtUYA^Z1bL|6+5K%|5|tAlIRE7`>VAVh*v*deM(Wk;>)Sfi5XX4 zxq8i5d$77*;L9e_rh85O^Mu9LwWQbn@-nz6BAcTtw?|<^=!azaXjaKH967T8o;(r! z^f=4HKG>qpG$Nz3CAQ!GFZ;#2*ABciWQt}syTx?7A-={dLHEM*WZelz*X&7|uxrP% z)VI67N{b{&e{9_^w<P%5A?Ld2IPU9@ziwLGa#QWk-}^^m_wc<ss`&bVYk117v^`%s zP8|ALa5B8cw&Dfn8?Kt?2jeH6;4L_D+{}Tm_|dDJqw@c!?>KCJXnswO@|&Z3Kk;&H z(Ej)$-sOhbLEXtmD`zLjzY|ux+a!M9!hPii%lPEkiT3ZN=gXvAIp#Ziea7CRZl*Q7 z^R1P%K5<#RxStSy!2N@A(vOVe=VL?;yKCFL%f7%Rf3*M9RIUxHKWINcH|bgUA#J-d z)fygqwO8wy=hr$!>|w2YlpOuw@sGm9LiK-t4=;Lm>p`Y`cC-Hjb(g!P4`=(&Q#yP5 zA-71MKz6$w=Wmt5`FThG|M?pJ;>hX;o?6q4%0%y|JoKIP$1}_M_J;pEUc8t65Uq7a z;B~`n)jg4uvQM4<eA0Z`mbHIZ^_^R{_VyHm+kNNj{!5t7@$+a15lxy~;Bbg}Qp|dW z-BR;y1eb*T_nE-PTXK87*v$f7yShItnzvcqSN69@h|5g*ut24019Sb)^uXfihe<O9 z_KDi5K3rb&C2G%KIl+BL3sP)6nr}B>ct2~c^@^>r?fzfcuRFdL5M8(U*UU3zkL2si zj3qblF4I)avDRT)xvD3(Ia+0p>m>e*{ckVmC))oFXn1?*{(a4?7g-ihc7Mp)5qc<^ zd584=Z|{_!Jj{M@+pdQ9j@rZANj26R%O5w|*YrJ*{a_KRnwrBdGwnkGLw3cN7>56s zH3H9Z{t4V4u~Y2dIrSKBcNH5Ufs#K#{}-#gx2l-?D@bxTpPlZ*<;)hH|7)gSNK?0Q zpBZ$l?2xR`J;4Jq@-?5;KK<9Ln6Yk>(KhZruZoaK@25;L@P63n&$cg{A@f4&_TupG zdAbYLE4$`cv@ZSl(cRUl?C{}9L8r=ohx>Gt9*z~dr+V^ksNy#1chf%Pd8TCY|NAjb zq2hb+X&vrK-Ag0;na&)TZ>#(%pn0*0Z|e-U?;aKX^N&=2c;3!3tKo6oG_?nDobSJ0 zjQjJDvGG*_cdU)#vcra){egV`JbjWooO*2cdd-;mWBc)68tjt<#4q|EsujAYe)70x zie$~lFvHzN^)EOL`P5Zx9?GXb%$MKY8#ALx{13yD2TNMiRe#ng-LCkc$hiFB|K9~t zex4Ot4V!;l;P(^S%^tV-M@t0%oM|7nsI2;0_2-(noWhlc`lTOQW-z9!*f?j*u>5f@ z__*h5TgP+n8aE5goBD+PMS^&JMMTD3?siRn=QZMTQ$9F8V*32SU%%^L+ONxg9Wx93 z4i!(TS<Ix<?jJ8=9>A9Oee0<Mm7M3jlca8S>py+3E~Rox{rOz0YiA<LzEvInT2w8* z^1pv>_xBr;a@T$2%4KcmZTs8)@aUE~={x2vls3HWdE9id-n?zAORXnwT)QJr>dvR4 zd5L_#q(1L>z0&KBiO|Q|Zt309uUGpXKJ|9BU$kY^mDN{gM2Q#Ynoi1C5goJos=0~j zbrZF#%bv|zw(IDsRi-a&7G8V(wNu33^t$?FPE)5Ni!N=7>p#3T#c|oW?AaB)QP%_3 zuj=L1+`G-}=9~-S%cWku67KvJ;v4K+TO7M-o!ivMM@?l^Rx93|k`W?3OIj*q>q^Jf zm$PRJovhlcH+4;>hiOJ{XGm@Dx=h|zE3&_<#a1tu+EuoG_x-D;m+!3XE8FUG?a8{R zRktlq2fM$zIVUPDbK2XYyW1wMvpRV)a^A#;8%?&`xL>Yxx?FwL!{)iwp?h0mW2Ljj zCKePWrQazwsgKslT`k;nbw*zL_uccBeP30q?=BYI!BM&BXud!<YoO%qhyN}$o$kDJ z$7lOAPqS;sV@%)tQaf(-cX?E7k7SOJ&qV(B&tGp#Exa52TJQLNLje;Pk6U^f7lk|` zD=alD6)jKSaDU?|_C0B$R(JL5>u&5KfnA*&c8L3{vd<81Gnm|#w(4`+;+@;>-JZKO z_Wv0H!_3a=EX!JRoqlVv?^d~UuXV9C_?d4yqQxQQ{$$DCYZuO}n^tCObm!x|e{xAP z71*|X{3$=TZK~N$t85EH7lrwv9Cs8qFQ3%lA-eM5+>D-iy4#+9tZl4xRy6gwof$lR z;={6Y!L$E=f2)>O*uSFk=UVw~wUvEoJ0ov@Jh5RL{{oLCSq--qy|lIQGQDc5pgVQ* zm#EUi?V<Vo+0jcPmmb|)(f>R(dE%-@=S6kbT-&?F-^yn_IP>smz}FSUU5_*03F`E$ z3P1jIwujd9kF`g;?QS?dHCqxFyVa(lW?OdAJl%O84({6_>!D!IV`6vu+m7?PS+(=B zgYI&gcG@o3@%rY?9;@Xu?teP7Ci(Rd$&9O-hrO&UPhS#J)Bh3}{{G>~IVZ1vWHg+1 zb<VwaJ`dOM#B#AZxf%OL`@cR~?W&;i?){?A%h|jQubY(2yYas8e3|X>4VR`*fAC(B zQ`Ds>dEzS5ogrPXHeX&YZW{llSN({elK--n6P9T@L0h*jHl6jH@3wq#pQrpHKV_|@ z2|7Vnm7cX`<gC>#S`juk{>=AECxhS#0)_v6x2#xg7qC#xBzSYlyqLu=_ulbxOZJSu z`{SX?mW*YRGd3yYR@m-z@!4)%I5&Flte27>Tb6ylwK8N)bM?7@y4%v1d|Q3?W1`#I zhc^n_Z#-$c`Tw!xt6BW{6Mi{L6<)e`=jOJh&GGXeuX|qi!*Y_%v9&4Yyc1JCv2zB9 zm+w{BaaX#x#DDd(Lht+cms)1NTbU@78S!$)^@*Ex-ii3{O9=kwspWI#`n<%q4|h#B z<Gmhy_~+3rXTM9v^;+pkXeQ|{ds|dGtA1v7noaiX5Ob4om7T9GztxonZrmd8Bw>>n z!snPY(P7HF4<R4dS*O^0o(gDJ3tq<8$P~Em(9;rGejkT_Q3kg!#E9`uOA`ye`$(?# zMGMz2xsX-oFSpw3#lO@m|5kJ{bJ_lW{S6CF+|!R;kd<mDbu;FWNy%B`ob29b7Hs8g zv0{JNtQ`eKF3(%9ve%mBWb~h3CszxX{}EA_zQ6u)f?<g5;i*MBL9U$lrg^bW_BPqp z`d--c_B5}R-tDu6+a`-&|I66D<=6f&uiJ6&W^9}G#B+Z@Olj1;*(H0c9@V`)r|$Rg zkzPY<^!dB*({fe^$VgtlEVF*whJ*1MG0W%PiO!Fhm3imH+|&0{W{SL3c5XlSO7DA^ z`s6;VLXUr0yUuNnt5bQKm6>lE(;r>(K9()&?dG<<*P0k-O>tOl&{^8`ruW#9;NE9~ z6BDoaCrxbB&5X|Bx@2h=V;IQT`M<+%g%#VZGrWBdKOQ>&_EGx9SKk&cy!v)M@5;P- z&vUQkP8S_FjWK-Bd4jXd$L~o{n{0Q;&(#-iR2e`2CQ?50#;gDSdb4d6Y=TTZYOa_s zySj~CHOMbzZj^+}SGjfHFMZj0t}y4gy-k3pq{{V&dDphZ$Srj|TTuOJtL=wl*%eWX z+>5LA{x8c}VlsE~xmL6Dwu{6r#TUoEt9X6Ica@du)P)l1n=eN6r>XkZseZ3|vt;$D zqmO2vs`$7jIbu<Uy6B3e+}UZ*c0~P~clD}Y)rY7{dFpjux{GV)Hq8ifUN*On)#0LE z?ybyU6Cxj7)fF<y*d2Z*oOyqAZmhAZujrls^Ex``-OGBt*JkCV-7`*H{hqgInMYvm z^F8Hz(`Fp}=hU~?J)>@a)XR6<zwCH+Q_{KLbybUh#<MN5aetncUHJDS?vuhzi~4{l zo~rOy4>Ge@{dOl;Z_AtA+_Pj=@LY!bH?Cg2XSreJi`TCI8Y=p4%{>wL;N_jBTBmt= z(VEJ0i+?omIop_Rkm|iT?VR;ackw2Ra$WyS%ge=AQ-3zM?fdG=HmyC`ZFR8ZufrX# zc{!U){1%sgTqB~VVwuA9S5`&-k|S5w!i3;UPeR&bmU~~*4cV#?uwUxRJC?`HeJ@p# gdY64K(thcdyeo6}pZ&}X4FCVXj&iMRP-b8N06^*FKmY&$ literal 0 HcmV?d00001 diff --git a/clientutil/dns.go b/clientutil/dns.go new file mode 100644 index 0000000..ed30f87 --- /dev/null +++ b/clientutil/dns.go @@ -0,0 +1,98 @@ +package clientutil + +import ( + "log" + "net" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +type resolver interface { + ResolveIP(string) []string +} + +type dnsResolver struct{} + +func (r *dnsResolver) ResolveIP(hostport string) []string { + var resolved []string + host, port, err := net.SplitHostPort(hostport) + if err != nil { + log.Printf("error parsing %s: %v", hostport, err) + return nil + } + hostIPs, err := net.LookupIP(host) + if err != nil { + log.Printf("error resolving %s: %v", host, err) + return nil + } + for _, ip := range hostIPs { + resolved = append(resolved, net.JoinHostPort(ip.String(), port)) + } + return resolved +} + +var defaultResolver = newDNSCache(&dnsResolver{}) + +type cacheDatum struct { + addrs []string + deadline time.Time +} + +type dnsCache struct { + resolver resolver + sf singleflight.Group + mx sync.RWMutex + cache map[string]cacheDatum +} + +func newDNSCache(resolver resolver) *dnsCache { + return &dnsCache{ + resolver: resolver, + cache: make(map[string]cacheDatum), + } +} + +func (c *dnsCache) get(host string) ([]string, bool) { + d, ok := c.cache[host] + if !ok { + return nil, false + } + return d.addrs, d.deadline.After(time.Now()) +} + +func (c *dnsCache) update(host string) []string { + v, _, _ := c.sf.Do(host, func() (interface{}, error) { + addrs := c.resolver.ResolveIP(host) + // By uncommenting this, we stop caching negative results. + // if len(addrs) == 0 { + // return nil, nil + // } + c.mx.Lock() + c.cache[host] = cacheDatum{ + addrs: addrs, + deadline: time.Now().Add(60 * time.Second), + } + c.mx.Unlock() + return addrs, nil + }) + return v.([]string) +} + +func (c *dnsCache) ResolveIP(host string) []string { + c.mx.RLock() + addrs, ok := c.get(host) + c.mx.RUnlock() + + if ok { + return addrs + } + + if len(addrs) > 0 { + go c.update(host) + return addrs + } + + return c.update(host) +} diff --git a/clientutil/dns_test.go b/clientutil/dns_test.go new file mode 100644 index 0000000..3162c5e --- /dev/null +++ b/clientutil/dns_test.go @@ -0,0 +1,27 @@ +package clientutil + +import "testing" + +type fakeResolver struct { + addrs []string + requests int +} + +func (r *fakeResolver) ResolveIP(host string) []string { + r.requests++ + return r.addrs +} + +func TestDNSCache(t *testing.T) { + r := &fakeResolver{addrs: []string{"1.2.3.4"}} + c := newDNSCache(r) + for i := 0; i < 5; i++ { + addrs := c.ResolveIP("a.b.c.d") + if len(addrs) != 1 { + t.Errorf("ResolveIP returned bad response: %v", addrs) + } + } + if r.requests != 1 { + t.Errorf("cached resolver has wrong number of requests: %d, expecting 1", r.requests) + } +} diff --git a/clientutil/doc.go b/clientutil/doc.go new file mode 100644 index 0000000..421915b --- /dev/null +++ b/clientutil/doc.go @@ -0,0 +1,37 @@ +// Package clientutil implements a very simple style of JSON RPC. +// +// Requests and responses are both encoded in JSON, and they should +// have the "application/json" Content-Type. +// +// HTTP response statuses other than 200 indicate an error: in this +// case, the response body may contain (in plain text) further details +// about the error. Some HTTP status codes are considered temporary +// errors (incl. 429 for throttling). The client will retry requests, +// if targets are available, until the context expires - so it's quite +// important to remember to set a timeout on the context given to the +// Call() function! +// +// The client handles both replicated services and sharded +// (partitioned) services. Users of this package that want to support +// sharded deployments are supposed to pass a shard ID to every +// Call(). At the deployment stage, sharding can be enabled via the +// configuration. +// +// For replicated services, the client will expect the provided +// hostname to resolve to one or more IP addresses, in which case it +// will pick a random IP address on every new request, while +// remembering which addresses have had errors and trying to avoid +// them. It will however send an occasional request to the failed +// targets, to see if they've come back. +// +// For sharded services, the client makes simple HTTP requests to the +// specific target identified by the shard. It does this by prepending +// the shard ID to the backend hostname (so a request to "example.com" +// with shard ID "1" becomes a request to "1.example.com"). +// +// The difference with other JSON-RPC implementations is that we use a +// different URI for every method, and we force the usage of +// request/response types. This makes it easy for projects to +// eventually migrate to GRPC. +// +package clientutil diff --git a/clientutil/error.go b/clientutil/error.go new file mode 100644 index 0000000..f011e16 --- /dev/null +++ b/clientutil/error.go @@ -0,0 +1,35 @@ +package clientutil + +import ( + "fmt" + "io/ioutil" + "net/http" +) + +// RemoteError represents a HTTP error from the server. The status +// code and response body can be retrieved with the StatusCode() and +// Body() methods. +type RemoteError struct { + statusCode int + body string +} + +func remoteErrorFromResponse(resp *http.Response) *RemoteError { + // Optimistically read the response body, ignoring errors. + var body string + if data, err := ioutil.ReadAll(resp.Body); err == nil { + body = string(data) + } + return &RemoteError{statusCode: resp.StatusCode, body: body} +} + +// Error implements the error interface. +func (e *RemoteError) Error() string { + return fmt.Sprintf("%d - %s", e.statusCode, e.body) +} + +// StatusCode returns the HTTP status code. +func (e *RemoteError) StatusCode() int { return e.statusCode } + +// Body returns the response body. +func (e *RemoteError) Body() string { return e.body } diff --git a/clientutil/json.go b/clientutil/json.go deleted file mode 100644 index 5fc1ab2..0000000 --- a/clientutil/json.go +++ /dev/null @@ -1,45 +0,0 @@ -package clientutil - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "net/http" -) - -// DoJSONHTTPRequest makes an HTTP POST request to the specified uri, -// with a JSON-encoded request body. It will attempt to decode the -// response body as JSON. -func DoJSONHTTPRequest(ctx context.Context, client *http.Client, uri string, req, resp interface{}) error { - data, err := json.Marshal(req) - if err != nil { - return err - } - - httpReq, err := http.NewRequest("POST", uri, bytes.NewReader(data)) - if err != nil { - return err - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq = httpReq.WithContext(ctx) - - httpResp, err := RetryHTTPDo(client, httpReq, NewExponentialBackOff()) - if err != nil { - return err - } - defer httpResp.Body.Close() - - if httpResp.StatusCode != 200 { - return fmt.Errorf("HTTP status %d", httpResp.StatusCode) - } - if httpResp.Header.Get("Content-Type") != "application/json" { - return errors.New("not a JSON response") - } - - if resp == nil { - return nil - } - return json.NewDecoder(httpResp.Body).Decode(resp) -} diff --git a/clientutil/retry.go b/clientutil/retry.go deleted file mode 100644 index 3ca7b51..0000000 --- a/clientutil/retry.go +++ /dev/null @@ -1,92 +0,0 @@ -package clientutil - -import ( - "errors" - "net/http" - "time" - - "github.com/cenkalti/backoff" -) - -// NewExponentialBackOff creates a backoff.ExponentialBackOff object -// with our own default values. -func NewExponentialBackOff() *backoff.ExponentialBackOff { - b := backoff.NewExponentialBackOff() - b.InitialInterval = 100 * time.Millisecond - //b.Multiplier = 1.4142 - return b -} - -// A temporary (retriable) error is something that has a Temporary method. -type tempError interface { - Temporary() bool -} - -type tempErrorWrapper struct { - error -} - -func (t tempErrorWrapper) Temporary() bool { return true } - -// TempError makes a temporary (retriable) error out of a normal error. -func TempError(err error) error { - return tempErrorWrapper{err} -} - -// Retry operation op until it succeeds according to the backoff -// policy b. -// -// Note that this function reverses the error semantics of -// backoff.Operation: all errors are permanent unless explicitly -// marked as temporary (i.e. they have a Temporary() method that -// returns true). This is to better align with the errors returned by -// the net package. -func Retry(op backoff.Operation, b backoff.BackOff) error { - innerOp := func() error { - err := op() - if err == nil { - return err - } - if tmpErr, ok := err.(tempError); ok && tmpErr.Temporary() { - return err - } - return backoff.Permanent(err) - } - return backoff.Retry(innerOp, b) -} - -var errHTTPBackOff = TempError(errors.New("temporary http error")) - -func isStatusTemporary(code int) bool { - switch code { - case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: - return true - default: - return false - } -} - -// RetryHTTPDo retries an HTTP request until it succeeds, according to -// the backoff policy b. It will retry on temporary network errors and -// upon receiving specific temporary HTTP errors. It will use the -// context associated with the HTTP request object. -func RetryHTTPDo(client *http.Client, req *http.Request, b backoff.BackOff) (*http.Response, error) { - var resp *http.Response - op := func() error { - // Clear up previous response if set. - if resp != nil { - resp.Body.Close() - } - - var err error - resp, err = client.Do(req) - if err == nil && isStatusTemporary(resp.StatusCode) { - resp.Body.Close() - return errHTTPBackOff - } - return err - } - - err := Retry(op, backoff.WithContext(b, req.Context())) - return resp, err -} diff --git a/clientutil/retry_test.go b/clientutil/retry_test.go deleted file mode 100644 index b7d5f03..0000000 --- a/clientutil/retry_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package clientutil - -import ( - "context" - "io" - "io/ioutil" - "log" - "net" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" -) - -type tcpHandler interface { - Handle(net.Conn) -} - -type tcpHandlerFunc func(net.Conn) - -func (f tcpHandlerFunc) Handle(c net.Conn) { f(c) } - -// Base TCP server type (to build fake LDAP servers). -type tcpServer struct { - l net.Listener - handler tcpHandler -} - -func newTCPServer(t testing.TB, handler tcpHandler) *tcpServer { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal("Listen():", err) - } - log.Printf("started new tcp server on %s", l.Addr().String()) - s := &tcpServer{l: l, handler: handler} - go s.serve() - return s -} - -func (s *tcpServer) serve() { - for { - conn, err := s.l.Accept() - if err != nil { - return - } - go func(c net.Conn) { - s.handler.Handle(c) - c.Close() - }(conn) - } -} - -func (s *tcpServer) Addr() string { - return s.l.Addr().String() -} - -func (s *tcpServer) Close() { - s.l.Close() -} - -// A test server that will close all incoming connections right away. -func newConnFailServer(t testing.TB) *tcpServer { - return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) {})) -} - -// A test server that will close all connections after a 1s delay. -func newConnFailDelayServer(t testing.TB) *tcpServer { - return newTCPServer(t, tcpHandlerFunc(func(c net.Conn) { time.Sleep(1 * time.Second) })) -} - -type httpServer struct { - *httptest.Server -} - -func (s *httpServer) Addr() string { - u, _ := url.Parse(s.Server.URL) - return u.Host -} - -// An HTTP server that will always return a specific HTTP status. -func newStatusHTTPServer(statusCode int) *httpServer { - return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(statusCode) - io.WriteString(w, "hello\n") - }))} -} - -func newOKHTTPServer() *httpServer { - return &httpServer{httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - io.WriteString(w, "OK") - }))} -} - -type testServer interface { - Addr() string - Close() -} - -type testBackends struct { - servers []testServer - addrs []string -} - -func newTestBackends(servers ...testServer) *testBackends { - b := new(testBackends) - for _, s := range servers { - b.servers = append(b.servers, s) - b.addrs = append(b.addrs, s.Addr()) - } - return b -} - -func (b *testBackends) ResolveIPs(_ []string) []string { - return b.addrs -} - -func (b *testBackends) stop(i int) { - b.servers[i].Close() -} - -func (b *testBackends) close() { - for _, s := range b.servers { - s.Close() - } -} - -func doRequests(backends *testBackends, u string, n int) (int, int) { - c := &http.Client{ - Transport: NewTransport([]string{"backend"}, nil, backends), - } - b := NewExponentialBackOff() - - var errs, oks int - for i := 0; i < n; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - req, _ := http.NewRequest("GET", u, nil) - resp, err := RetryHTTPDo(c, req.WithContext(ctx), b) - cancel() - if err != nil { - errs++ - continue - } - _, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 200 { - errs++ - continue - } - if err != nil { - errs++ - continue - } - oks++ - } - - return oks, errs -} - -func TestRetryHTTP_BackendsDown(t *testing.T) { - b := newTestBackends(newOKHTTPServer(), newOKHTTPServer(), newOKHTTPServer()) - defer b.close() - - oks, errs := doRequests(b, "http://backend/", 100) - if errs > 0 { - t.Fatalf("errs=%d", errs) - } - if oks == 0 { - t.Fatal("oks=0") - } - - b.stop(0) - b.stop(1) - - oks, errs = doRequests(b, "http://backend/", 100) - if errs > 0 { - t.Fatalf("errs=%d", errs) - } - if oks == 0 { - t.Fatal("oks=0") - } -} - -func TestRetryHTTP_HighLatencyBackend(t *testing.T) { - b := newTestBackends(newConnFailDelayServer(t), newOKHTTPServer()) - defer b.close() - - _, _ = doRequests(b, "http://backend/", 10) - // Silly transport.go load balancer only balances connections, - // so in this scenario we'll just keep hitting the slow - // server, exhausting our deadline budget, and never fail over - // to the secondary backend :( - // if errs > 0 { - // t.Fatalf("errs=%d", errs) - // } - // if oks == 0 { - // t.Fatal("oks=0") - // } -} diff --git a/clientutil/track.go b/clientutil/track.go new file mode 100644 index 0000000..2db20bb --- /dev/null +++ b/clientutil/track.go @@ -0,0 +1,123 @@ +package clientutil + +import ( + "math/rand" + "sync" + "time" +) + +// The backendTracker tracks the state of the targets associated with +// a backend, and periodically checks DNS for updates. +type backendTracker struct { + log logger + addr string + resolver resolver + stopCh chan struct{} + + mx sync.Mutex + resolved []string + failed map[string]time.Time +} + +func newBackendTracker(addr string, resolver resolver, logger logger) *backendTracker { + // Resolve the targets once before returning. + b := &backendTracker{ + addr: addr, + resolver: resolver, + resolved: resolver.ResolveIP(addr), + failed: make(map[string]time.Time), + stopCh: make(chan struct{}), + log: logger, + } + go b.updateProc() + return b +} + +func (b *backendTracker) Close() { + close(b.stopCh) +} + +// Return the full list of targets in reverse preference order. +func (b *backendTracker) getTargets() []string { + b.mx.Lock() + defer b.mx.Unlock() + + var good, bad []string + for _, t := range b.resolved { + if _, ok := b.failed[t]; ok { + bad = append(bad, t) + } else { + good = append(good, t) + } + } + + good = shuffle(good) + bad = shuffle(bad) + + return append(good, bad...) +} + +func (b *backendTracker) setStatus(addr string, ok bool) { + b.mx.Lock() + + _, isFailed := b.failed[addr] + if isFailed && ok { + b.log.Printf("target %s now ok", addr) + delete(b.failed, addr) + } else if !isFailed && !ok { + b.log.Printf("target %s failed", addr) + b.failed[addr] = time.Now() + } + + b.mx.Unlock() +} + +var ( + backendUpdateInterval = 60 * time.Second + backendFailureRetryInterval = 60 * time.Second +) + +func (b *backendTracker) expireFailedTargets() { + b.mx.Lock() + now := time.Now() + for k, v := range b.failed { + if now.Sub(v) > backendFailureRetryInterval { + delete(b.failed, k) + } + } + b.mx.Unlock() +} + +func (b *backendTracker) updateProc() { + tick := time.NewTicker(backendUpdateInterval) + defer tick.Stop() + for { + select { + case <-b.stopCh: + return + case <-tick.C: + b.expireFailedTargets() + resolved := b.resolver.ResolveIP(b.addr) + if len(resolved) > 0 { + b.mx.Lock() + b.resolved = resolved + b.mx.Unlock() + } + } + } +} + +var shuffleSrc = rand.NewSource(time.Now().UnixNano()) + +// Re-order elements of a slice randomly. +func shuffle(values []string) []string { + if len(values) < 2 { + return values + } + rnd := rand.New(shuffleSrc) + for i := len(values) - 1; i > 0; i-- { + j := rnd.Intn(i + 1) + values[i], values[j] = values[j], values[i] + } + return values +} diff --git a/clientutil/transport.go b/clientutil/transport.go index e4f98e3..843a760 100644 --- a/clientutil/transport.go +++ b/clientutil/transport.go @@ -3,170 +3,63 @@ package clientutil import ( "context" "crypto/tls" - "errors" - "log" "net" "net/http" "sync" "time" ) -var errAllBackendsFailed = errors.New("all backends failed") - -type dnsResolver struct{} - -func (r *dnsResolver) ResolveIPs(hosts []string) []string { - var resolved []string - for _, hostport := range hosts { - host, port, err := net.SplitHostPort(hostport) - if err != nil { - log.Printf("error parsing %s: %v", hostport, err) - continue - } - hostIPs, err := net.LookupIP(host) - if err != nil { - log.Printf("error resolving %s: %v", host, err) - continue - } - for _, ip := range hostIPs { - resolved = append(resolved, net.JoinHostPort(ip.String(), port)) - } - } - return resolved -} - -var defaultResolver = &dnsResolver{} - -type resolver interface { - ResolveIPs([]string) []string -} - -// Balancer for HTTP connections. It will round-robin across available -// backends, trying to avoid ones that are erroring out, until one -// succeeds or they all fail. +// The transportCache is just a cache of http transports, each +// connecting to a specific address. // -// This object should not be used for load balancing of individual -// HTTP requests: once a new connection is established, requests will -// be sent over it until it errors out. It's meant to provide a -// *reliable* connection to a set of equivalent backends for HA -// purposes. -type balancer struct { - hosts []string - resolver resolver - stop chan bool +// We use this to control the HTTP Host header and the TLS ServerName +// independently of the target address. +type transportCache struct { + tlsConfig *tls.Config - // List of currently valid (or untested) backends, and ones - // that errored out at least once. - mx sync.Mutex - addrs []string - ok map[string]bool + mx sync.RWMutex + transports map[string]http.RoundTripper } -var backendUpdateInterval = 60 * time.Second +func newTransportCache(tlsConfig *tls.Config) *transportCache { + return &transportCache{ + tlsConfig: tlsConfig, + transports: make(map[string]http.RoundTripper), + } +} -// Periodically update the list of available backends. -func (b *balancer) updateProc() { - tick := time.NewTicker(backendUpdateInterval) - for { - select { - case <-b.stop: - return - case <-tick.C: - resolved := b.resolver.ResolveIPs(b.hosts) - if len(resolved) > 0 { - b.mx.Lock() - b.addrs = resolved - b.mx.Unlock() - } - } +func (m *transportCache) newTransport(addr string) http.RoundTripper { + return &http.Transport{ + TLSClientConfig: m.tlsConfig, + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { + return netDialContext(ctx, network, addr) + }, } } -// Returns a list of all available backends, split into "good ones" -// (no errors seen since last successful connection) and "bad ones". -func (b *balancer) getBackends() ([]string, []string) { - b.mx.Lock() - defer b.mx.Unlock() +func (m *transportCache) getTransport(addr string) http.RoundTripper { + m.mx.RLock() + t, ok := m.transports[addr] + m.mx.RUnlock() - var good, bad []string - for _, addr := range b.addrs { - if ok := b.ok[addr]; ok { - good = append(good, addr) - } else { - bad = append(bad, addr) + if !ok { + m.mx.Lock() + if t, ok = m.transports[addr]; !ok { + t = m.newTransport(addr) + m.transports[addr] = t } + m.mx.Unlock() } - return good, bad -} -func (b *balancer) notify(addr string, ok bool) { - b.mx.Lock() - b.ok[addr] = ok - b.mx.Unlock() + return t } +// Go < 1.9 does not have net.DialContext, reimplement it in terms of +// net.DialTimeout. func netDialContext(ctx context.Context, network, addr string) (net.Conn, error) { - timeout := 30 * time.Second - // Go < 1.9 does not have net.DialContext, reimplement it in - // terms of net.DialTimeout. + timeout := 60 * time.Second // some arbitrary max timeout if deadline, ok := ctx.Deadline(); ok { timeout = time.Until(deadline) } return net.DialTimeout(network, addr, timeout) } - -func (b *balancer) dial(ctx context.Context, network, addr string) (net.Conn, error) { - // Start by attempting a connection on 'good' targets. - good, bad := b.getBackends() - - for _, addr := range good { - // Go < 1.9 does not have DialContext, deal with it - conn, err := netDialContext(ctx, network, addr) - if err == nil { - return conn, nil - } else if err == context.Canceled { - // A timeout might be bad, set the error bit - // on the connection. - b.notify(addr, false) - return nil, err - } - b.notify(addr, false) - } - - for _, addr := range bad { - conn, err := netDialContext(ctx, network, addr) - if err == nil { - b.notify(addr, true) - return conn, nil - } else if err == context.Canceled { - return nil, err - } - } - - return nil, errAllBackendsFailed -} - -// NewTransport returns a suitably configured http.RoundTripper that -// talks to a specific backend service. It performs discovery of -// available backends via DNS (using A or AAAA record lookups), tries -// to route traffic away from faulty backends. -// -// It will periodically attempt to rediscover new backends. -func NewTransport(backends []string, tlsConf *tls.Config, resolver resolver) http.RoundTripper { - if resolver == nil { - resolver = defaultResolver - } - addrs := resolver.ResolveIPs(backends) - b := &balancer{ - hosts: backends, - resolver: resolver, - addrs: addrs, - ok: make(map[string]bool), - } - go b.updateProc() - - return &http.Transport{ - DialContext: b.dial, - TLSClientConfig: tlsConf, - } -} -- GitLab