Skip to content
Snippets Groups Projects
handlers.go 4.78 KiB
Newer Older
  • Learn to ignore specific revisions
  • ale's avatar
    ale committed
    // Copyright 2013 The Gorilla Authors. All rights reserved.
    // Use of this source code is governed by a BSD-style
    // license that can be found in the LICENSE file.
    
    package handlers
    
    import (
    	"bufio"
    	"fmt"
    	"net"
    	"net/http"
    	"sort"
    	"strings"
    )
    
    // MethodHandler is an http.Handler that dispatches to a handler whose key in the
    // MethodHandler's map matches the name of the HTTP request's method, eg: GET
    //
    // If the request's method is OPTIONS and OPTIONS is not a key in the map then
    // the handler responds with a status of 200 and sets the Allow header to a
    // comma-separated list of available methods.
    //
    // If the request's method doesn't match any of its keys the handler responds
    // with a status of HTTP 405 "Method Not Allowed" and sets the Allow header to a
    // comma-separated list of available methods.
    type MethodHandler map[string]http.Handler
    
    func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    	if handler, ok := h[req.Method]; ok {
    		handler.ServeHTTP(w, req)
    	} else {
    		allow := []string{}
    		for k := range h {
    			allow = append(allow, k)
    		}
    		sort.Strings(allow)
    		w.Header().Set("Allow", strings.Join(allow, ", "))
    		if req.Method == "OPTIONS" {
    			w.WriteHeader(http.StatusOK)
    		} else {
    			http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
    		}
    	}
    }
    
    // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP
    // status code and body size
    type responseLogger struct {
    	w      http.ResponseWriter
    	status int
    	size   int
    }
    
    func (l *responseLogger) Header() http.Header {
    	return l.w.Header()
    }
    
    func (l *responseLogger) Write(b []byte) (int, error) {
    	size, err := l.w.Write(b)
    	l.size += size
    	return size, err
    }
    
    func (l *responseLogger) WriteHeader(s int) {
    	l.w.WriteHeader(s)
    	l.status = s
    }
    
    func (l *responseLogger) Status() int {
    	return l.status
    }
    
    func (l *responseLogger) Size() int {
    	return l.size
    }
    
    func (l *responseLogger) Flush() {
    	f, ok := l.w.(http.Flusher)
    	if ok {
    		f.Flush()
    	}
    }
    
    type hijackLogger struct {
    	responseLogger
    }
    
    func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
    	h := l.responseLogger.w.(http.Hijacker)
    	conn, rw, err := h.Hijack()
    	if err == nil && l.responseLogger.status == 0 {
    		// The status will be StatusSwitchingProtocols if there was no error and
    		// WriteHeader has not been called yet
    		l.responseLogger.status = http.StatusSwitchingProtocols
    	}
    	return conn, rw, err
    }
    
    type closeNotifyWriter struct {
    	loggingResponseWriter
    	http.CloseNotifier
    }
    
    type hijackCloseNotifier struct {
    	loggingResponseWriter
    	http.Hijacker
    	http.CloseNotifier
    }
    
    // isContentType validates the Content-Type header matches the supplied
    // contentType. That is, its type and subtype match.
    func isContentType(h http.Header, contentType string) bool {
    	ct := h.Get("Content-Type")
    	if i := strings.IndexRune(ct, ';'); i != -1 {
    		ct = ct[0:i]
    	}
    	return ct == contentType
    }
    
    // ContentTypeHandler wraps and returns a http.Handler, validating the request
    // content type is compatible with the contentTypes list. It writes a HTTP 415
    // error if that fails.
    //
    // Only PUT, POST, and PATCH requests are considered.
    func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler {
    	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if !(r.Method == "PUT" || r.Method == "POST" || r.Method == "PATCH") {
    			h.ServeHTTP(w, r)
    			return
    		}
    
    		for _, ct := range contentTypes {
    			if isContentType(r.Header, ct) {
    				h.ServeHTTP(w, r)
    				return
    			}
    		}
    		http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType)
    	})
    }
    
    const (
    	// HTTPMethodOverrideHeader is a commonly used
    	// http header to override a request method.
    	HTTPMethodOverrideHeader = "X-HTTP-Method-Override"
    	// HTTPMethodOverrideFormKey is a commonly used
    	// HTML form key to override a request method.
    	HTTPMethodOverrideFormKey = "_method"
    )
    
    // HTTPMethodOverrideHandler wraps and returns a http.Handler which checks for
    // the X-HTTP-Method-Override header or the _method form key, and overrides (if
    // valid) request.Method with its value.
    //
    // This is especially useful for HTTP clients that don't support many http verbs.
    // It isn't secure to override e.g a GET to a POST, so only POST requests are
    // considered.  Likewise, the override method can only be a "write" method: PUT,
    // PATCH or DELETE.
    //
    // Form method takes precedence over header method.
    func HTTPMethodOverrideHandler(h http.Handler) http.Handler {
    	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if r.Method == "POST" {
    			om := r.FormValue(HTTPMethodOverrideFormKey)
    			if om == "" {
    				om = r.Header.Get(HTTPMethodOverrideHeader)
    			}
    			if om == "PUT" || om == "PATCH" || om == "DELETE" {
    				r.Method = om
    			}
    		}
    		h.ServeHTTP(w, r)
    	})
    }