From b1c0a012a9d00b10f7652b689ae0c6f703e6a6b1 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sun, 27 Jan 2019 14:44:37 +0000
Subject: [PATCH] Simplify login handler by isolating session init logic

And standardize on a single structure for the gorilla session
map, using the 'data' key for our gob-encoded objects.
---
 server/http.go  | 13 +++++++------
 server/login.go | 47 +++++++++++++++++++++++++++++++----------------
 2 files changed, 38 insertions(+), 22 deletions(-)

diff --git a/server/http.go b/server/http.go
index 9ee3696..a439be9 100644
--- a/server/http.go
+++ b/server/http.go
@@ -200,7 +200,7 @@ func (h *Server) loginCallback(w http.ResponseWriter, req *http.Request, usernam
 	// Create cookie-based session for the authenticated user.
 	session := newAuthSession(h.authSessionLifetime, username, userinfo)
 	httpSession, _ := h.authSessionStore.Get(req, authSessionKey) // nolint
-	httpSession.Values["auth"] = session
+	httpSession.Values["data"] = session
 	return httpSession.Save(req, w)
 }
 
@@ -211,12 +211,13 @@ func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSessio
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}
-		session, ok := httpSession.Values["auth"].(*authSession)
-		if ok && session != nil && session.Valid() {
+		session, ok := httpSession.Values["data"].(*authSession)
+		if ok && session.Valid() {
 			f(w, req, session)
 			return
 		}
 		httpSession.Options.MaxAge = -1
+		delete(httpSession.Values, "data")
 		if err := httpSession.Save(req, w); err != nil {
 			log.Printf("error saving session: %v", err)
 		}
@@ -224,9 +225,9 @@ func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSessio
 	})
 }
 
-// Homepage handler. Authorizes an authenticated user to a service by
-// signing a token with the user's identity. The client is redirected
-// back to the service, with the signed token.
+// Token signing handler. Authorizes an authenticated user to a service by
+// signing a token with the user's identity. The client is redirected back to
+// the original service, with the signed token.
 func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, session *authSession) {
 	// Extract the authorization request parameters from the HTTP
 	// request.
diff --git a/server/login.go b/server/login.go
index 8a182f1..3118f93 100644
--- a/server/login.go
+++ b/server/login.go
@@ -105,21 +105,43 @@ func newLoginHandler(okHandler loginCallbackFunc, devMgr *device.Manager, authCl
 	}
 }
 
+func (l *loginHandler) fetchOrInitSession(req *http.Request) (*sessions.Session, *loginSession, error) {
+	// Either fetch the current session or create a new blank one.
+	httpSession, err := l.loginSessionStore.Get(req, loginSessionKey)
+	if err != nil {
+		return nil, nil, err
+	}
+	session, ok := httpSession.Values["data"].(*loginSession)
+	if !ok || !session.Valid() {
+		session = newLoginSession()
+
+		// Initialize session. The only parameter is 'r', the target
+		// redirect location. Enforce relative redirect URL (no host
+		// should be specified).
+		session.Redir = req.FormValue("r")
+		if session.Redir == "" {
+			return nil, nil, errors.New("empty login redirect target")
+		}
+		if !strings.HasPrefix(session.Redir, "/") || strings.HasPrefix(session.Redir, "//") {
+			return nil, nil, errors.New("bad login redirect target")
+		}
+
+		httpSession.Values["data"] = session
+	}
+
+	return httpSession, session, nil
+}
+
 // The login session controls the flow of the client - it's just a way
 // to ensure that every step is authorized as part of the login
 // sequence.
 func (l *loginHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-	// Either fetch the current session or create a new blank one.
-	httpSession, err := l.loginSessionStore.Get(req, loginSessionKey)
+	httpSession, session, err := l.fetchOrInitSession(req)
 	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
+		log.Printf("login session init error: %v", err)
+		http.Error(w, err.Error(), http.StatusBadRequest)
 		return
 	}
-	session, ok := httpSession.Values["ls"].(*loginSession)
-	if !ok || session == nil || !session.Valid() {
-		session = newLoginSession()
-		httpSession.Values["ls"] = session
-	}
 
 	// Dispatch the current state to its handler. Handlers will
 	// handle the current request and either 1) validate the
@@ -142,6 +164,7 @@ func (l *loginHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 			// the login callback, before redirecting to the
 			// original URL.
 			httpSession.Options.MaxAge = -1
+			delete(httpSession.Values, "data")
 			if err := httpSession.Save(req, w); err != nil {
 				log.Printf("login error saving session: %v", err)
 				http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -188,14 +211,6 @@ func (l *loginHandler) handlePassword(w http.ResponseWriter, req *http.Request,
 	username := req.FormValue("username")
 	password := req.FormValue("password")
 
-	if req.Method == "GET" && session.Redir == "" {
-		session.Redir = req.FormValue("r")
-		// Enforce relative redirect URL (no host specified).
-		if session.Redir == "" || !strings.HasPrefix(session.Redir, "/") {
-			return loginStateNone, nil, errors.New("bad request")
-		}
-	}
-
 	// If the request is a POST, attempt login with username/password.
 	env := map[string]interface{}{
 		"Error":    false,
-- 
GitLab