From 0a3d704ccb8c5b1ef9497ef44d1d1ce719dec459 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Mon, 27 Nov 2017 08:45:52 +0000
Subject: [PATCH] Add load shedding wrapper, Prometheus metrics

---
 serverutil/http.go          | 57 +++++++++++++++++++++++++++++++++----
 serverutil/load_shedding.go | 51 +++++++++++++++++++++++++++++++++
 2 files changed, 102 insertions(+), 6 deletions(-)
 create mode 100644 serverutil/load_shedding.go

diff --git a/serverutil/http.go b/serverutil/http.go
index 079b933..1c218fe 100644
--- a/serverutil/http.go
+++ b/serverutil/http.go
@@ -9,36 +9,52 @@ import (
 	"os/signal"
 	"syscall"
 	"time"
+
+	"github.com/prometheus/client_golang/prometheus"
+	"github.com/prometheus/client_golang/prometheus/promhttp"
 )
 
+var gracefulShutdownTimeout = 3 * time.Second
+
+// ServerConfig stores common HTTP/HTTPS server configuration parameters.
+type ServerConfig struct {
+	TLS                 *TLSServerConfig `yaml:"tls"`
+	MaxInflightRequests int              `yaml:"max_inflight_requests"`
+}
+
 // Serve HTTP(S) content on the specified address. If serverConfig is
 // not nil, enable HTTPS and TLS authentication.
 //
 // This function will return an error if there are problems creating
 // the listener, otherwise it will handle graceful termination on
 // SIGINT or SIGTERM and return nil.
-func Serve(h http.Handler, serverConfig *TLSServerConfig, addr string) (err error) {
+func Serve(h http.Handler, serverConfig *ServerConfig, addr string) (err error) {
 	var tlsConfig *tls.Config
-	if serverConfig != nil {
-		tlsConfig, err = serverConfig.TLSConfig()
+	if serverConfig.TLS != nil {
+		tlsConfig, err = serverConfig.TLS.TLSConfig()
 		if err != nil {
 			return err
 		}
-		h, err = serverConfig.TLSAuthWrapper(h)
+		h, err = serverConfig.TLS.TLSAuthWrapper(h)
 		if err != nil {
 			return err
 		}
 	}
 
+	if serverConfig.MaxInflightRequests > 0 {
+		h = newLoadSheddingWrapper(serverConfig.MaxInflightRequests, h)
+	}
+
 	srv := &http.Server{
 		Addr:         addr,
-		Handler:      h,
+		Handler:      instrumentHandler(h),
 		ReadTimeout:  30 * time.Second,
 		WriteTimeout: 30 * time.Second,
 		IdleTimeout:  60 * time.Second,
 		TLSConfig:    tlsConfig,
 	}
 
+	// Install a signal handler for gentle process termination.
 	done := make(chan struct{})
 	sigCh := make(chan os.Signal, 1)
 	go func() {
@@ -47,7 +63,7 @@ func Serve(h http.Handler, serverConfig *TLSServerConfig, addr string) (err erro
 
 		// Gracefully terminate for 3 seconds max, then shut
 		// down remaining clients.
-		ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+		ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
 		defer cancel()
 		if err := srv.Shutdown(ctx); err == context.Canceled {
 			if err := srv.Close(); err != nil {
@@ -66,3 +82,32 @@ func Serve(h http.Handler, serverConfig *TLSServerConfig, addr string) (err erro
 	<-done
 	return nil
 }
+
+func instrumentHandler(h http.Handler) http.Handler {
+	root := http.NewServeMux()
+	root.Handle("/metrics", promhttp.Handler())
+	root.Handle("/", h)
+	return promhttp.InstrumentHandlerInFlight(inFlightRequests,
+		promhttp.InstrumentHandlerCounter(totalRequests, root))
+}
+
+// HTTP-related metrics.
+var (
+	totalRequests = prometheus.NewCounterVec(
+		prometheus.CounterOpts{
+			Name: "total_requests",
+			Help: "Total number of requests.",
+		},
+		[]string{"code"},
+	)
+	inFlightRequests = prometheus.NewGauge(
+		prometheus.GaugeOpts{
+			Name: "inflight_requests",
+			Help: "Number of in-flight requests.",
+		},
+	)
+)
+
+func init() {
+	prometheus.MustRegister(totalRequests, inFlightRequests)
+}
diff --git a/serverutil/load_shedding.go b/serverutil/load_shedding.go
new file mode 100644
index 0000000..beb2ae0
--- /dev/null
+++ b/serverutil/load_shedding.go
@@ -0,0 +1,51 @@
+package serverutil
+
+import (
+	"net/http"
+	"sync/atomic"
+
+	"github.com/prometheus/client_golang/prometheus"
+)
+
+type loadSheddingWrapper struct {
+	limit, inflight int32
+	h               http.Handler
+}
+
+func newLoadSheddingWrapper(limit int, h http.Handler) *loadSheddingWrapper {
+	return &loadSheddingWrapper{limit: int32(limit), h: h}
+}
+
+func (l *loadSheddingWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	inflight := atomic.AddInt32(&l.inflight, 1)
+	defer atomic.AddInt32(&l.inflight, -1)
+
+	if inflight > l.limit {
+		throttledRequests.Inc()
+		w.Header().Set("Connection", "close")
+		http.Error(w, "Throttled", http.StatusTooManyRequests)
+		return
+	}
+
+	allowedRequests.Inc()
+	l.h.ServeHTTP(w, r)
+}
+
+var (
+	throttledRequests = prometheus.NewCounter(
+		prometheus.CounterOpts{
+			Name: "ls_throttled_requests",
+			Help: "Requests throttled by the load shedding wrapper.",
+		},
+	)
+	allowedRequests = prometheus.NewCounter(
+		prometheus.CounterOpts{
+			Name: "ls_allowed_requests",
+			Help: "Requests allowed by the load shedding wrapper.",
+		},
+	)
+)
+
+func init() {
+	prometheus.MustRegister(throttledRequests, allowedRequests)
+}
-- 
GitLab