Skip to content
Snippets Groups Projects
Commit 0868f264 authored by ale's avatar ale
Browse files

Reduce timeout on calls to backends with multiple targets

Allows for failover in case a target times out.
parent 3f5c3e1b
No related branches found
No related tags found
No related merge requests found
...@@ -98,28 +98,35 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack ...@@ -98,28 +98,35 @@ func newBalancedBackend(config *BackendConfig, resolver resolver) (*balancedBack
// with a JSON-encoded request body. It will attempt to decode the // with a JSON-encoded request body. It will attempt to decode the
// response body as JSON. // response body as JSON.
func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, resp interface{}) error { func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, resp interface{}) error {
// Serialize the request body.
data, err := json.Marshal(req) data, err := json.Marshal(req)
if err != nil { if err != nil {
return err return err
} }
var tg targetGenerator = b.backendTracker // Create the target sequence for this call. If there are multiple
if b.sharded { // targets, reduce the timeout on each individual call accordingly to
if shard == "" { // accomodate eventual failover.
return fmt.Errorf("call without shard to sharded service %s", b.baseURI.String()) seq, err := b.makeSequence(shard)
} if err != nil {
tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver) return err
}
innerTimeout := 1 * time.Hour
if deadline, ok := ctx.Deadline(); ok {
innerTimeout = time.Until(deadline) / time.Duration(seq.Len())
} }
seq := newSequence(tg)
b.log.Printf("%016x: initialized", seq.ID())
// Call the backends in the sequence until one succeeds, with an
// exponential backoff policy controlled by the outer Context.
var httpResp *http.Response var httpResp *http.Response
err = backoff.Retry(func() error { err = backoff.Retry(func() error {
req, rerr := b.newJSONRequest(path, shard, data) req, rerr := b.newJSONRequest(path, shard, data)
if rerr != nil { if rerr != nil {
return rerr return rerr
} }
httpResp, rerr = b.do(ctx, seq, req) innerCtx, cancel := context.WithTimeout(ctx, innerTimeout)
httpResp, rerr = b.do(innerCtx, seq, req)
cancel()
return rerr return rerr
}, backoff.WithContext(newExponentialBackOff(), ctx)) }, backoff.WithContext(newExponentialBackOff(), ctx))
if err != nil { if err != nil {
...@@ -127,16 +134,34 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res ...@@ -127,16 +134,34 @@ func (b *balancedBackend) Call(ctx context.Context, shard, path string, req, res
} }
defer httpResp.Body.Close() // nolint defer httpResp.Body.Close() // nolint
// Decode the response.
if httpResp.Header.Get("Content-Type") != "application/json" { if httpResp.Header.Get("Content-Type") != "application/json" {
return errors.New("not a JSON response") return errors.New("not a JSON response")
} }
if resp == nil { if resp == nil {
return nil return nil
} }
return json.NewDecoder(httpResp.Body).Decode(resp) return json.NewDecoder(httpResp.Body).Decode(resp)
} }
// Initialize a new target sequence.
func (b *balancedBackend) makeSequence(shard string) (*sequence, error) {
var tg targetGenerator = b.backendTracker
if b.sharded {
if shard == "" {
return nil, fmt.Errorf("call without shard to sharded service %s", b.baseURI.String())
}
tg = newShardedGenerator(shard, b.baseURI.Host, b.resolver)
}
seq := newSequence(tg)
if seq.Len() == 0 {
return nil, errNoTargets
}
b.log.Printf("%016x: initialized", seq.ID())
return seq, nil
}
// Return the URI to be used for the request. This is used both in the // Return the URI to be used for the request. This is used both in the
// Host HTTP header and as the TLS server name used to pick a server // Host HTTP header and as the TLS server name used to pick a server
// certificate (if using TLS). // certificate (if using TLS).
...@@ -213,6 +238,8 @@ func newSequence(tg targetGenerator) *sequence { ...@@ -213,6 +238,8 @@ func newSequence(tg targetGenerator) *sequence {
func (s *sequence) ID() uint64 { return s.id } func (s *sequence) ID() uint64 { return s.id }
func (s *sequence) Len() int { return len(s.targets) }
func (s *sequence) reloadTargets() { func (s *sequence) reloadTargets() {
targets := s.tg.getTargets() targets := s.tg.getTargets()
if len(targets) > 0 { if len(targets) > 0 {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment