Skip to content
Snippets Groups Projects
balancer.go 7.81 KiB
Newer Older
package clientutil

import (
	"bytes"
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"log"
	"math/rand"
	"net/http"
	"net/url"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/cenkalti/backoff"
)

// Our own narrow logger interface.
type logger interface {
	Printf(string, ...interface{})
}

// A nilLogger is used when Config.Debug is false.
type nilLogger struct{}

func (l nilLogger) Printf(_ string, _ ...interface{}) {}

// Parameters that define the exponential backoff algorithm used.
var (
	ExponentialBackOffInitialInterval = 100 * time.Millisecond
	ExponentialBackOffMultiplier      = 1.4142
)

// newExponentialBackOff creates a backoff.ExponentialBackOff object
// with our own default values.
func newExponentialBackOff() *backoff.ExponentialBackOff {
	b := backoff.NewExponentialBackOff()
	b.InitialInterval = ExponentialBackOffInitialInterval
	b.Multiplier = ExponentialBackOffMultiplier

	// Set MaxElapsedTime to 0 because we expect the overall
	// timeout to be dictated by the request Context.
	b.MaxElapsedTime = 0

	return b
}

// Balancer for HTTP connections. It will round-robin across available
// backends, trying to avoid ones that are erroring out, until one
// succeeds or returns a permanent error.
//
// This object should not be used for load balancing of individual
// HTTP requests: it doesn't do anything smart beyond trying to avoid
// broken targets. It's meant to provide a *reliable* connection to a
// set of equivalent services for HA purposes.
type balancedBackend struct {
	*backendTracker
	*transportCache
	baseURI  *url.URL
	sharded  bool
	resolver resolver
	log      logger
}

func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBackend, error) {
	u, err := url.Parse(config.URL)
	if err != nil {
		return nil, err
	}

	var tlsConfig *tls.Config
	if config.TLSConfig != nil {
		tlsConfig, err = config.TLSConfig.TLSConfig()
		if err != nil {
			return nil, err
		}
	}

	var logger logger = &nilLogger{}
	if config.Debug {
		logger = log.New(os.Stderr, fmt.Sprintf("backend %s: ", u.Host), 0)
	}
	return &balancedBackend{
		backendTracker: newBackendTracker(u.Host, resolver, logger),
		transportCache: newTransportCache(tlsConfig),
		sharded:        config.Sharded,
		baseURI:        u,
		resolver:       resolver,
		log:            logger,
	}, nil
}

// Call the backend. Makes an HTTP POST request to the specified uri,
// with a JSON-encoded request body. It will attempt to decode the
// response body as JSON.
func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, resp interface{}) error {
ale's avatar
ale committed
	// Serialize the request body.
	data, err := json.Marshal(req)
	if err != nil {
		return err
	}

ale's avatar
ale committed
	// Create the target sequence for this call. If there are multiple
	// targets, reduce the timeout on each individual call accordingly to
	// accomodate eventual failover.
	seq, err := b.makeSequence(shard)
	if err != nil {
		return err
	}
	innerTimeout := 1 * time.Hour
	if deadline, ok := ctx.Deadline(); ok {
		innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
ale's avatar
ale committed
	// Call the backends in the sequence until one succeeds, with an
	// exponential backoff policy controlled by the outer Context.
	return backoff.Retry(func() error {
		req, rerr := b.newJSONRequest(path, shard, data)
		if rerr != nil {
			return rerr
		}
ale's avatar
ale committed
		innerCtx, cancel := context.WithTimeout(ctx, innerTimeout)
		defer cancel()

		// When do() returns successfully, we already know that the
		// response had an HTTP status of 200.
		httpResp, rerr := b.do(innerCtx, seq, req)
		if rerr != nil {
			return rerr
		}
		defer httpResp.Body.Close() // nolint

		// Decode the response, unless the 'resp' output is nil.
		if httpResp.Header.Get("Content-Type") != "application/json" {
			return errors.New("not a JSON response")
		}
		if resp == nil {
			return nil
		}
		return json.NewDecoder(httpResp.Body).Decode(resp)
	}, backoff.WithContext(newExponentialBackOff(), ctx))
ale's avatar
ale committed
}
ale's avatar
ale committed
// Initialize a new target sequence.
func (b *balancedBackend) makeSequence(shard string) (*sequence, error) {
	var tg targetGenerator = b.backendTracker
	if b.sharded {
		if shard == "" {
			return nil, fmt.Errorf("call without shard to sharded service %s", b.baseURI.String())
		}
		tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver)
ale's avatar
ale committed
	seq := newSequence(tg)
	if seq.Len() == 0 {
		return nil, errNoTargets
ale's avatar
ale committed
	b.log.Printf("%016x: initialized", seq.ID())
	return seq, nil
}

// Return the URI to be used for the request. This is used both in the
// Host HTTP header and as the TLS server name used to pick a server
// certificate (if using TLS).
func (b *balancedBackend) getURIForRequest(shard, path string) string {
	u := *b.baseURI
	if b.sharded && shard != "" {
		u.Host = fmt.Sprintf("%s.%s", shard, u.Host)
	}
	u.Path = appendPath(u.Path, path)
	return u.String()
}

// Build a http.Request object.
func (b *balancedBackend) newJSONRequest(path, shard string, data []byte) (*http.Request, error) {
	req, err := http.NewRequest("POST", b.getURIForRequest(shard, path), bytes.NewReader(data))
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Content-Length", strconv.FormatInt(int64(len(data)), 10))
	return req, nil
}

// Select a new target from the given sequence and send the request to
// it. Wrap HTTP errors in a RemoteError object.
func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Request) (resp *http.Response, err error) {
	target, terr := seq.Next()
	if terr != nil {
		return
	}

	b.log.Printf("sequence %016x: connecting to %s", seq.ID(), target)
	client := &http.Client{
		Transport: b.transportCache.getTransport(target),
	}
	resp, err = client.Do(req.WithContext(ctx))
	if err == nil && resp.StatusCode != 200 {
		err = remoteErrorFromResponse(resp)
		if !isStatusTemporary(resp.StatusCode) {
			err = backoff.Permanent(err)
		}
		resp.Body.Close() // nolint
		resp = nil
	}

	seq.Done(target, err)
	return
}

var errNoTargets = errors.New("no available backends")

type targetGenerator interface {
	getTargets() []string
	setStatus(string, bool)
}

// A replicatedSequence repeatedly iterates over available backends in order of
// preference. Once in a while it refreshes its list of available
// targets.
type sequence struct {
	id      uint64
	tg      targetGenerator
	targets []string
	pos     int
}

func newSequence(tg targetGenerator) *sequence {
	return &sequence{
		id:      rand.Uint64(),
		tg:      tg,
		targets: tg.getTargets(),
	}
}

func (s *sequence) ID() uint64 { return s.id }

ale's avatar
ale committed
func (s *sequence) Len() int { return len(s.targets) }

func (s *sequence) reloadTargets() {
	targets := s.tg.getTargets()
	if len(targets) > 0 {
		s.targets = targets
		s.pos = 0
	}
}

// Next returns the next target.
func (s *sequence) Next() (t string, err error) {
	if s.pos >= len(s.targets) {
		s.reloadTargets()
		if len(s.targets) == 0 {
			err = errNoTargets
			return
		}
	}
	t = s.targets[s.pos]
	s.pos++
	return
}

func (s *sequence) Done(t string, err error) {
	s.tg.setStatus(t, err == nil)
}

// A shardedGenerator returns a single sharded target to a sequence.
type shardedGenerator struct {
	id    uint64
	addrs []string
}

func newShardedGenerator(shard, base string, resolver resolver) *shardedGenerator {
	return &shardedGenerator{
		id:    rand.Uint64(),
		addrs: resolver.ResolveIP(fmt.Sprintf("%s.%s", shard, base)),
	}
}

func (g *shardedGenerator) getTargets() []string       { return g.addrs }
func (g *shardedGenerator) setStatus(_ string, _ bool) {}

// Concatenate two URI paths.
func appendPath(a, b string) string {
	if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") {
		return a + b[1:]
	}
	return a + b
}

// Some HTTP status codes are treated are temporary errors.
func isStatusTemporary(code int) bool {
	switch code {
	case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
		return true
	default:
		return false
	}
}