Commit 4d70b167 authored by ale's avatar ale

Refactor the login handler

The login handler is now a simpler, standalone http.Handler
wrapper. The separation between the SSO application and the login
handler is now fairly complete.

The login handler no longer forces the user to a specific workflow via
session cookies, but it works on a request-by-request basis instead,
which makes the "back" button works as expected (allowing the user to
bail out of a broken 2FA process, for example).

Session handling has been simplified as well: there is a single
session for authentication and login state, which should remove the
opportunity for session synchronization errors.
parent 6d3a620e
Pipeline #5345 passed with stages
in 3 minutes and 18 seconds
...@@ -20,8 +20,6 @@ import ( ...@@ -20,8 +20,6 @@ import (
) )
type authSession struct { type authSession struct {
*httputil.ExpiringSession
Auth bool Auth bool
Username string Username string
Groups []string Groups []string
...@@ -29,7 +27,7 @@ type authSession struct { ...@@ -29,7 +27,7 @@ type authSession struct {
type authSessionKeyType int type authSessionKeyType int
var authSessionKey authSessionKeyType = 42 const authSessionKey authSessionKeyType = 0
func getCurrentAuthSession(req *http.Request) *authSession { func getCurrentAuthSession(req *http.Request) *authSession {
s, ok := req.Context().Value(authSessionKey).(*authSession) s, ok := req.Context().Value(authSessionKey).(*authSession)
...@@ -64,7 +62,7 @@ func Groups(req *http.Request) []string { ...@@ -64,7 +62,7 @@ func Groups(req *http.Request) []string {
return nil return nil
} }
var authSessionLifetime = 1 * time.Hour var defaultAuthSessionTTL = 1 * time.Hour
func init() { func init() {
gob.Register(&authSession{}) gob.Register(&authSession{})
...@@ -77,6 +75,8 @@ type SSOWrapper struct { ...@@ -77,6 +75,8 @@ type SSOWrapper struct {
sessionEncKey []byte sessionEncKey []byte
serverURL string serverURL string
serverOrigin string serverOrigin string
TTL time.Duration
} }
// NewSSOWrapper returns a new SSOWrapper that will authenticate users // NewSSOWrapper returns a new SSOWrapper that will authenticate users
...@@ -93,6 +93,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, ...@@ -93,6 +93,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey,
serverOrigin: originFromURL(serverURL), serverOrigin: originFromURL(serverURL),
sessionAuthKey: sessionAuthKey, sessionAuthKey: sessionAuthKey,
sessionEncKey: sessionEncKey, sessionEncKey: sessionEncKey,
TTL: defaultAuthSessionTTL,
}, nil }, nil
} }
...@@ -109,7 +110,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http. ...@@ -109,7 +110,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
} }
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
session, _ := store.Get(req, "sso") session, _ := httputil.GetExpiringSession(req, store, "sso", s.TTL)
switch strings.TrimPrefix(req.URL.Path, svcPath) { switch strings.TrimPrefix(req.URL.Path, svcPath) {
case "sso_login": case "sso_login":
...@@ -119,11 +120,11 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http. ...@@ -119,11 +120,11 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
s.handleLogout(w, req, session) s.handleLogout(w, req, session)
default: default:
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth { if auth, ok := session.Values["a"].(*authSession); ok && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username) req.Header.Set("X-Authenticated-User", auth.Username)
ctx := context.WithValue(req.Context(), authSessionKey, auth) req = req.WithContext(context.WithValue(req.Context(), authSessionKey, auth))
h.ServeHTTP(w, req.WithContext(ctx)) h.ServeHTTP(w, req)
return return
} }
...@@ -132,7 +133,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http. ...@@ -132,7 +133,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
}) })
} }
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) { func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
t := req.FormValue("t") t := req.FormValue("t")
d := req.FormValue("d") d := req.FormValue("d")
...@@ -154,21 +155,25 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi ...@@ -154,21 +155,25 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
// Authenticate the user. // Authenticate the user.
session.Values["a"] = &authSession{ session.Values["a"] = &authSession{
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime), Auth: true,
Auth: true, Username: tkt.User,
Username: tkt.User, Groups: tkt.Groups,
Groups: tkt.Groups,
} }
if err := sessions.Save(req, w); err != nil { if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
http.Redirect(w, req, d, http.StatusFound) http.Redirect(w, req, d, http.StatusFound)
} }
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) { func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession) {
// Delete the auth session.
session.Options.MaxAge = -1 session.Options.MaxAge = -1
delete(session.Values, "sso")
if err := sessions.Save(req, w); err != nil { if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
...@@ -182,11 +187,13 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess ...@@ -182,11 +187,13 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
} }
// Redirect to the SSO server. // Redirect to the SSO server.
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) { func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
// Generate a random nonce and store it in the local session. // Generate a random nonce and store it in the local session.
nonce := makeUniqueNonce() nonce := makeUniqueNonce()
session.Values["nonce"] = nonce session.Values["nonce"] = nonce
if err := sessions.Save(req, w); err != nil { if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
......
package httputil
import "net/http"
// WithDynamicHeaders wraps an http.Handler with cache-busting and
// security-related headers appropriate for a user-facing dynamic
// application. The 'csp' argument sets a default
// Content-Security-Policy.
func WithDynamicHeaders(h http.Handler, csp string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("Pragma", "no-cache")
hdr.Set("Cache-Control", "no-store")
hdr.Set("Expires", "-1")
hdr.Set("X-Frame-Options", "NONE")
hdr.Set("X-XSS-Protection", "1; mode=block")
hdr.Set("X-Content-Type-Options", "nosniff")
if csp != "" && hdr.Get("Content-Security-Policy") == "" {
hdr.Set("Content-Security-Policy", csp)
}
h.ServeHTTP(w, r)
})
}
package httputil
import (
"bytes"
"html/template"
"io"
"log"
"net/http"
"strconv"
"github.com/gorilla/csrf"
)
// A Renderer just renders HTML templates with some common context
// variables. Context is represented as a map[string]interface{}, to
// allow the merge operation.
type Renderer struct {
tpl *template.Template
vars map[string]interface{}
}
// NewRenderer creates a new Renderer with the provided templates and
// default variables.
func NewRenderer(tpl *template.Template, vars map[string]interface{}) *Renderer {
return &Renderer{
tpl: tpl,
vars: vars,
}
}
// Render the named HTML template to 'w'.
func (r *Renderer) Render(w http.ResponseWriter, req *http.Request, templateName string, data map[string]interface{}) {
// Merge default variables with the ones passed in 'data',
// without modifying either. Always populate the CRSFField
// variable with the current CSRF token.
vars := make(map[string]interface{})
vars["CSRFField"] = csrf.TemplateField(req)
for k, v := range r.vars {
vars[k] = v
}
for k, v := range data {
vars[k] = v
}
// Render the template into a buffer, to prevent returning
// half-rendered templates when there is an error.
var buf bytes.Buffer
if err := r.tpl.ExecuteTemplate(&buf, templateName, data); err != nil {
log.Printf("template rendering error for %s: %v", req.URL.String(), err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write our response to the client.
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
io.Copy(w, &buf) // nolint
}
...@@ -2,7 +2,11 @@ package httputil ...@@ -2,7 +2,11 @@ package httputil
import ( import (
"encoding/gob" "encoding/gob"
"log"
"net/http"
"time" "time"
"github.com/gorilla/sessions"
) )
// ExpiringSession is a session with server-side expiration check. // ExpiringSession is a session with server-side expiration check.
...@@ -12,24 +16,85 @@ import ( ...@@ -12,24 +16,85 @@ import (
// browser for the latter, but we enforce time-based expiration on the // browser for the latter, but we enforce time-based expiration on the
// server. // server.
type ExpiringSession struct { type ExpiringSession struct {
Expiry time.Time *sessions.Session
} }
// NewExpiringSession returns a session that is valid for the given // GetExpiringSession wraps a Session (obtained from 'store') with
// duration. // an ExpiringSession. If it's invalid or expired, a new empty Session
func NewExpiringSession(ttl time.Duration) *ExpiringSession { // will be created with an expiration time set using 'ttl'.
return &ExpiringSession{ func GetExpiringSession(req *http.Request, store sessions.Store, key string, ttl time.Duration) (*ExpiringSession, error) {
Expiry: time.Now().Add(ttl), now := time.Now()
// An error here just means that we failed to decode the
// existing session for some reason. A new session will always
// be returned, so we just pass along the error to the caller
// (so it can be logged).
s, err := store.Get(req, key)
// See if we have a valid session first.
if !s.IsNew {
if exp, ok := s.Values["_exp"].(time.Time); ok && now.Before(exp) {
return &ExpiringSession{Session: s}, err
}
// We can't call sessions.NewSession() because that
// won't register the session with the Registry, so it
// won't be sent with the response. Wipe the data
// instead.
for k := range s.Values {
delete(s.Values, k)
}
} }
// The session is either invalid or expired, create a new
// blank one containing no data.
expiry := now.Add(ttl)
s.Values["_exp"] = expiry
return &ExpiringSession{Session: s}, err
} }
// Valid returns true if the session has not expired yet. // Wrapper for an http.ResponseWriter that ensures all tracked
// It can be called with a nil receiver. // sessions are saved before the request body is sent.
func (e *ExpiringSession) Valid() bool { //
return e != nil && time.Now().Before(e.Expiry) // We have to duplicate the logic to call WriteHeader on the first
// Write, otherwise the underlying ResponseWriter won't call our
// WriteHeader function but its own instead.
type sessionResponseWriter struct {
http.ResponseWriter
headerWritten bool
req *http.Request
} }
func init() { func (w *sessionResponseWriter) WriteHeader(statusCode int) {
gob.Register(&ExpiringSession{}) if statusCode >= 200 && statusCode < 400 {
if err := sessions.Save(w.req, w.ResponseWriter); err != nil {
log.Printf("error saving sessions: %v", err)
}
}
w.ResponseWriter.WriteHeader(statusCode)
w.headerWritten = true
} }
func (w *sessionResponseWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(b)
}
// NewSessionResponseWriter returns a wrapped http.ResponseWriter that
// will always remember to save the Gorilla sessions before writing
// the response body.
func NewSessionResponseWriter(w http.ResponseWriter, req *http.Request) http.ResponseWriter {
return &sessionResponseWriter{
ResponseWriter: w,
req: req,
}
}
func init() {
// Register time.Time with encoding/gob, to ensure that the
// ExpiringSession timestamp can be serialized.
var t time.Time
gob.Register(t)
}
package httputil package httputil
import ( import (
"bytes"
"encoding/gob" "encoding/gob"
"reflect" "net/http"
"testing" "testing"
"time" "time"
"github.com/gorilla/sessions"
) )
type mySession struct {
Data string
}
func init() {
gob.Register(&mySession{})
}
func TestExpiringSession(t *testing.T) { func TestExpiringSession(t *testing.T) {
type mySession struct { store := sessions.NewCookieStore()
*ExpiringSession req, _ := http.NewRequest("GET", "http://localhost/", nil)
Data string
}
s := &mySession{
ExpiringSession: NewExpiringSession(60 * time.Second),
Data: "data",
}
var buf bytes.Buffer httpsess, err := GetExpiringSession(req, store, "testkey", 60*time.Second)
if err := gob.NewEncoder(&buf).Encode(s); err != nil { if err != nil {
t.Fatal("encode:", err) t.Errorf("store.Get error: %v", err)
} }
var s2 mySession
if err := gob.NewDecoder(&buf).Decode(&s2); err != nil { if _, ok := httpsess.Values["mykey"].(*mySession); ok {
t.Fatal("decode:", err) t.Fatal("got a session without any data")
}
if !reflect.DeepEqual(s.Data, s2.Data) {
t.Fatalf("sessions differ: %+v vs %+v", s, &s2)
} }
} }
package httputil
import (
"bytes"
"io/ioutil"
"net/http"
"os"
"time"
)
// StaticContent is an http.Handler that serves in-memory data as if
// it were a static file.
type StaticContent struct {
modtime time.Time
name string
data []byte
}
// LoadStaticContent creates a StaticContent by loading data from a file.
func LoadStaticContent(path string) (*StaticContent, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
data, err := ioutil.ReadFile(path) // #nosec
if err != nil {
return nil, err
}
return &StaticContent{
name: path,
modtime: stat.ModTime(),
data: data,
}, nil
}
func (c *StaticContent) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.ServeContent(w, req, c.name, c.modtime, bytes.NewReader(c.data))
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"strings" "strings"
"git.autistici.org/id/auth" "git.autistici.org/id/auth"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/mssola/user_agent" "github.com/mssola/user_agent"
) )
...@@ -35,7 +36,7 @@ type Config struct { ...@@ -35,7 +36,7 @@ type Config struct {
} }
// New returns a new Manager with the given configuration. // New returns a new Manager with the given configuration.
func New(config *Config) (*Manager, error) { func New(config *Config, urlPrefix string) (*Manager, error) {
if config == nil { if config == nil {
config = &Config{} config = &Config{}
} }
...@@ -45,9 +46,15 @@ func New(config *Config) (*Manager, error) { ...@@ -45,9 +46,15 @@ func New(config *Config) (*Manager, error) {
log.Printf("Warning: GeoIP disabled: %v", err) log.Printf("Warning: GeoIP disabled: %v", err)
} }
// This should only happen in tests.
if config.AuthKey == "" {
log.Printf("Warning: device_manager.auth_key unset, generating temporary random secrets")
config.AuthKey = string(securecookie.GenerateRandomKey(64))
}
return &Manager{ return &Manager{
geodb: geodb, geodb: geodb,
store: newStore([]byte(config.AuthKey)), store: newStore([]byte(config.AuthKey), urlPrefix),
}, nil }, nil
} }
......
...@@ -4,11 +4,11 @@ import "github.com/gorilla/sessions" ...@@ -4,11 +4,11 @@ import "github.com/gorilla/sessions"
const aVeryLongTimeInSeconds = 10 * 365 * 86400 const aVeryLongTimeInSeconds = 10 * 365 * 86400
func newStore(authKey []byte) sessions.Store { func newStore(authKey []byte, urlPrefix string) sessions.Store {
// No encryption, long-term lifetime cookie. // No encryption, long-term lifetime cookie.
store := sessions.NewCookieStore(authKey, nil) store := sessions.NewCookieStore(authKey, nil)
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: urlPrefix + "/",
HttpOnly: true, HttpOnly: true,
Secure: true, Secure: true,
MaxAge: aVeryLongTimeInSeconds, MaxAge: aVeryLongTimeInSeconds,
......
This diff is collapsed.
...@@ -161,7 +161,7 @@ func checkLoginPasswordPage(t testing.TB, resp *http.Response) { ...@@ -161,7 +161,7 @@ func checkLoginPasswordPage(t testing.TB, resp *http.Response) {
var otpFieldRx = regexp.MustCompile(`<input[^>]*name="otp"`) var otpFieldRx = regexp.MustCompile(`<input[^>]*name="otp"`)
func checkLoginOTPPage(t testing.TB, resp *http.Response) { func checkLoginOTPPage(t testing.TB, resp *http.Response) {
if resp.Request.URL.Path != "/login" { if resp.Request.URL.Path != "/login/otp" {
t.Errorf("request path is not /login (%s)", resp.Request.URL.String()) t.Errorf("request path is not /login (%s)", resp.Request.URL.String())
} }
data, err := ioutil.ReadAll(resp.Body) data, err := ioutil.ReadAll(resp.Body)
...@@ -283,7 +283,7 @@ func TestHTTP_LoginOTP(t *testing.T) { ...@@ -283,7 +283,7 @@ func TestHTTP_LoginOTP(t *testing.T) {
// 302 redirect to the target service. // 302 redirect to the target service.
v = make(url.Values) v = make(url.Values)
v.Set("otp", "123456") v.Set("otp", "123456")
doPostForm(t, httpSrv, c, "/login", v, checkRedirectToTargetService) doPostForm(t, httpSrv, c, "/login/otp", v, checkRedirectToTargetService)
} }
func createFakeKeyStore(t testing.TB, username, password string) *httptest.Server { func createFakeKeyStore(t testing.TB, username, password string) *httptest.Server {
...@@ -304,7 +304,7 @@ func createFakeKeyStore(t testing.TB, username, password string) *httptest.Serve ...@@ -304,7 +304,7 @@ func createFakeKeyStore(t testing.TB, username, password string) *httptest.Serve
t.Errorf("bad password in keystore Open request: expected %s, got %s", password, openReq.Password) t.Errorf("bad password in keystore Open request: expected %s, got %s", password, openReq.Password)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
io.WriteString(w, "{}") io.WriteString(w, "{}") // nolint
}) })
return httptest.NewServer(h) return httptest.NewServer(h)
} }
......
...@@ -67,4 +67,4 @@ func (dl DefaultLogger) LogResponse(req *http.Request, res *http.Response, err e ...@@ -67,4 +67,4 @@ func (dl DefaultLogger) LogResponse(req *http.Request, res *http.Response, err e
} }
// DefaultLoggedTransport wraps http.DefaultTransport to log using DefaultLogger // DefaultLoggedTransport wraps http.DefaultTransport to log using DefaultLogger
var DefaultLoggedTransport = NewLoggedTransport(http.DefaultTransport, DefaultLogger{})