Commit 3ec9dca3 authored by ale's avatar ale

Store the session in the http.Request context

Add methods to the httpsso package to retrieve authentication info
from the current session (by passing an *http.Request object).
parent 285b177c
package httpsso
import (
"context"
"crypto/rand"
"encoding/gob"
"encoding/hex"
......@@ -23,6 +24,44 @@ type authSession struct {
Auth bool
Username string
Groups []string
}
type authSessionKeyType int
var authSessionKey authSessionKeyType = 42
func getCurrentAuthSession(req *http.Request) *authSession {
s, ok := req.Context().Value(authSessionKey).(*authSession)
if !ok {
return nil
}
return s
}
// Authenticated returns true if the user is successfully
// authenticated, in the call trace following SSOWrapper.Wrap.
func Authenticated(req *http.Request) bool {
if s := getCurrentAuthSession(req); s != nil {
return s.Auth
}
return false
}
// Username of the currently authenticated user.
func Username(req *http.Request) string {
if s := getCurrentAuthSession(req); s != nil && s.Auth {
return s.Username
}
return ""
}
// Groups returns the group list for the currently authenticated user.
func Groups(req *http.Request) []string {
if s := getCurrentAuthSession(req); s != nil && s.Auth {
return s.Groups
}
return nil
}
var authSessionLifetime = 1 * time.Hour
......@@ -83,7 +122,8 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username)
h.ServeHTTP(w, req)
ctx := context.WithValue(req.Context(), authSessionKey, auth)
h.ServeHTTP(w, req.WithContext(ctx))
return
}
......@@ -117,6 +157,7 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
}
if err := sessions.Save(req, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
......@@ -137,7 +178,7 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
w.Header().Set("Access-Control-Allow-Origin", s.serverOrigin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
io.WriteString(w, "OK")
io.WriteString(w, "OK") // nolint
}
// Redirect to the SSO server.
......
......@@ -8,6 +8,7 @@ import (
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
......@@ -41,6 +42,8 @@ const (
testLoginServer = "login.example.com"
)
var testGroups = []string{"group1", "group2"}
func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain string, priv []byte) []byte {
resp, err := c.Get(base + path)
if err != nil {
......@@ -65,13 +68,13 @@ func makeAuthRequest(t testing.TB, c *http.Client, base, path, service, domain s
t.Fatal(err)
}
nonce := loc.Query().Get("n")
tkt := sso.NewTicket("user", service, domain, nonce, nil, 300*time.Second)
tkt := sso.NewTicket("user", service, domain, nonce, testGroups, 300*time.Second)
signed, err := signer.Sign(tkt)
if err != nil {
t.Fatal("Sign():", err)
}
u := make(url.Values)
destURL := "https://" + testHost + "/test"
destURL := "https://" + testHost + path
u.Set("d", destURL)
u.Set("t", signed)
resp, err = c.Get(base + "/sso_login?" + u.Encode())
......@@ -108,8 +111,11 @@ func TestSSOWrapper(t *testing.T) {
// Build a test app - note that we want to use a gorilla Mux
// here, otherwise cookie-based sessions won't work.
m := mux.NewRouter()
m.HandleFunc("/test/groups", func(w http.ResponseWriter, req *http.Request) {
io.WriteString(w, strings.Join(Groups(req), ",")) // nolint
})
m.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) {
io.WriteString(w, "OK")
io.WriteString(w, "OK") // nolint
})
w, err := NewSSOWrapper("https://"+testLoginServer+"/", pub, testDomain, securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32))
......@@ -118,13 +124,20 @@ func TestSSOWrapper(t *testing.T) {
}
// Start a local test https server.
srv := httptest.NewTLSServer(w.Wrap(m, testService, nil))
srv := httptest.NewTLSServer(w.Wrap(m, testService, testGroups))
defer srv.Close()
// Request a sample URL.
c := newTestHTTPClient()
data := string(makeAuthRequest(t, c, srv.URL, "/test", testService, testDomain, priv))
if data != "OK" {
t.Fatalf("Get() returned bad data: %s", data)
t.Fatalf("Get(/test) returned bad data: %s", data)
}
// Another URL, clean client, verify context values.
c = newTestHTTPClient()
data = string(makeAuthRequest(t, c, srv.URL, "/test/groups", testService, testDomain, priv))
if data != "group1,group2" {
t.Fatalf("Get(/test/groups) returned bad data: %s", data)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment