Skip to content
Snippets Groups Projects
Commit 3ec9dca3 authored by ale's avatar ale
Browse files

Store the session in the http.Request context

Add methods to the httpsso package to retrieve authentication info
from the current session (by passing an *http.Request object).
parent 285b177c
Branches
No related tags found
No related merge requests found
package httpsso
import (
"context"
"crypto/rand"
"encoding/gob"
"encoding/hex"
......@@ -23,6 +24,44 @@ type authSession struct {
Auth bool
Username string
Groups []string
}
type authSessionKeyType int
var authSessionKey authSessionKeyType = 42
func getCurrentAuthSession(req *http.Request) *authSession {
s, ok := req.Context().Value(authSessionKey).(*authSession)
if !ok {
return nil
}
return s
}
// Authenticated returns true if the user is successfully
// authenticated, in the call trace following SSOWrapper.Wrap.
func Authenticated(req *http.Request) bool {
if s := getCurrentAuthSession(req); s != nil {
return s.Auth
}
return false
}
// Username of the currently authenticated user.
func Username(req *http.Request) string {
if s := getCurrentAuthSession(req); s != nil && s.Auth {
return s.Username
}
return ""
}
// Groups returns the group list for the currently authenticated user.
func Groups(req *http.Request) []string {
if s := getCurrentAuthSession(req); s != nil && s.Auth {
return s.Groups
}
return nil
}
var authSessionLifetime = 1 * time.Hour
......@@ -83,7 +122,8 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username)
h.ServeHTTP(w, req)
ctx := context.WithValue(req.Context(), authSessionKey, auth)
h.ServeHTTP(w, req.WithContext(ctx))
return
}
......@@ -117,6 +157,7 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
}
if err := sessions.Save(req, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
......@@ -137,7 +178,7 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
w.Header().Set("Access-Control-Allow-Origin", s.serverOrigin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
io.WriteString(w, "OK")
io.WriteString(w, "OK") // nolint
}
// Redirect to the SSO server.
......
......@@ -8,6 +8,7 @@ import (
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
......@@ -41,6 +42,8 @@ const (
testLoginServer = "login.example.com"
)
var testGroups = []string{"group1", "group2"}
func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte {
resp, err := c.Get(base + path)
if err != nil {
......@@ -65,13 +68,13 @@ func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain s
t.Fatal(err)
}
nonce := loc.Query().Get("n")
tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second)
tkt := sso.NewTicket("user", service, domain, nonce, testGroups, 300*time.Second)
signed, err := signer.Sign(tkt)
if err != nil {
t.Fatal("Sign():", err)
}
u := make(url.Values)
destURL := "https://" + testHost + "/test"
destURL := "https://" + testHost + path
u.Set("d", destURL)
u.Set("t", signed)
resp, err = c.Get(base + "/sso_login?" + u.Encode())
......@@ -108,8 +111,11 @@ func TestSSOWrapper(t *testing.T) {
// Build a test app - note that we want to use a gorilla Mux
// here, otherwise cookie-based sessions won't work.
m := mux.NewRouter()
m.HandleFunc("/test/groups", func(w http.ResponseWriter, req *http.Request) {
io.WriteString(w, strings.Join(Groups(req), ",")) // nolint
})
m.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) {
io.WriteString(w, "OK")
io.WriteString(w, "OK") // nolint
})
w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32))
......@@ -118,13 +124,20 @@ func TestSSOWrapper(t *testing.T) {
}
// Start a local test https server.
srv := httptest.NewTLSServer(w.Wrap(m, testService, nil))
srv := httptest.NewTLSServer(w.Wrap(m, testService, testGroups))
defer srv.Close()
// Request a sample URL.
c := newTestHTTPClient()
data := string(makeAuthRequest(t, c, srv.URL, "/test", testService, testDomain, priv))
if data != "OK" {
t.Fatalf("Get() returned bad data: %s", data)
t.Fatalf("Get(/test) returned bad data: %s", data)
}
// Another URL, clean client, verify context values.
c = newTestHTTPClient()
data = string(makeAuthRequest(t, c, srv.URL, "/test/groups", testService, testDomain, priv))
if data != "group1,group2" {
t.Fatalf("Get(/test/groups) returned bad data: %s", data)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment