From 42482e8283c78c682c16672f0f2657a8bea9b285 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Mon, 16 Dec 2019 09:52:00 +0000
Subject: [PATCH] Use securecookie for the httpsso handler

There is no need for the complex gorilla/sessions machinery for what
is basically a single cookie, so we switch to using
gorilla/securecookie directly.
---
 httpsso/handler.go      | 129 +++++++++++++++++++++-------------------
 httpsso/handler_test.go |   8 +--
 2 files changed, 70 insertions(+), 67 deletions(-)

diff --git a/httpsso/handler.go b/httpsso/handler.go
index 402e778..e750f33 100644
--- a/httpsso/handler.go
+++ b/httpsso/handler.go
@@ -13,10 +13,13 @@ import (
 	"strings"
 	"time"
 
-	"github.com/gorilla/sessions"
-
 	"git.autistici.org/id/go-sso"
-	"git.autistici.org/id/go-sso/httputil"
+	"github.com/gorilla/securecookie"
+)
+
+const (
+	ssoCookieName   = "sso"
+	nonceCookieName = "sso_n"
 )
 
 type authSession struct {
@@ -30,11 +33,10 @@ type authSessionKeyType int
 const authSessionKey authSessionKeyType = 0
 
 func getCurrentAuthSession(req *http.Request) *authSession {
-	s, ok := req.Context().Value(authSessionKey).(*authSession)
-	if !ok {
-		return nil
+	if s, ok := req.Context().Value(authSessionKey).(*authSession); ok {
+		return s
 	}
-	return s
+	return nil
 }
 
 // Authenticated returns true if the user is successfully
@@ -70,30 +72,30 @@ func init() {
 
 // SSOWrapper protects http handlers with single-sign-on authentication.
 type SSOWrapper struct {
-	v              sso.Validator
-	sessionAuthKey []byte
-	sessionEncKey  []byte
-	serverURL      string
-	serverOrigin   string
-
-	TTL time.Duration
+	v            sso.Validator
+	sc           *securecookie.SecureCookie
+	serverURL    string
+	serverOrigin string
 }
 
 // NewSSOWrapper returns a new SSOWrapper that will authenticate users
 // on the specified login service.
-func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, sessionEncKey []byte) (*SSOWrapper, error) {
+func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, sessionEncKey []byte, ttl time.Duration) (*SSOWrapper, error) {
 	v, err := sso.NewValidator(pkey, domain)
 	if err != nil {
 		return nil, err
 	}
 
+	if ttl == 0 {
+		ttl = defaultAuthSessionTTL
+	}
+	sc := securecookie.New(sessionAuthKey, sessionEncKey).MaxAge(int(ttl.Seconds()))
+
 	return &SSOWrapper{
-		v:              v,
-		serverURL:      serverURL,
-		serverOrigin:   originFromURL(serverURL),
-		sessionAuthKey: sessionAuthKey,
-		sessionEncKey:  sessionEncKey,
-		TTL:            defaultAuthSessionTTL,
+		v:            v,
+		sc:           sc,
+		serverURL:    serverURL,
+		serverOrigin: originFromURL(serverURL),
 	}, nil
 }
 
@@ -101,50 +103,47 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey,
 // Currently only a simple form of group-based ACLs is supported.
 func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.Handler {
 	svcPath := pathFromService(service)
-	store := sessions.NewCookieStore(s.sessionAuthKey, s.sessionEncKey)
-	store.Options = &sessions.Options{
-		HttpOnly: true,
-		Secure:   true,
-		MaxAge:   0,
-		Path:     svcPath,
-	}
-
 	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-		session, _ := httputil.GetExpiringSession(req, store, "sso", s.TTL)
-
 		switch strings.TrimPrefix(req.URL.Path, svcPath) {
 		case "sso_login":
-			s.handleLogin(w, req, session, service, groups)
+			s.handleLogin(w, req, service, groups)
 
 		case "sso_logout":
-			s.handleLogout(w, req, session)
+			s.handleLogout(w, req)
 
 		default:
-			if auth, ok := session.Values["a"].(*authSession); ok && auth.Auth {
-				req.Header.Set("X-Authenticated-User", auth.Username)
+			var auth authSession
+			if cookie, err := req.Cookie(ssoCookieName); err == nil {
+				s.sc.Decode(ssoCookieName, cookie.Value, &auth) // nolint
+			}
 
-				req = req.WithContext(context.WithValue(req.Context(), authSessionKey, auth))
+			if auth.Auth {
+				req.Header.Set("X-Authenticated-User", auth.Username)
+				req = req.WithContext(context.WithValue(req.Context(), authSessionKey, &auth))
 				h.ServeHTTP(w, req)
 				return
 			}
 
-			s.redirectToLogin(w, req, session, service, groups)
+			s.redirectToLogin(w, req, service, groups)
 		}
 	})
 }
 
-func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
+func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, service string, groups []string) {
 	t := req.FormValue("t")
 	d := req.FormValue("d")
 
-	// Pop the nonce from the session.
-	nonce, ok := session.Values["nonce"].(string)
-	if !ok || nonce == "" {
+	// Pop the nonce from the cookies.
+	cookie, err := req.Cookie(nonceCookieName)
+	if err != nil {
 		log.Printf("got login request without nonce")
 		http.Error(w, "Missing nonce", http.StatusBadRequest)
 		return
 	}
-	delete(session.Values, "nonce")
+	nonce := cookie.Value
+	cookie.MaxAge = -1
+	cookie.Value = ""
+	http.SetCookie(w, cookie)
 
 	tkt, err := s.v.Validate(t, nonce, service, groups)
 	if err != nil {
@@ -154,28 +153,34 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
 	}
 
 	// Authenticate the user.
-	session.Values["a"] = &authSession{
+	auth := authSession{
 		Auth:     true,
 		Username: tkt.User,
 		Groups:   tkt.Groups,
 	}
-	if err := sessions.Save(req, w); err != nil {
+	encoded, err := s.sc.Encode(ssoCookieName, &auth)
+	if err != nil {
 		log.Printf("error saving SSO session: %v", err)
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
+	http.SetCookie(w, &http.Cookie{
+		Name:     ssoCookieName,
+		Value:    encoded,
+		Path:     pathFromService(service),
+		Secure:   true,
+		HttpOnly: true,
+	})
+
 	http.Redirect(w, req, d, http.StatusFound)
 }
 
-func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession) {
-	// Delete the auth session.
-	session.Options.MaxAge = -1
-	delete(session.Values, "sso")
-
-	if err := sessions.Save(req, w); err != nil {
-		log.Printf("error saving SSO session: %v", err)
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
+func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request) {
+	// Delete the session cookie, if present.
+	if cookie, err := req.Cookie(ssoCookieName); err == nil {
+		cookie.MaxAge = -1
+		cookie.Value = ""
+		http.SetCookie(w, cookie)
 	}
 
 	w.Header().Set("Content-Type", "text/plain")
@@ -187,16 +192,16 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
 }
 
 // Redirect to the SSO server.
-func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
-	// Generate a random nonce and store it in the local session.
+func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, service string, groups []string) {
+	// Generate a random nonce and store it in a cookie.
 	nonce := makeUniqueNonce()
-	session.Values["nonce"] = nonce
-
-	if err := sessions.Save(req, w); err != nil {
-		log.Printf("error saving SSO session: %v", err)
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}
+	http.SetCookie(w, &http.Cookie{
+		Name:     nonceCookieName,
+		Value:    nonce,
+		Path:     pathFromService(service) + "sso_login",
+		Secure:   true,
+		HttpOnly: true,
+	})
 
 	v := make(url.Values)
 	v.Set("s", service)
diff --git a/httpsso/handler_test.go b/httpsso/handler_test.go
index 7cb1f65..300af5e 100644
--- a/httpsso/handler_test.go
+++ b/httpsso/handler_test.go
@@ -12,7 +12,6 @@ import (
 	"testing"
 	"time"
 
-	"github.com/gorilla/mux"
 	"github.com/gorilla/securecookie"
 	"golang.org/x/crypto/ed25519"
 
@@ -108,9 +107,8 @@ func TestSSOWrapper(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	// Build a test app - note that we want to use a gorilla Mux
-	// here, otherwise cookie-based sessions won't work.
-	m := mux.NewRouter()
+	// Build a test app.
+	m := http.NewServeMux()
 	m.HandleFunc("/test/groups", func(w http.ResponseWriter, req *http.Request) {
 		io.WriteString(w, strings.Join(Groups(req), ",")) // nolint
 	})
@@ -118,7 +116,7 @@ func TestSSOWrapper(t *testing.T) {
 		io.WriteString(w, "OK") // nolint
 	})
 
-	w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32))
+	w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32), 0)
 	if err != nil {
 		t.Fatal("NewSSOWrapper():", err)
 	}
-- 
GitLab