diff --git a/clientutil/balancer.go b/clientutil/balancer.go index 8f6e99f197e648ea18bf4024874ec0f5c5594c9a..5c37d6bfcb7e97c0c249d21c236d90df7dcc1637 100644 --- a/clientutil/balancer.go +++ b/clientutil/balancer.go @@ -221,7 +221,7 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque client := &http.Client{ Transport: b.transportCache.getTransport(target), } - resp, err = client.Do(req.WithContext(ctx)) + resp, err = client.Do(propagateDeadline(ctx, req)) if err == nil && resp.StatusCode != 200 { err = remoteErrorFromResponse(resp) if !isStatusTemporary(resp.StatusCode) { @@ -235,6 +235,19 @@ func (b *balancedBackend) do(ctx context.Context, seq *sequence, req *http.Reque return } +const deadlineHeader = "X-RPC-Deadline" + +// Propagate context deadline to the server using a HTTP header. +func propagateDeadline(ctx context.Context, req *http.Request) *http.Request { + req = req.WithContext(ctx) + if deadline, ok := ctx.Deadline(); ok { + req.Header.Set(deadlineHeader, strconv.FormatInt(deadline.UTC().UnixNano(), 10)) + } else { + req.Header.Del(deadlineHeader) + } + return req +} + var errNoTargets = errors.New("no available backends") type targetGenerator interface { diff --git a/serverutil/http.go b/serverutil/http.go index f1c69b0b64e7d06fa0ad4ffa34b03c1330e80ce4..641c0b9bcae0001b66322e04b28c9da981617424 100644 --- a/serverutil/http.go +++ b/serverutil/http.go @@ -11,6 +11,7 @@ import ( _ "net/http/pprof" "os" "os/signal" + "strconv" "syscall" "time" @@ -162,11 +163,30 @@ func addDefaultHandlers(h http.Handler) http.Handler { // Prometheus instrumentation (requests to /metrics and // /health are not included). root.Handle("/", promhttp.InstrumentHandlerInFlight(inFlightRequests, - promhttp.InstrumentHandlerCounter(totalRequests, h))) + promhttp.InstrumentHandlerCounter(totalRequests, + propagateDeadline(h)))) return root } +const deadlineHeader = "X-RPC-Deadline" + +// Read an eventual deadline from the HTTP request, and set it as the +// deadline of the request context. +func propagateDeadline(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if hdr := req.Header.Get(deadlineHeader); hdr != "" { + if deadlineNano, err := strconv.ParseInt(hdr, 10, 64); err == nil { + deadline := time.Unix(0, deadlineNano) + ctx, cancel := context.WithDeadline(req.Context(), deadline) + defer cancel() + req = req.WithContext(ctx) + } + } + h.ServeHTTP(w, req) + }) +} + func guessEndpointName(addr string) string { _, port, err := net.SplitHostPort(addr) if err != nil {