diff --git a/go.mod b/go.mod index 93702e4b44260b30d3a238ab351d003bcc7fc416..50772127bd76e75719bb4ca2533c3f4af2086e77 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/crewjam/saml v0.4.14 github.com/elazarl/go-bindata-assetfs v1.0.1 github.com/go-webauthn/webauthn v0.12.2 - github.com/gorilla/csrf v1.7.2 + github.com/gorilla/csrf v1.7.3 github.com/gorilla/mux v1.8.1 github.com/gorilla/securecookie v1.1.2 github.com/mssola/user_agent v0.6.0 diff --git a/go.sum b/go.sum index 8c98017d2e093588575f89c4b40c6531d1fd60e6..d82ef39b71ae2ebdda91a741dc2fa256d7e9cc8e 100644 --- a/go.sum +++ b/go.sum @@ -432,6 +432,8 @@ github.com/goreleaser/nfpm v1.2.1/go.mod h1:TtWrABZozuLOttX2uDlYyECfQX7x5XYkVxhj github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= diff --git a/vendor/github.com/gorilla/csrf/csrf.go b/vendor/github.com/gorilla/csrf/csrf.go index 97a392568c06f657c01f125d07b4b358ad1b0308..5dda2547ca1434cf885daf31a6b9fab63c054566 100644 --- a/vendor/github.com/gorilla/csrf/csrf.go +++ b/vendor/github.com/gorilla/csrf/csrf.go @@ -1,10 +1,12 @@ package csrf import ( + "context" "errors" "fmt" "net/http" "net/url" + "slices" "github.com/gorilla/securecookie" ) @@ -22,6 +24,14 @@ const ( errorPrefix string = "gorilla/csrf: " ) +type contextKey string + +// PlaintextHTTPContextKey is the context key used to store whether the request +// is being served via plaintext HTTP. This is used to signal to the middleware +// that strict Referer checking should not be enforced as is done for HTTPS by +// default. +const PlaintextHTTPContextKey contextKey = "plaintext" + var ( // The name value used in form fields. fieldName = tokenKey @@ -41,6 +51,9 @@ var ( // ErrNoReferer is returned when a HTTPS request provides an empty Referer // header. ErrNoReferer = errors.New("referer not supplied") + // ErrBadOrigin is returned when the Origin header is present and is not a + // trusted origin. + ErrBadOrigin = errors.New("origin invalid") // ErrBadReferer is returned when the scheme & host in the URL do not match // the supplied Referer header. ErrBadReferer = errors.New("referer invalid") @@ -242,10 +255,50 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // HTTP methods not defined as idempotent ("safe") under RFC7231 require // inspection. if !contains(safeMethods, r.Method) { - // Enforce an origin check for HTTPS connections. As per the Django CSRF - // implementation (https://goo.gl/vKA7GE) the Referer header is almost - // always present for same-domain HTTP requests. - if r.URL.Scheme == "https" { + var isPlaintext bool + val := r.Context().Value(PlaintextHTTPContextKey) + if val != nil { + isPlaintext, _ = val.(bool) + } + + // take a copy of the request URL to avoid mutating the original + // attached to the request. + // set the scheme & host based on the request context as these are not + // populated by default for server requests + // ref: https://pkg.go.dev/net/http#Request + requestURL := *r.URL // shallow clone + + requestURL.Scheme = "https" + if isPlaintext { + requestURL.Scheme = "http" + } + if requestURL.Host == "" { + requestURL.Host = r.Host + } + + // if we have an Origin header, check it against our allowlist + origin := r.Header.Get("Origin") + if origin != "" { + parsedOrigin, err := url.Parse(origin) + if err != nil { + r = envError(r, ErrBadOrigin) + cs.opts.ErrorHandler.ServeHTTP(w, r) + return + } + if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) { + r = envError(r, ErrBadOrigin) + cs.opts.ErrorHandler.ServeHTTP(w, r) + return + } + } + + // If we are serving via TLS and have no Origin header, prevent against + // CSRF via HTTP machine in the middle attacks by enforcing strict + // Referer origin checks. Consider an attacker who performs a + // successful HTTP Machine-in-the-Middle attack and uses this to inject + // a form and cause submission to our origin. We strictly disallow + // cleartext HTTP origins and evaluate the domain against an allowlist. + if origin == "" && !isPlaintext { // Fetch the Referer value. Call the error handler if it's empty or // otherwise fails to parse. referer, err := url.Parse(r.Referer()) @@ -255,18 +308,17 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - valid := sameOrigin(r.URL, referer) - - if !valid { - for _, trustedOrigin := range cs.opts.TrustedOrigins { - if referer.Host == trustedOrigin { - valid = true - break - } - } + // disallow cleartext HTTP referers when serving via TLS + if referer.Scheme == "http" { + r = envError(r, ErrBadReferer) + cs.opts.ErrorHandler.ServeHTTP(w, r) + return } - if !valid { + // If the request is being served via TLS and the Referer is not the + // same origin, check the domain against our allowlist. We only + // check when we have host information from the referer. + if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) { r = envError(r, ErrBadReferer) cs.opts.ErrorHandler.ServeHTTP(w, r) return @@ -308,6 +360,15 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { contextClear(r) } +// PlaintextHTTPRequest accepts as input a http.Request and returns a new +// http.Request with the PlaintextHTTPContextKey set to true. This is used to +// signal to the CSRF middleware that the request is being served over plaintext +// HTTP and that Referer-based origin allow-listing checks should be skipped. +func PlaintextHTTPRequest(r *http.Request) *http.Request { + ctx := context.WithValue(r.Context(), PlaintextHTTPContextKey, true) + return r.WithContext(ctx) +} + // unauthorizedhandler sets a HTTP 403 Forbidden status and writes the // CSRF failure reason to the response. func unauthorizedHandler(w http.ResponseWriter, r *http.Request) { diff --git a/vendor/modules.txt b/vendor/modules.txt index 4a05d044a99afc93bec126a04708e7fd0caf00b5..7fe1c4e90e28913aa937a66ed912394d7dd6a9df 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -83,7 +83,7 @@ github.com/google/go-tpm/tpmutil/tbs # github.com/google/uuid v1.6.0 ## explicit github.com/google/uuid -# github.com/gorilla/csrf v1.7.2 +# github.com/gorilla/csrf v1.7.3 ## explicit; go 1.20 github.com/gorilla/csrf # github.com/gorilla/mux v1.8.1