http.go 5.2 KB
Newer Older
ale's avatar
ale committed
1 2 3 4 5
package serverutil

import (
	"context"
	"crypto/tls"
ale's avatar
ale committed
6
	"fmt"
ale's avatar
ale committed
7 8 9 10
	"io"
	"log"
	"net"
	"net/http"
ale's avatar
ale committed
11
	_ "net/http/pprof"
ale's avatar
ale committed
12 13 14 15 16
	"os"
	"os/signal"
	"syscall"
	"time"

ale's avatar
ale committed
17
	"git.autistici.org/ai3/go-common/tracing"
ale's avatar
ale committed
18
	"github.com/coreos/go-systemd/daemon"
ale's avatar
ale committed
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promhttp"
)

var gracefulShutdownTimeout = 3 * time.Second

// ServerConfig stores common HTTP/HTTPS server configuration parameters.
type ServerConfig struct {
	TLS                 *TLSServerConfig `yaml:"tls"`
	MaxInflightRequests int              `yaml:"max_inflight_requests"`
	TrustedForwarders   []string         `yaml:"trusted_forwarders"`
}

func (config *ServerConfig) buildHTTPServer(h http.Handler) (*http.Server, error) {
	var tlsConfig *tls.Config
	var err error
	if config != nil {
		if config.TLS != nil {
			tlsConfig, err = config.TLS.TLSConfig()
			if err != nil {
				return nil, err
			}
			h, err = config.TLS.TLSAuthWrapper(h)
			if err != nil {
				return nil, err
			}
		}

		// If TrustedForwarders is defined, rewrite the request
		// headers using X-Forwarded-Proto and X-Real-IP.
		if len(config.TrustedForwarders) > 0 {
			h, err = newProxyHeaders(h, config.TrustedForwarders)
			if err != nil {
				return nil, err
			}
		}

		// If MaxInflightRequests is set, enable the load
		// shedding wrapper.
		if config.MaxInflightRequests > 0 {
			h = newLoadSheddingWrapper(config.MaxInflightRequests, h)
		}
	}

	// These are not meant to be external-facing servers, so we
	// can be generous with the timeouts to keep the number of
	// reconnections low.
	return &http.Server{
		Handler:      defaultHandler(h),
		ReadTimeout:  30 * time.Second,
		WriteTimeout: 30 * time.Second,
		IdleTimeout:  600 * time.Second,
		TLSConfig:    tlsConfig,
	}, nil
}

// Serve HTTP(S) content on the specified address. If config.TLS is
// not nil, enable HTTPS and TLS authentication.
//
// This function will return an error if there are problems creating
// the listener, otherwise it will handle graceful termination on
// SIGINT or SIGTERM and return nil.
func Serve(h http.Handler, config *ServerConfig, addr string) error {
ale's avatar
ale committed
82 83 84 85
	// Wrap with tracing handler (exclude metrics and other
	// debugging endpoints).
	h = tracing.WrapHandler(h, guessEndpointName(addr))

ale's avatar
ale committed
86 87 88 89 90 91
	// Create the HTTP server.
	srv, err := config.buildHTTPServer(h)
	if err != nil {
		return err
	}

ale's avatar
ale committed
92 93
	// Create the net.Listener first, so we can detect
	// initialization-time errors safely.
ale's avatar
ale committed
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
	l, err := net.Listen("tcp", addr)
	if err != nil {
		return err
	}
	if srv.TLSConfig != nil {
		l = tls.NewListener(l, srv.TLSConfig)
	}

	// Install a signal handler for gentle process termination.
	done := make(chan struct{})
	sigCh := make(chan os.Signal, 1)
	go func() {
		<-sigCh
		log.Printf("exiting")

		// Gracefully terminate for 3 seconds max, then shut
		// down remaining clients.
		ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
		defer cancel()
		if err = srv.Shutdown(ctx); err == context.Canceled {
			if err = srv.Close(); err != nil {
				log.Printf("error terminating server: %v", err)
			}
		}

		close(done)
	}()

	signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)

ale's avatar
ale committed
124 125 126
	// Notify systemd that we are ready to serve. This call is
	// allowed to fail (in case there is no systemd).
	daemon.SdNotify(false, "READY=1") // nolint
ale's avatar
ale committed
127

ale's avatar
ale committed
128 129 130 131 132 133 134 135 136 137 138 139 140 141
	err = srv.Serve(l)
	if err != http.ErrServerClosed {
		return err
	}

	<-done
	return nil
}

func defaultHandler(h http.Handler) http.Handler {
	root := http.NewServeMux()

	// Add an endpoint for HTTP health checking probes.
	root.Handle("/health", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
ale's avatar
ale committed
142
		io.WriteString(w, "OK") // nolint
ale's avatar
ale committed
143 144 145 146 147
	}))

	// Add an endpoint to serve Prometheus metrics.
	root.Handle("/metrics", promhttp.Handler())

ale's avatar
ale committed
148 149 150 151
	// Let the default net/http handler deal with /debug/
	// URLs. Packages such as net/http/pprof register their
	// handlers there in ways that aren't reproducible.
	root.Handle("/debug/", http.DefaultServeMux)
ale's avatar
ale committed
152 153 154 155 156 157 158 159 160 161

	// Forward everything else to the main handler, adding
	// Prometheus instrumentation (requests to /metrics and
	// /health are not included).
	root.Handle("/", promhttp.InstrumentHandlerInFlight(inFlightRequests,
		promhttp.InstrumentHandlerCounter(totalRequests, h)))

	return root
}

ale's avatar
ale committed
162 163 164 165 166 167 168 169 170 171 172 173
func guessEndpointName(addr string) string {
	_, port, err := net.SplitHostPort(addr)
	if err != nil {
		return addr
	}
	host, err := os.Hostname()
	if err != nil {
		return addr
	}
	return fmt.Sprintf("%s:%s", host, port)
}

ale's avatar
ale committed
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
// HTTP-related metrics.
var (
	// Since we instrument the root HTTP handler, we don't really
	// have a good way to set the 'handler' label based on the
	// request URL - but still, we'd like to set the label to
	// match what the other Prometheus jobs do. So we just set it
	// to 'all'.
	totalRequests = prometheus.NewCounterVec(
		prometheus.CounterOpts{
			Name: "http_requests_total",
			Help: "Total number of requests.",
			ConstLabels: prometheus.Labels{
				"handler": "all",
			},
		},
		[]string{"code", "method"},
	)
	inFlightRequests = prometheus.NewGauge(
		prometheus.GaugeOpts{
			Name: "http_requests_inflight",
			Help: "Number of in-flight requests.",
		},
	)
)

func init() {
	prometheus.MustRegister(totalRequests, inFlightRequests)
}