package serverutil import ( "context" "crypto/tls" "fmt" "io" "log" "net" "net/http" _ "net/http/pprof" "os" "os/signal" "strconv" "syscall" "time" "git.autistici.org/ai3/go-common/tracing" "github.com/coreos/go-systemd/daemon" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" ) var gracefulShutdownTimeout = 3 * time.Second // ServerConfig stores common HTTP/HTTPS server configuration parameters. type ServerConfig struct { TLS *TLSServerConfig `yaml:"tls"` MaxInflightRequests int `yaml:"max_inflight_requests"` RequestTimeoutSecs int `yaml:"request_timeout"` TrustedForwarders []string `yaml:"trusted_forwarders"` } func (config *ServerConfig) buildHTTPServer(h http.Handler) (*http.Server, error) { var tlsConfig *tls.Config var err error if config != nil { if config.TLS != nil { tlsConfig, err = config.TLS.TLSConfig() if err != nil { return nil, err } h, err = config.TLS.TLSAuthWrapper(h) if err != nil { return nil, err } } // If TrustedForwarders is defined, rewrite the request // headers using X-Forwarded-Proto and X-Real-IP. if len(config.TrustedForwarders) > 0 { h, err = newProxyHeaders(h, config.TrustedForwarders) if err != nil { return nil, err } } // If MaxInflightRequests is set, enable the load // shedding wrapper. if config.MaxInflightRequests > 0 { h = newLoadSheddingWrapper(config.MaxInflightRequests, h) } } // Wrap the handler with a TimeoutHandler if 'request_timeout' // is set. h = addDefaultHandlers(h) if config.RequestTimeoutSecs > 0 { h = http.TimeoutHandler(h, time.Duration(config.RequestTimeoutSecs)*time.Second, "") } // These are not meant to be external-facing servers, so we // can be generous with the timeouts to keep the number of // reconnections low. return &http.Server{ Handler: h, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 600 * time.Second, TLSConfig: tlsConfig, }, nil } // Serve HTTP(S) content on the specified address. If config.TLS is // not nil, enable HTTPS and TLS authentication. // // This function will return an error if there are problems creating // the listener, otherwise it will handle graceful termination on // SIGINT or SIGTERM and return nil. func Serve(h http.Handler, config *ServerConfig, addr string) error { // Wrap with tracing handler (exclude metrics and other // debugging endpoints). h = tracing.WrapHandler(h, guessEndpointName(addr)) // Create the HTTP server. srv, err := config.buildHTTPServer(h) if err != nil { return err } // Create the net.Listener first, so we can detect // initialization-time errors safely. l, err := net.Listen("tcp", addr) if err != nil { return err } if srv.TLSConfig != nil { l = tls.NewListener(l, srv.TLSConfig) } // Install a signal handler for gentle process termination. done := make(chan struct{}) sigCh := make(chan os.Signal, 1) go func() { <-sigCh log.Printf("exiting") // Gracefully terminate for 3 seconds max, then shut // down remaining clients. ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout) defer cancel() if err = srv.Shutdown(ctx); err == context.Canceled { if err = srv.Close(); err != nil { log.Printf("error terminating server: %v", err) } } close(done) }() signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) // Notify systemd that we are ready to serve. This call is // allowed to fail (in case there is no systemd). daemon.SdNotify(false, "READY=1") // nolint err = srv.Serve(l) if err != http.ErrServerClosed { return err } <-done return nil } func addDefaultHandlers(h http.Handler) http.Handler { root := http.NewServeMux() // Add an endpoint for HTTP health checking probes. root.Handle("/health", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "OK") // nolint })) // Add an endpoint to serve Prometheus metrics. root.Handle("/metrics", promhttp.Handler()) // Let the default net/http handler deal with /debug/ // URLs. Packages such as net/http/pprof register their // handlers there in ways that aren't reproducible. root.Handle("/debug/", http.DefaultServeMux) // Forward everything else to the main handler, adding // Prometheus instrumentation (requests to /metrics and // /health are not included). root.Handle("/", promhttp.InstrumentHandlerInFlight(inFlightRequests, promhttp.InstrumentHandlerCounter(totalRequests, propagateDeadline(h)))) return root } const deadlineHeader = "X-RPC-Deadline" // Read an eventual deadline from the HTTP request, and set it as the // deadline of the request context. func propagateDeadline(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if hdr := req.Header.Get(deadlineHeader); hdr != "" { if deadlineNano, err := strconv.ParseInt(hdr, 10, 64); err == nil { deadline := time.Unix(0, deadlineNano) ctx, cancel := context.WithDeadline(req.Context(), deadline) defer cancel() req = req.WithContext(ctx) } } h.ServeHTTP(w, req) }) } func guessEndpointName(addr string) string { _, port, err := net.SplitHostPort(addr) if err != nil { return addr } host, err := os.Hostname() if err != nil { return addr } return fmt.Sprintf("%s:%s", host, port) } // HTTP-related metrics. var ( // Since we instrument the root HTTP handler, we don't really // have a good way to set the 'handler' label based on the // request URL - but still, we'd like to set the label to // match what the other Prometheus jobs do. So we just set it // to 'all'. totalRequests = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "http_requests_total", Help: "Total number of requests.", ConstLabels: prometheus.Labels{ "handler": "all", }, }, []string{"code", "method"}, ) inFlightRequests = prometheus.NewGauge( prometheus.GaugeOpts{ Name: "http_requests_inflight", Help: "Number of in-flight requests.", }, ) ) func init() { prometheus.MustRegister(totalRequests, inFlightRequests) }