diff --git a/server/http.go b/server/http.go index 7248f0359046b0ef99760cf10e39690745eed715..12ecae6ef0e0d959bd98dc3955a63a1ef8dcd4cd 100644 --- a/server/http.go +++ b/server/http.go @@ -206,7 +206,11 @@ func (h *Server) loginCallback(w http.ResponseWriter, req *http.Request, usernam return httpSession.Save(req, w) } -func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSession)) http.Handler { +func (h *Server) redirectToLogin(w http.ResponseWriter, req *http.Request) { + http.Redirect(w, req, h.loginHandler.makeLoginURL(req), http.StatusFound) +} + +func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSession), authFail func(http.ResponseWriter, *http.Request)) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { httpSession, err := h.authSessionStore.Get(req, authSessionKey) if err != nil { @@ -223,7 +227,7 @@ func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSessio if err := httpSession.Save(req, w); err != nil { log.Printf("error saving session: %v", err) } - http.Redirect(w, req, h.loginHandler.makeLoginURL(req), http.StatusFound) + authFail(w, req) }) } @@ -285,6 +289,10 @@ func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, sessio http.Redirect(w, req, callbackURL, http.StatusFound) } +func (h *Server) alreadyLoggedOut(w http.ResponseWriter, req *http.Request) { + http.Error(w, "You do not seem to be logged in", http.StatusBadRequest) +} + type logoutServiceInfo struct { URL string `json:"url"` Name string `json:"name"` @@ -381,7 +389,7 @@ func (h *Server) Handler() http.Handler { // protection. m := http.NewServeMux() m.Handle(h.urlFor("/login"), h.loginHandler) - m.Handle(h.urlFor("/logout"), h.withAuth(h.handleLogout)) + m.Handle(h.urlFor("/logout"), h.withAuth(h.handleLogout, h.alreadyLoggedOut)) idph := http.Handler(m) if h.csrfSecret != nil { idph = csrf.Protect(h.csrfSecret)(idph) @@ -390,7 +398,7 @@ func (h *Server) Handler() http.Handler { // Add the SSO provider endpoints (root path and /exchange), // which do not need CSRF. We use a HandlerFunc to bypass the // '/' dispatch semantics of the standard http.ServeMux. - ssoh := h.withAuth(h.handleHomepage) + ssoh := h.withAuth(h.handleHomepage, h.redirectToLogin) userh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET" && r.URL.Path == h.urlFor("/"): diff --git a/server/http_test.go b/server/http_test.go index c3fa9a6fcaa19273cbbd5fd752dd154da7ab2f2c..e11560f64fd793e68d2189857d7a9e0a7d129c42 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -249,7 +249,6 @@ func TestHTTP_LoginAndLogout(t *testing.T) { // Make a logout request. doGet(t, httpSrv, c, "/logout", checkStatusOk) - doPostForm(t, httpSrv, c, "/logout", nil, checkStatusOk) // This new authorization request should send us to the login page. v = make(url.Values)