package serverutil

import (
	"net/http"
	"sync/atomic"

	"github.com/prometheus/client_golang/prometheus"
)

type loadSheddingWrapper struct {
	limit, inflight int32
	h               http.Handler
}

func newLoadSheddingWrapper(limit int, h http.Handler) *loadSheddingWrapper {
	return &loadSheddingWrapper{limit: int32(limit), h: h}
}

func (l *loadSheddingWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	inflight := atomic.AddInt32(&l.inflight, 1)
	defer atomic.AddInt32(&l.inflight, -1)

	if inflight > l.limit {
		throttledRequests.Inc()
		w.Header().Set("Connection", "close")
		http.Error(w, "Throttled", http.StatusTooManyRequests)
		return
	}

	allowedRequests.Inc()
	l.h.ServeHTTP(w, r)
}

var (
	throttledRequests = prometheus.NewCounter(
		prometheus.CounterOpts{
			Name: "ls_throttled_requests",
			Help: "Requests throttled by the load shedding wrapper.",
		},
	)
	allowedRequests = prometheus.NewCounter(
		prometheus.CounterOpts{
			Name: "ls_allowed_requests",
			Help: "Requests allowed by the load shedding wrapper.",
		},
	)
)

func init() {
	prometheus.MustRegister(throttledRequests, allowedRequests)
}