Skip to content
Snippets Groups Projects
http.go 7.95 KiB
Newer Older
ale's avatar
ale committed
package serverutil

import (
ale's avatar
ale committed
	"compress/gzip"
ale's avatar
ale committed
	"context"
	"crypto/tls"
ale's avatar
ale committed
	"fmt"
	"io"
ale's avatar
ale committed
	"log"
	"net"
ale's avatar
ale committed
	"net/http"
ale's avatar
ale committed
	_ "net/http/pprof"
ale's avatar
ale committed
	"os"
	"os/signal"
	"strconv"
ale's avatar
ale committed
	"syscall"
	"time"

ale's avatar
ale committed
	"git.autistici.org/ai3/go-common/tracing"
ale's avatar
ale committed
	"github.com/NYTimes/gziphandler"
ale's avatar
ale committed
	"github.com/coreos/go-systemd/v22/daemon"
ale's avatar
ale committed
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promhttp"
)

ale's avatar
ale committed
var (
	gracefulShutdownTimeout = 3 * time.Second

	gzipLevel        = gzip.BestSpeed
	gzipMinSize      = 1300
	gzipContentTypes = []string{
		"application/json",
		"application/javascript",
		"text/html",
		"text/plain",
		"text/css",
	}
)
ale's avatar
ale committed

// ServerConfig stores common HTTP/HTTPS server configuration parameters.
type ServerConfig struct {
	TLS                 *TLSServerConfig `yaml:"tls"`
	MaxInflightRequests int              `yaml:"max_inflight_requests"`
	RequestTimeoutSecs  int              `yaml:"request_timeout"`
	TrustedForwarders   []string         `yaml:"trusted_forwarders"`
ale's avatar
ale committed

	// TODO: switch do disable_compression (flip default) later.
	EnableCompression bool `yaml:"enable_compression"`
ale's avatar
ale committed
}

ale's avatar
ale committed
func (config *ServerConfig) buildHTTPHandler(h http.Handler) (http.Handler, *tls.Config, error) {
ale's avatar
ale committed
	var tlsConfig *tls.Config
	var err error
	if config != nil {
		if config.TLS != nil {
			tlsConfig, err = config.TLS.TLSConfig()
ale's avatar
ale committed
			if err != nil {
ale's avatar
ale committed
				return nil, nil, err
ale's avatar
ale committed
			}
			h, err = config.TLS.TLSAuthWrapper(h)
ale's avatar
ale committed
			if err != nil {
ale's avatar
ale committed
				return nil, nil, err
ale's avatar
ale committed
			}
		}

		// If TrustedForwarders is defined, rewrite the request
		// headers using X-Forwarded-Proto and X-Real-IP.
		if len(config.TrustedForwarders) > 0 {
			h, err = newProxyHeaders(h, config.TrustedForwarders)
			if err != nil {
ale's avatar
ale committed
				return nil, nil, err
			}
		}

		// If MaxInflightRequests is set, enable the load
		// shedding wrapper.
		if config.MaxInflightRequests > 0 {
			h = newLoadSheddingWrapper(config.MaxInflightRequests, h)
ale's avatar
ale committed
		}

		// Wrap the handler with a TimeoutHandler if 'request_timeout'
		// is set.
		if config.RequestTimeoutSecs > 0 {
			h = http.TimeoutHandler(h, time.Duration(config.RequestTimeoutSecs)*time.Second, "")
		}
ale's avatar
ale committed
	// Add all the default handlers (health, monitoring, etc).
	h = addDefaultHandlers(h)

	// Optionally enable compression.
	if config != nil && config.EnableCompression {
		gzwrap, err := gziphandler.GzipHandlerWithOpts(
			gziphandler.CompressionLevel(gzipLevel),
			gziphandler.MinSize(gzipMinSize),
			gziphandler.ContentTypes(gzipContentTypes),
		)
		if err != nil {
			return nil, nil, err
		}
		h = gzwrap(h)
	}

	return h, tlsConfig, nil
func buildListener(addr string, tlsConfig *tls.Config) (net.Listener, error) {
	// Create the net.Listener first, so we can detect
	// initialization-time errors safely.
	l, err := net.Listen("tcp", addr)
	if err != nil {
		return nil, err
	}
	if tlsConfig != nil {
		l = tls.NewListener(l, tlsConfig)
	}
	return l, nil
}

func buildServer(h http.Handler, config *ServerConfig, addr string) (*http.Server, error) {
ale's avatar
ale committed
	// Wrap with tracing handler (exclude metrics and other
	// debugging endpoints).
	h = tracing.WrapHandler(h, guessEndpointName(addr))

ale's avatar
ale committed
	// Create the top-level HTTP handler with all our additions.
	hh, tlsConfig, err := config.buildHTTPHandler(h)
	if err != nil {
ale's avatar
ale committed
	// These are not meant to be external-facing servers, so we
	// can be generous with the timeouts to keep the number of
	// reconnections low.
	srv := &http.Server{
		Handler:           hh,
		ReadHeaderTimeout: 30 * time.Second,
		IdleTimeout:       600 * time.Second,
		TLSConfig:         tlsConfig,
	}

	return srv, nil
}

// Serve HTTP(S) content on the specified address. If config.TLS is
// not nil, enable HTTPS and TLS authentication.
//
// This function will return an error if there are problems creating
// the listener, otherwise it will handle graceful termination on
// SIGINT or SIGTERM and return nil.
func Serve(h http.Handler, config *ServerConfig, addr string) error {
	srv, err := buildServer(h, config, addr)
	if err != nil {
		return err
	}

	l, err := buildListener(addr, srv.TLSConfig)
	if err != nil {
		return err
ale's avatar
ale committed
	}

	// Install a signal handler for gentle process termination.
	done := make(chan struct{})
	sigCh := make(chan os.Signal, 1)
	go func() {
		<-sigCh
		log.Printf("exiting")

		// Gracefully terminate for 3 seconds max, then shut
		// down remaining clients.
		ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
		defer cancel()
ale's avatar
ale committed
		if err = srv.Shutdown(ctx); err == context.Canceled {
			if err = srv.Close(); err != nil {
ale's avatar
ale committed
				log.Printf("error terminating server: %v", err)
			}
		}

		close(done)
	}()
ale's avatar
ale committed
	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)

	// Notify systemd that we are ready to serve. This call is
	// allowed to fail (in case there is no systemd).
	daemon.SdNotify(false, "READY=1") // nolint

	err = srv.Serve(l)
ale's avatar
ale committed
	if err != http.ErrServerClosed {
ale's avatar
ale committed
		return err
	}

	<-done
	return nil
}

// ServeWithContext operates like Serve but with a controlling Context
// that can be used to stop the HTTP server.
func ServeWithContext(ctx context.Context, h http.Handler, config *ServerConfig, addr string) error {
	srv, err := buildServer(h, config, addr)
	if err != nil {
		return err
	}

	l, err := buildListener(addr, srv.TLSConfig)
	if err != nil {
		return err
	}

	go func() {
		<-ctx.Done()

		sctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
		srv.Shutdown(sctx) // nolint: errcheck
		srv.Close()
		cancel()
	}()

	daemon.SdNotify(false, "READY=1") // nolint

	err = srv.Serve(l)
	if err == http.ErrServerClosed {
		err = nil
	}

	return err
}

func addDefaultHandlers(h http.Handler) http.Handler {
ale's avatar
ale committed
	root := http.NewServeMux()

	// Add an endpoint for HTTP health checking probes.
	root.Handle("/health", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		io.WriteString(w, "OK") // nolint
	}))

	// Add an endpoint to serve Prometheus metrics.
ale's avatar
ale committed
	root.Handle("/metrics", promhttp.Handler())
ale's avatar
ale committed
	// Let the default net/http handler deal with /debug/
	// URLs. Packages such as net/http/pprof register their
	// handlers there in ways that aren't reproducible.
	root.Handle("/debug/", http.DefaultServeMux)

	// Forward everything else to the main handler, adding
	// Prometheus instrumentation (requests to /metrics and
	// /health are not included).
	root.Handle("/", promhttp.InstrumentHandlerInFlight(inFlightRequests,
		promhttp.InstrumentHandlerCounter(totalRequests,
			propagateDeadline(h))))

	return root
ale's avatar
ale committed
}

const deadlineHeader = "X-RPC-Deadline"

// Read an eventual deadline from the HTTP request, and set it as the
// deadline of the request context.
func propagateDeadline(h http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		if hdr := req.Header.Get(deadlineHeader); hdr != "" {
			if deadlineNano, err := strconv.ParseInt(hdr, 10, 64); err == nil {
				deadline := time.Unix(0, deadlineNano)
				ctx, cancel := context.WithDeadline(req.Context(), deadline)
				defer cancel()
				req = req.WithContext(ctx)
			}
		}
		h.ServeHTTP(w, req)
	})
}

ale's avatar
ale committed
func guessEndpointName(addr string) string {
	_, port, err := net.SplitHostPort(addr)
	if err != nil {
		return addr
	}
	host, err := os.Hostname()
	if err != nil {
		return addr
	}
	return fmt.Sprintf("%s:%s", host, port)
}

ale's avatar
ale committed
// HTTP-related metrics.
var (
ale's avatar
ale committed
	// Since we instrument the root HTTP handler, we don't really
	// have a good way to set the 'handler' label based on the
	// request URL - but still, we'd like to set the label to
	// match what the other Prometheus jobs do. So we just set it
	// to 'all'.
ale's avatar
ale committed
	totalRequests = prometheus.NewCounterVec(
		prometheus.CounterOpts{
ale's avatar
ale committed
			Name: "http_requests_total",
ale's avatar
ale committed
			Help: "Total number of requests.",
ale's avatar
ale committed
			ConstLabels: prometheus.Labels{
				"handler": "all",
			},
ale's avatar
ale committed
		},
ale's avatar
ale committed
		[]string{"code", "method"},
ale's avatar
ale committed
	)
	inFlightRequests = prometheus.NewGauge(
		prometheus.GaugeOpts{
ale's avatar
ale committed
			Name: "http_requests_inflight",
ale's avatar
ale committed
			Help: "Number of in-flight requests.",
		},
	)
)

func init() {
	prometheus.MustRegister(totalRequests, inFlightRequests)
}