package clientutil import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "log" "math/rand" "net/http" "net/url" "os" "strconv" "strings" "time" "github.com/cenkalti/backoff" ) // Our own narrow logger interface. type logger interface { Printf(string, ...interface{}) } // A nilLogger is used when Config.Debug is false. type nilLogger struct{} func (l nilLogger) Printf(_ string, _ ...interface{}) {} // Parameters that define the exponential backoff algorithm used. var ( ExponentialBackOffInitialInterval = 100 * time.Millisecond ExponentialBackOffMultiplier = 1.4142 ) // newExponentialBackOff creates a backoff.ExponentialBackOff object // with our own default values. func newExponentialBackOff() *backoff.ExponentialBackOff { b := backoff.NewExponentialBackOff() b.InitialInterval = ExponentialBackOffInitialInterval b.Multiplier = ExponentialBackOffMultiplier // Set MaxElapsedTime to 0 because we expect the overall // timeout to be dictated by the request Context. b.MaxElapsedTime = 0 return b } // Balancer for HTTP connections. It will round-robin across available // backends, trying to avoid ones that are erroring out, until one // succeeds or returns a permanent error. // // This object should not be used for load balancing of individual // HTTP requests: it doesn't do anything smart beyond trying to avoid // broken targets. It's meant to provide a *reliable* connection to a // set of equivalent services for HA purposes. type balancedBackend struct { *backendTracker *transportCache baseURI *url.URL sharded bool resolver resolver log logger } func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) { u, err := url.Parse(config.URL) if err != nil { return nil, err } var tlsConfig *tls.Config if config.TLSConfig != nil { tlsConfig, err = config.TLSConfig.TLSConfig() if err != nil { return nil, err } } var logger logger = &nilLogger{} if config.Debug { logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0) } return &balancedBackend{ backendTracker: newBackendTracker(u.Host, resolver, logger), transportCache: newTransportCache(tlsConfig), sharded: config.Sharded, baseURI: u, resolver: resolver, log: logger, }, nil } // Call the backend. Makes an HTTP POST request to the specified uri, // with a JSON-encoded request body. It will attempt to decode the // response body as JSON. func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, resp interface{}) error { // Serialize the request body. data, err := json.Marshal(req) if err != nil { return err } // Create the target sequence for this call. If there are multiple // targets, reduce the timeout on each individual call accordingly to // accomodate eventual failover. seq, err := b.makeSequence(shard) if err != nil { return err } innerTimeout := 1 * time.Hour if deadline, ok := ctx.Deadline(); ok { innerTimeout = time.Until(deadline) / time.Duration(seq.Len()) } // Call the backends in the sequence until one succeeds, with an // exponential backoff policy controlled by the outer Context. return backoff.Retry(func() error { req, rerr := b.newJSONRequest(path, shard, data) if rerr != nil { return rerr } innerCtx, cancel := context.WithTimeout(ctx, innerTimeout) defer cancel() // When do() returns successfully, we already know that the // response had an HTTP status of 200. httpResp, rerr := b.do(innerCtx, seq, req) if rerr != nil { return rerr } defer httpResp.Body.Close() // nolint // Decode the response, unless the 'resp' output is nil. if httpResp.Header.Get("Content-Type") != "application/json" { return errors.New("not a JSON response") } if resp == nil { return nil } return json.NewDecoder(httpResp.Body).Decode(resp) }, backoff.WithContext(newExponentialBackOff(), ctx)) } // Initialize a new target sequence. func (b *balancedBackend) makeSequence(shard string) (*sequence, error) { var tg targetGenerator = b.backendTracker if b.sharded { if shard == "" { return nil, fmt.Errorf("call without shard to sharded service %s", b.baseURI.String()) } tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver) } seq := newSequence(tg) if seq.Len() == 0 { return nil, errNoTargets } b.log.Printf("%016x: initialized", seq.ID()) return seq, nil } // Return the URI to be used for the request. This is used both in the // Host HTTP header and as the TLS server name used to pick a server // certificate (if using TLS). func (b *balancedBackend) getURIForRequest(shard, path string) string { u := *b.baseURI if b.sharded && shard != "" { u.Host = fmt.Sprintf("%s.%s", shard, u.Host) } u.Path = appendPath(u.Path, path) return u.String() } // Build a http.Request object. func (b *balancedBackend) newJSONRequest(path, shard string, data []byte) (*http.Request, error) { req, err := http.NewRequest("POST", b.getURIForRequest(shard, path), bytes.NewReader(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Length", strconv.FormatInt(int64(len(data)), 10)) return req, nil } // Select a new target from the given sequence and send the request to // it. Wrap HTTP errors in a RemoteError object. func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Request) (resp *http.Response, err error) { target, terr := seq.Next() if terr != nil { return } b.log.Printf("sequence %016x: connecting to %s", seq.ID(), target) client := &http.Client{ Transport: b.transportCache.getTransport(target), } resp, err = client.Do(req.WithContext(ctx)) if err == nil && resp.StatusCode != 200 { err = remoteErrorFromResponse(resp) if !isStatusTemporary(resp.StatusCode) { err = backoff.Permanent(err) } resp.Body.Close() // nolint resp = nil } seq.Done(target, err) return } var errNoTargets = errors.New("no available backends") type targetGenerator interface { getTargets() []string setStatus(string, bool) } // A replicatedSequence repeatedly iterates over available backends in order of // preference. Once in a while it refreshes its list of available // targets. type sequence struct { id uint64 tg targetGenerator targets []string pos int } func newSequence(tg targetGenerator) *sequence { return &sequence{ id: rand.Uint64(), tg: tg, targets: tg.getTargets(), } } func (s *sequence) ID() uint64 { return s.id } func (s *sequence) Len() int { return len(s.targets) } func (s *sequence) reloadTargets() { targets := s.tg.getTargets() if len(targets) > 0 { s.targets = targets s.pos = 0 } } // Next returns the next target. func (s *sequence) Next() (t string, err error) { if s.pos >= len(s.targets) { s.reloadTargets() if len(s.targets) == 0 { err = errNoTargets return } } t = s.targets[s.pos] s.pos++ return } func (s *sequence) Done(t string, err error) { s.tg.setStatus(t, err == nil) } // A shardedGenerator returns a single sharded target to a sequence. type shardedGenerator struct { id uint64 addrs []string } func newShardedGenerator(shard, base string, resolver resolver) *shardedGenerator { return &shardedGenerator{ id: rand.Uint64(), addrs: resolver.ResolveIP(fmt.Sprintf("%s.%s", shard, base)), } } func (g *shardedGenerator) getTargets() []string { return g.addrs } func (g *shardedGenerator) setStatus(_ string, _ bool) {} // Concatenate two URI paths. func appendPath(a, b string) string { if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") { return a + b[1:] } return a + b } // Some HTTP status codes are treated are temporary errors. func isStatusTemporary(code int) bool { switch code { case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: return true default: return false } }