Commit 4d70b167 authored by ale's avatar ale

Refactor the login handler

The login handler is now a simpler, standalone http.Handler
wrapper. The separation between the SSO application and the login
handler is now fairly complete.

The login handler no longer forces the user to a specific workflow via
session cookies, but it works on a request-by-request basis instead,
which makes the "back" button works as expected (allowing the user to
bail out of a broken 2FA process, for example).

Session handling has been simplified as well: there is a single
session for authentication and login state, which should remove the
opportunity for session synchronization errors.
parent 6d3a620e
Pipeline #5345 passed with stages
in 3 minutes and 18 seconds
......@@ -20,8 +20,6 @@ import (
)
type authSession struct {
*httputil.ExpiringSession
Auth bool
Username string
Groups []string
......@@ -29,7 +27,7 @@ type authSession struct {
type authSessionKeyType int
var authSessionKey authSessionKeyType = 42
const authSessionKey authSessionKeyType = 0
func getCurrentAuthSession(req *http.Request) *authSession {
s, ok := req.Context().Value(authSessionKey).(*authSession)
......@@ -64,7 +62,7 @@ func Groups(req *http.Request) []string {
return nil
}
var authSessionLifetime = 1 * time.Hour
var defaultAuthSessionTTL = 1 * time.Hour
func init() {
gob.Register(&authSession{})
......@@ -77,6 +75,8 @@ type SSOWrapper struct {
sessionEncKey []byte
serverURL string
serverOrigin string
TTL time.Duration
}
// NewSSOWrapper returns a new SSOWrapper that will authenticate users
......@@ -93,6 +93,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey,
serverOrigin: originFromURL(serverURL),
sessionAuthKey: sessionAuthKey,
sessionEncKey: sessionEncKey,
TTL: defaultAuthSessionTTL,
}, nil
}
......@@ -109,7 +110,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
session, _ := store.Get(req, "sso")
session, _ := httputil.GetExpiringSession(req, store, "sso", s.TTL)
switch strings.TrimPrefix(req.URL.Path, svcPath) {
case "sso_login":
......@@ -119,11 +120,11 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
s.handleLogout(w, req, session)
default:
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
if auth, ok := session.Values["a"].(*authSession); ok && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username)
ctx := context.WithValue(req.Context(), authSessionKey, auth)
h.ServeHTTP(w, req.WithContext(ctx))
req = req.WithContext(context.WithValue(req.Context(), authSessionKey, auth))
h.ServeHTTP(w, req)
return
}
......@@ -132,7 +133,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
})
}
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
t := req.FormValue("t")
d := req.FormValue("d")
......@@ -154,21 +155,25 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
// Authenticate the user.
session.Values["a"] = &authSession{
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
}
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, req, d, http.StatusFound)
}
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) {
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession) {
// Delete the auth session.
session.Options.MaxAge = -1
delete(session.Values, "sso")
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
......@@ -182,11 +187,13 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
}
// Redirect to the SSO server.
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
// Generate a random nonce and store it in the local session.
nonce := makeUniqueNonce()
session.Values["nonce"] = nonce
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
......
package httputil
import "net/http"
// WithDynamicHeaders wraps an http.Handler with cache-busting and
// security-related headers appropriate for a user-facing dynamic
// application. The 'csp' argument sets a default
// Content-Security-Policy.
func WithDynamicHeaders(h http.Handler, csp string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("Pragma", "no-cache")
hdr.Set("Cache-Control", "no-store")
hdr.Set("Expires", "-1")
hdr.Set("X-Frame-Options", "NONE")
hdr.Set("X-XSS-Protection", "1; mode=block")
hdr.Set("X-Content-Type-Options", "nosniff")
if csp != "" && hdr.Get("Content-Security-Policy") == "" {
hdr.Set("Content-Security-Policy", csp)
}
h.ServeHTTP(w, r)
})
}
package httputil
import (
"bytes"
"html/template"
"io"
"log"
"net/http"
"strconv"
"github.com/gorilla/csrf"
)
// A Renderer just renders HTML templates with some common context
// variables. Context is represented as a map[string]interface{}, to
// allow the merge operation.
type Renderer struct {
tpl *template.Template
vars map[string]interface{}
}
// NewRenderer creates a new Renderer with the provided templates and
// default variables.
func NewRenderer(tpl *template.Template, vars map[string]interface{}) *Renderer {
return &Renderer{
tpl: tpl,
vars: vars,
}
}
// Render the named HTML template to 'w'.
func (r *Renderer) Render(w http.ResponseWriter, req *http.Request, templateName string, data map[string]interface{}) {
// Merge default variables with the ones passed in 'data',
// without modifying either. Always populate the CRSFField
// variable with the current CSRF token.
vars := make(map[string]interface{})
vars["CSRFField"] = csrf.TemplateField(req)
for k, v := range r.vars {
vars[k] = v
}
for k, v := range data {
vars[k] = v
}
// Render the template into a buffer, to prevent returning
// half-rendered templates when there is an error.
var buf bytes.Buffer
if err := r.tpl.ExecuteTemplate(&buf, templateName, data); err != nil {
log.Printf("template rendering error for %s: %v", req.URL.String(), err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write our response to the client.
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
io.Copy(w, &buf) // nolint
}
......@@ -2,7 +2,11 @@ package httputil
import (
"encoding/gob"
"log"
"net/http"
"time"
"github.com/gorilla/sessions"
)
// ExpiringSession is a session with server-side expiration check.
......@@ -12,24 +16,85 @@ import (
// browser for the latter, but we enforce time-based expiration on the
// server.
type ExpiringSession struct {
Expiry time.Time
*sessions.Session
}
// NewExpiringSession returns a session that is valid for the given
// duration.
func NewExpiringSession(ttl time.Duration) *ExpiringSession {
return &ExpiringSession{
Expiry: time.Now().Add(ttl),
// GetExpiringSession wraps a Session (obtained from 'store') with
// an ExpiringSession. If it's invalid or expired, a new empty Session
// will be created with an expiration time set using 'ttl'.
func GetExpiringSession(req *http.Request, store sessions.Store, key string, ttl time.Duration) (*ExpiringSession, error) {
now := time.Now()
// An error here just means that we failed to decode the
// existing session for some reason. A new session will always
// be returned, so we just pass along the error to the caller
// (so it can be logged).
s, err := store.Get(req, key)
// See if we have a valid session first.
if !s.IsNew {
if exp, ok := s.Values["_exp"].(time.Time); ok && now.Before(exp) {
return &ExpiringSession{Session: s}, err
}
// We can't call sessions.NewSession() because that
// won't register the session with the Registry, so it
// won't be sent with the response. Wipe the data
// instead.
for k := range s.Values {
delete(s.Values, k)
}
}
// The session is either invalid or expired, create a new
// blank one containing no data.
expiry := now.Add(ttl)
s.Values["_exp"] = expiry
return &ExpiringSession{Session: s}, err
}
// Valid returns true if the session has not expired yet.
// It can be called with a nil receiver.
func (e *ExpiringSession) Valid() bool {
return e != nil && time.Now().Before(e.Expiry)
// Wrapper for an http.ResponseWriter that ensures all tracked
// sessions are saved before the request body is sent.
//
// We have to duplicate the logic to call WriteHeader on the first
// Write, otherwise the underlying ResponseWriter won't call our
// WriteHeader function but its own instead.
type sessionResponseWriter struct {
http.ResponseWriter
headerWritten bool
req *http.Request
}
func init() {
gob.Register(&ExpiringSession{})
func (w *sessionResponseWriter) WriteHeader(statusCode int) {
if statusCode >= 200 && statusCode < 400 {
if err := sessions.Save(w.req, w.ResponseWriter); err != nil {
log.Printf("error saving sessions: %v", err)
}
}
w.ResponseWriter.WriteHeader(statusCode)
w.headerWritten = true
}
func (w *sessionResponseWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(b)
}
// NewSessionResponseWriter returns a wrapped http.ResponseWriter that
// will always remember to save the Gorilla sessions before writing
// the response body.
func NewSessionResponseWriter(w http.ResponseWriter, req *http.Request) http.ResponseWriter {
return &sessionResponseWriter{
ResponseWriter: w,
req: req,
}
}
func init() {
// Register time.Time with encoding/gob, to ensure that the
// ExpiringSession timestamp can be serialized.
var t time.Time
gob.Register(t)
}
package httputil
import (
"bytes"
"encoding/gob"
"reflect"
"net/http"
"testing"
"time"
"github.com/gorilla/sessions"
)
type mySession struct {
Data string
}
func init() {
gob.Register(&mySession{})
}
func TestExpiringSession(t *testing.T) {
type mySession struct {
*ExpiringSession
Data string
}
s := &mySession{
ExpiringSession: NewExpiringSession(60 * time.Second),
Data: "data",
}
store := sessions.NewCookieStore()
req, _ := http.NewRequest("GET", "http://localhost/", nil)
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(s); err != nil {
t.Fatal("encode:", err)
httpsess, err := GetExpiringSession(req, store, "testkey", 60*time.Second)
if err != nil {
t.Errorf("store.Get error: %v", err)
}
var s2 mySession
if err := gob.NewDecoder(&buf).Decode(&s2); err != nil {
t.Fatal("decode:", err)
}
if !reflect.DeepEqual(s.Data, s2.Data) {
t.Fatalf("sessions differ: %+v vs %+v", s, &s2)
if _, ok := httpsess.Values["mykey"].(*mySession); ok {
t.Fatal("got a session without any data")
}
}
package httputil
import (
"bytes"
"io/ioutil"
"net/http"
"os"
"time"
)
// StaticContent is an http.Handler that serves in-memory data as if
// it were a static file.
type StaticContent struct {
modtime time.Time
name string
data []byte
}
// LoadStaticContent creates a StaticContent by loading data from a file.
func LoadStaticContent(path string) (*StaticContent, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
data, err := ioutil.ReadFile(path) // #nosec
if err != nil {
return nil, err
}
return &StaticContent{
name: path,
modtime: stat.ModTime(),
data: data,
}, nil
}
func (c *StaticContent) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.ServeContent(w, req, c.name, c.modtime, bytes.NewReader(c.data))
}
......@@ -10,6 +10,7 @@ import (
"strings"
"git.autistici.org/id/auth"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/mssola/user_agent"
)
......@@ -35,7 +36,7 @@ type Config struct {
}
// New returns a new Manager with the given configuration.
func New(config *Config) (*Manager, error) {
func New(config *Config, urlPrefix string) (*Manager, error) {
if config == nil {
config = &Config{}
}
......@@ -45,9 +46,15 @@ func New(config *Config) (*Manager, error) {
log.Printf("Warning: GeoIP disabled: %v", err)
}
// This should only happen in tests.
if config.AuthKey == "" {
log.Printf("Warning: device_manager.auth_key unset, generating temporary random secrets")
config.AuthKey = string(securecookie.GenerateRandomKey(64))
}
return &Manager{
geodb: geodb,
store: newStore([]byte(config.AuthKey)),
store: newStore([]byte(config.AuthKey), urlPrefix),
}, nil
}
......
......@@ -4,11 +4,11 @@ import "github.com/gorilla/sessions"
const aVeryLongTimeInSeconds = 10 * 365 * 86400
func newStore(authKey []byte) sessions.Store {
func newStore(authKey []byte, urlPrefix string) sessions.Store {
// No encryption, long-term lifetime cookie.
store := sessions.NewCookieStore(authKey, nil)
store.Options = &sessions.Options{
Path: "/",
Path: urlPrefix + "/",
HttpOnly: true,
Secure: true,
MaxAge: aVeryLongTimeInSeconds,
......
This diff is collapsed.
......@@ -161,7 +161,7 @@ func checkLoginPasswordPage(t testing.TB, resp *http.Response) {
var otpFieldRx = regexp.MustCompile(`<input[^>]*name="otp"`)
func checkLoginOTPPage(t testing.TB, resp *http.Response) {
if resp.Request.URL.Path != "/login" {
if resp.Request.URL.Path != "/login/otp" {
t.Errorf("request path is not /login (%s)", resp.Request.URL.String())
}
data, err := ioutil.ReadAll(resp.Body)
......@@ -283,7 +283,7 @@ func TestHTTP_LoginOTP(t *testing.T) {
// 302 redirect to the target service.
v = make(url.Values)
v.Set("otp", "123456")
doPostForm(t, httpSrv, c, "/login", v, checkRedirectToTargetService)
doPostForm(t, httpSrv, c, "/login/otp", v, checkRedirectToTargetService)
}
func createFakeKeyStore(t testing.TB, username, password string) *httptest.Server {
......@@ -304,7 +304,7 @@ func createFakeKeyStore(t testing.TB, username, password string) *httptest.Serve
t.Errorf("bad password in keystore Open request: expected %s, got %s", password, openReq.Password)
}
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, "{}")
io.WriteString(w, "{}") // nolint
})
return httptest.NewServer(h)
}
......
......@@ -67,4 +67,4 @@ func (dl DefaultLogger) LogResponse(req *http.Request, res *http.Response, err e
}
// DefaultLoggedTransport wraps http.DefaultTransport to log using DefaultLogger
var DefaultLoggedTransport = NewLoggedTransport(http.DefaultTransport, DefaultLogger{})
//var DefaultLoggedTransport = NewLoggedTransport(http.DefaultTransport, DefaultLogger{})
<
package server
import (
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/tstranex/u2f"
"go.opencensus.io/trace"
"git.autistici.org/id/auth"
authclient "git.autistici.org/id/auth/client"
"git.autistici.org/id/go-sso/httputil"
"git.autistici.org/id/go-sso/server/device"
)
const loginSessionKey = "_login"
type loginSession struct {
*httputil.ExpiringSession
State loginState
// Post-login redirection URL.
Redir string
// Cached from the first form.
Username string
Password string
// The auth.Response is cached for 2FA.
AuthResponse *auth.Response
}
// The login session is short-lived, it only needs to last for the duration of
// the login process itself.
var defaultLoginSessionLifetime = 10 * time.Minute
func newLoginSession() *loginSession {
return &loginSession{
ExpiringSession: httputil.NewExpiringSession(defaultLoginSessionLifetime),
State: loginStatePassword,
}
}
type loginState int
const (
loginStateNone = iota
loginStatePassword
loginStateOTP
loginStateU2F
loginStateSuccess
)
func init() {
gob.Register(&loginSession{})
}
type loginCallbackFunc func(http.ResponseWriter, *http.Request, string, string, *auth.UserInfo) error
type loginHandler struct {
authClient authclient.Client
authService string
u2fAppID string
urlPrefix string
devMgr *device.Manager
loginCallback loginCallbackFunc
loginSessionStore sessions.Store
renderer *renderer
}
// NewLoginHandler will wrap an http.Handler with the login workflow,
// invoking it only on successful login.
func newLoginHandler(okHandler loginCallbackFunc, devMgr *device.Manager, authClient authclient.Client, authService, u2fAppID, urlPrefix string, rndr *renderer, keyPairs ...[]byte) *loginHandler {
store := sessions.NewCookieStore(keyPairs...)
store.Options = &sessions.Options{
HttpOnly: true,
Secure: true,
MaxAge: 0,
}
return &loginHandler{
authClient: authClient,
authService: authService,
u2fAppID: u2fAppID,
urlPrefix: strings.TrimRight(urlPrefix, "/"),
devMgr: devMgr,
loginCallback: okHandler,
loginSessionStore: store,
renderer: rndr,
}
}
func (l *loginHandler) fetchOrInitSession(req *http.Request) (*sessions.Session, *loginSession, error) {
// Either fetch the current session or create a new blank one.
httpSession, err := l.loginSessionStore.Get(req, loginSessionKey)
if err != nil {
return nil, nil, err
}
session, ok := httpSession.Values["data"].(*loginSession)
if !ok || !session.Valid() {
session = newLoginSession()
// Initialize session. The only parameter is 'r', the target
// redirect location. Enforce relative redirect URL (no host
// should be specified).
session.Redir = req.FormValue("r")
if session.Redir == "" {
return nil, nil, errors.New("empty login redirect target")
}
if !strings.HasPrefix(session.Redir, "/") || strings.HasPrefix(session.Redir, "//") {
return nil, nil, errors.New("bad login redirect target")
}
httpSession.Values["data"] = session
}
return httpSession, session, nil
}
// The login session controls the flow of the client - it's just a way
// to ensure that every step is authorized as part of the login
// sequence.
func (l *loginHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
httpSession, session, err := l.fetchOrInitSession(req)
if err != nil {
log.Printf("login session init error: %v", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Dispatch the current state to its handler. Handlers will
// handle the current request and either 1) validate the
// request successfully and move to the next state, or 2)
// return a response to the user. Handlers fall through to the
// next state on success.
for {