Commit 4d70b167 authored by ale's avatar ale

Refactor the login handler

The login handler is now a simpler, standalone http.Handler
wrapper. The separation between the SSO application and the login
handler is now fairly complete.

The login handler no longer forces the user to a specific workflow via
session cookies, but it works on a request-by-request basis instead,
which makes the "back" button works as expected (allowing the user to
bail out of a broken 2FA process, for example).

Session handling has been simplified as well: there is a single
session for authentication and login state, which should remove the
opportunity for session synchronization errors.
parent 6d3a620e
Pipeline #5345 passed with stages
in 3 minutes and 18 seconds
......@@ -20,8 +20,6 @@ import (
)
type authSession struct {
*httputil.ExpiringSession
Auth bool
Username string
Groups []string
......@@ -29,7 +27,7 @@ type authSession struct {
type authSessionKeyType int
var authSessionKey authSessionKeyType = 42
const authSessionKey authSessionKeyType = 0
func getCurrentAuthSession(req *http.Request) *authSession {
s, ok := req.Context().Value(authSessionKey).(*authSession)
......@@ -64,7 +62,7 @@ func Groups(req *http.Request) []string {
return nil
}
var authSessionLifetime = 1 * time.Hour
var defaultAuthSessionTTL = 1 * time.Hour
func init() {
gob.Register(&authSession{})
......@@ -77,6 +75,8 @@ type SSOWrapper struct {
sessionEncKey []byte
serverURL string
serverOrigin string
TTL time.Duration
}
// NewSSOWrapper returns a new SSOWrapper that will authenticate users
......@@ -93,6 +93,7 @@ func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey,
serverOrigin: originFromURL(serverURL),
sessionAuthKey: sessionAuthKey,
sessionEncKey: sessionEncKey,
TTL: defaultAuthSessionTTL,
}, nil
}
......@@ -109,7 +110,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
session, _ := store.Get(req, "sso")
session, _ := httputil.GetExpiringSession(req, store, "sso", s.TTL)
switch strings.TrimPrefix(req.URL.Path, svcPath) {
case "sso_login":
......@@ -119,11 +120,11 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
s.handleLogout(w, req, session)
default:
if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
if auth, ok := session.Values["a"].(*authSession); ok && auth.Auth {
req.Header.Set("X-Authenticated-User", auth.Username)
ctx := context.WithValue(req.Context(), authSessionKey, auth)
h.ServeHTTP(w, req.WithContext(ctx))
req = req.WithContext(context.WithValue(req.Context(), authSessionKey, auth))
h.ServeHTTP(w, req)
return
}
......@@ -132,7 +133,7 @@ func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.
})
}
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
t := req.FormValue("t")
d := req.FormValue("d")
......@@ -154,21 +155,25 @@ func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, sessi
// Authenticate the user.
session.Values["a"] = &authSession{
ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
Auth: true,
Username: tkt.User,
Groups: tkt.Groups,
}
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, req, d, http.StatusFound)
}
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) {
func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession) {
// Delete the auth session.
session.Options.MaxAge = -1
delete(session.Values, "sso")
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
......@@ -182,11 +187,13 @@ func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, sess
}
// Redirect to the SSO server.
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *httputil.ExpiringSession, service string, groups []string) {
// Generate a random nonce and store it in the local session.
nonce := makeUniqueNonce()
session.Values["nonce"] = nonce
if err := sessions.Save(req, w); err != nil {
log.Printf("error saving SSO session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
......
package httputil
import "net/http"
// WithDynamicHeaders wraps an http.Handler with cache-busting and
// security-related headers appropriate for a user-facing dynamic
// application. The 'csp' argument sets a default
// Content-Security-Policy.
func WithDynamicHeaders(h http.Handler, csp string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("Pragma", "no-cache")
hdr.Set("Cache-Control", "no-store")
hdr.Set("Expires", "-1")
hdr.Set("X-Frame-Options", "NONE")
hdr.Set("X-XSS-Protection", "1; mode=block")
hdr.Set("X-Content-Type-Options", "nosniff")
if csp != "" && hdr.Get("Content-Security-Policy") == "" {
hdr.Set("Content-Security-Policy", csp)
}
h.ServeHTTP(w, r)
})
}
package httputil
import (
"bytes"
"html/template"
"io"
"log"
"net/http"
"strconv"
"github.com/gorilla/csrf"
)
// A Renderer just renders HTML templates with some common context
// variables. Context is represented as a map[string]interface{}, to
// allow the merge operation.
type Renderer struct {
tpl *template.Template
vars map[string]interface{}
}
// NewRenderer creates a new Renderer with the provided templates and
// default variables.
func NewRenderer(tpl *template.Template, vars map[string]interface{}) *Renderer {
return &Renderer{
tpl: tpl,
vars: vars,
}
}
// Render the named HTML template to 'w'.
func (r *Renderer) Render(w http.ResponseWriter, req *http.Request, templateName string, data map[string]interface{}) {
// Merge default variables with the ones passed in 'data',
// without modifying either. Always populate the CRSFField
// variable with the current CSRF token.
vars := make(map[string]interface{})
vars["CSRFField"] = csrf.TemplateField(req)
for k, v := range r.vars {
vars[k] = v
}
for k, v := range data {
vars[k] = v
}
// Render the template into a buffer, to prevent returning
// half-rendered templates when there is an error.
var buf bytes.Buffer
if err := r.tpl.ExecuteTemplate(&buf, templateName, data); err != nil {
log.Printf("template rendering error for %s: %v", req.URL.String(), err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write our response to the client.
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
io.Copy(w, &buf) // nolint
}
......@@ -2,7 +2,11 @@ package httputil
import (
"encoding/gob"
"log"
"net/http"
"time"
"github.com/gorilla/sessions"
)
// ExpiringSession is a session with server-side expiration check.
......@@ -12,24 +16,85 @@ import (
// browser for the latter, but we enforce time-based expiration on the
// server.
type ExpiringSession struct {
Expiry time.Time
*sessions.Session
}
// NewExpiringSession returns a session that is valid for the given
// duration.
func NewExpiringSession(ttl time.Duration) *ExpiringSession {
return &ExpiringSession{
Expiry: time.Now().Add(ttl),
// GetExpiringSession wraps a Session (obtained from 'store') with
// an ExpiringSession. If it's invalid or expired, a new empty Session
// will be created with an expiration time set using 'ttl'.
func GetExpiringSession(req *http.Request, store sessions.Store, key string, ttl time.Duration) (*ExpiringSession, error) {
now := time.Now()
// An error here just means that we failed to decode the
// existing session for some reason. A new session will always
// be returned, so we just pass along the error to the caller
// (so it can be logged).
s, err := store.Get(req, key)
// See if we have a valid session first.
if !s.IsNew {
if exp, ok := s.Values["_exp"].(time.Time); ok && now.Before(exp) {
return &ExpiringSession{Session: s}, err
}
// We can't call sessions.NewSession() because that
// won't register the session with the Registry, so it
// won't be sent with the response. Wipe the data
// instead.
for k := range s.Values {
delete(s.Values, k)
}
}
// The session is either invalid or expired, create a new
// blank one containing no data.
expiry := now.Add(ttl)
s.Values["_exp"] = expiry
return &ExpiringSession{Session: s}, err
}
// Valid returns true if the session has not expired yet.
// It can be called with a nil receiver.
func (e *ExpiringSession) Valid() bool {
return e != nil && time.Now().Before(e.Expiry)
// Wrapper for an http.ResponseWriter that ensures all tracked
// sessions are saved before the request body is sent.
//
// We have to duplicate the logic to call WriteHeader on the first
// Write, otherwise the underlying ResponseWriter won't call our
// WriteHeader function but its own instead.
type sessionResponseWriter struct {
http.ResponseWriter
headerWritten bool
req *http.Request
}
func init() {
gob.Register(&ExpiringSession{})
func (w *sessionResponseWriter) WriteHeader(statusCode int) {
if statusCode >= 200 && statusCode < 400 {
if err := sessions.Save(w.req, w.ResponseWriter); err != nil {
log.Printf("error saving sessions: %v", err)
}
}
w.ResponseWriter.WriteHeader(statusCode)
w.headerWritten = true
}
func (w *sessionResponseWriter) Write(b []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(b)
}
// NewSessionResponseWriter returns a wrapped http.ResponseWriter that
// will always remember to save the Gorilla sessions before writing
// the response body.
func NewSessionResponseWriter(w http.ResponseWriter, req *http.Request) http.ResponseWriter {
return &sessionResponseWriter{
ResponseWriter: w,
req: req,
}
}
func init() {
// Register time.Time with encoding/gob, to ensure that the
// ExpiringSession timestamp can be serialized.
var t time.Time
gob.Register(t)
}
package httputil
import (
"bytes"
"encoding/gob"
"reflect"
"net/http"
"testing"
"time"
"github.com/gorilla/sessions"
)
type mySession struct {
Data string
}
func init() {
gob.Register(&mySession{})
}
func TestExpiringSession(t *testing.T) {
type mySession struct {
*ExpiringSession
Data string
}
s := &mySession{
ExpiringSession: NewExpiringSession(60 * time.Second),
Data: "data",
}
store := sessions.NewCookieStore()
req, _ := http.NewRequest("GET", "http://localhost/", nil)
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(s); err != nil {
t.Fatal("encode:", err)
httpsess, err := GetExpiringSession(req, store, "testkey", 60*time.Second)
if err != nil {
t.Errorf("store.Get error: %v", err)
}
var s2 mySession
if err := gob.NewDecoder(&buf).Decode(&s2); err != nil {
t.Fatal("decode:", err)
}
if !reflect.DeepEqual(s.Data, s2.Data) {
t.Fatalf("sessions differ: %+v vs %+v", s, &s2)
if _, ok := httpsess.Values["mykey"].(*mySession); ok {
t.Fatal("got a session without any data")
}
}
package httputil
import (
"bytes"
"io/ioutil"
"net/http"
"os"
"time"
)
// StaticContent is an http.Handler that serves in-memory data as if
// it were a static file.
type StaticContent struct {
modtime time.Time
name string
data []byte
}
// LoadStaticContent creates a StaticContent by loading data from a file.
func LoadStaticContent(path string) (*StaticContent, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
data, err := ioutil.ReadFile(path) // #nosec
if err != nil {
return nil, err
}
return &StaticContent{
name: path,
modtime: stat.ModTime(),
data: data,
}, nil
}
func (c *StaticContent) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.ServeContent(w, req, c.name, c.modtime, bytes.NewReader(c.data))
}
......@@ -10,6 +10,7 @@ import (
"strings"
"git.autistici.org/id/auth"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/mssola/user_agent"
)
......@@ -35,7 +36,7 @@ type Config struct {
}
// New returns a new Manager with the given configuration.
func New(config *Config) (*Manager, error) {
func New(config *Config, urlPrefix string) (*Manager, error) {
if config == nil {
config = &Config{}
}
......@@ -45,9 +46,15 @@ func New(config *Config) (*Manager, error) {
log.Printf("Warning: GeoIP disabled: %v", err)
}
// This should only happen in tests.
if config.AuthKey == "" {
log.Printf("Warning: device_manager.auth_key unset, generating temporary random secrets")
config.AuthKey = string(securecookie.GenerateRandomKey(64))
}
return &Manager{
geodb: geodb,
store: newStore([]byte(config.AuthKey)),
store: newStore([]byte(config.AuthKey), urlPrefix),
}, nil
}
......
......@@ -4,11 +4,11 @@ import "github.com/gorilla/sessions"
const aVeryLongTimeInSeconds = 10 * 365 * 86400
func newStore(authKey []byte) sessions.Store {
func newStore(authKey []byte, urlPrefix string) sessions.Store {
// No encryption, long-term lifetime cookie.
store := sessions.NewCookieStore(authKey, nil)
store.Options = &sessions.Options{
Path: "/",
Path: urlPrefix + "/",
HttpOnly: true,
Secure: true,
MaxAge: aVeryLongTimeInSeconds,
......
This diff is collapsed.
......@@ -161,7 +161,7 @@ func checkLoginPasswordPage(t testing.TB, resp *http.Response) {
var otpFieldRx = regexp.MustCompile(`<input[^>]*name="otp"`)
func checkLoginOTPPage(t testing.TB, resp *http.Response) {
if resp.Request.URL.Path != "/login" {
if resp.Request.URL.Path != "/login/otp" {
t.Errorf("request path is not /login (%s)", resp.Request.URL.String())
}
data, err := ioutil.ReadAll(resp.Body)
......@@ -283,7 +283,7 @@ func TestHTTP_LoginOTP(t *testing.T) {
// 302 redirect to the target service.
v = make(url.Values)
v.Set("otp", "123456")
doPostForm(t, httpSrv, c, "/login", v, checkRedirectToTargetService)
doPostForm(t, httpSrv, c, "/login/otp", v, checkRedirectToTargetService)
}
func createFakeKeyStore(t testing.TB, username, password string) *httptest.Server {
......@@ -304,7 +304,7 @@ func createFakeKeyStore(t testing.TB, username, password string) *httptest.Serve
t.Errorf("bad password in keystore Open request: expected %s, got %s", password, openReq.Password)
}
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, "{}")
io.WriteString(w, "{}") // nolint
})
return httptest.NewServer(h)
}
......
......@@ -67,4 +67,4 @@ func (dl DefaultLogger) LogResponse(req *http.Request, res *http.Response, err e
}
// DefaultLoggedTransport wraps http.DefaultTransport to log using DefaultLogger
var DefaultLoggedTransport = NewLoggedTransport(http.DefaultTransport, DefaultLogger{})
//var DefaultLoggedTransport = NewLoggedTransport(http.DefaultTransport, DefaultLogger{})
This diff is collapsed.
This diff is collapsed.
......@@ -29,8 +29,8 @@ func testConfig(t testing.TB, tmpdir, keystoreURL string) *Config {
if err != nil {
t.Fatal(err)
}
ioutil.WriteFile(filepath.Join(tmpdir, "secret"), priv, 0600)
ioutil.WriteFile(filepath.Join(tmpdir, "public"), pub, 0600)
ioutil.WriteFile(filepath.Join(tmpdir, "secret"), priv, 0600) // nolint
ioutil.WriteFile(filepath.Join(tmpdir, "public"), pub, 0600) // nolint
cfgstr := fmt.Sprintf(`---
secret_key_file: %s
......@@ -51,7 +51,7 @@ keystore:
url: "%s"
`, keystoreURL)
}
ioutil.WriteFile(filepath.Join(tmpdir, "config"), []byte(cfgstr), 0600)
ioutil.WriteFile(filepath.Join(tmpdir, "config"), []byte(cfgstr), 0600) // nolint
config, err := loadConfig(filepath.Join(tmpdir, "config"))
if err != nil {
......
package server
// Returns the intersection of two string lists (in O(N^2) time).
func intersectGroups(a, b []string) []string {
var out []string
for _, aa := range a {
for _, bb := range b {
if aa == bb {
out = append(out, aa)
break
}
}
}
return out
}
......@@ -7,4 +7,4 @@ gorilla/context is a general purpose registry for global request variables.
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context
Read the full documentation here: https://www.gorillatoolkit.org/pkg/context
......@@ -25,7 +25,7 @@ with Go's `net/http` package (or any framework supporting `http.Handler`), inclu
* [**RecoveryHandler**](https://godoc.org/github.com/gorilla/handlers#RecoveryHandler) for recovering from unexpected panics.
Other handlers are documented [on the Gorilla
website](https://www.gorillatoolkit.org/pkg/handlers).
website](http://www.gorillatoolkit.org/pkg/handlers).
## Example
......
......@@ -19,16 +19,14 @@ type cors struct {
maxAge int
ignoreOptions bool
allowCredentials bool
optionStatusCode int
}
// OriginValidator takes an origin string and returns whether or not that origin is allowed.
type OriginValidator func(string) bool
var (
defaultCorsOptionStatusCode = 200
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
)
......@@ -132,7 +130,6 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set(corsAllowOriginHeader, returnOrigin)
if r.Method == corsOptionMethod {
w.WriteHeader(ch.optionStatusCode)
return
}
ch.h.ServeHTTP(w, r)
......@@ -167,10 +164,9 @@ func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
func parseCORSOptions(opts ...CORSOption) *cors {
ch := &cors{
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
optionStatusCode: defaultCorsOptionStatusCode,
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
}
for _, option := range opts {
......@@ -255,20 +251,7 @@ func AllowedOriginValidator(fn OriginValidator) CORSOption {
}
}
// OptionStatusCode sets a custom status code on the OPTIONS requests.
// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
// and can be used if you need a custom status code (i.e 204).
//
// More informations on the spec:
// https://fetch.spec.whatwg.org/#cors-preflight-fetch
func OptionStatusCode(code int) CORSOption {
return func(ch *cors) error {
ch.optionStatusCode = code
return nil
}
}
// ExposedHeaders can be used to specify headers that are available
// ExposeHeaders can be used to specify headers that are available
// and will not be stripped out by the user-agent.
func ExposedHeaders(headers []string) CORSOption {
return func(ch *cors) error {
......
# This is the official list of gorilla/mux authors for copyright purposes.
#
# Please keep the list sorted.
Google LLC (https://opensource.google.com/)
Kamil Kisielk <kamil@kamilkisiel.net>
Matt Silverlock <matt@eatsleeprepeat.net>
Rodrigo Moraes (https://github.com/moraes)
**What version of Go are you running?** (Paste the output of `go version`)
**What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`)
**Describe your problem** (and what you have tried so far)
**Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it)
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
......
This diff is collapsed.
// +build !go1.7
package mux
import (
"net/http"
"github.com/gorilla/context"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return context.Get(r, key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
context.Set(r, key, val)
return r
}
func contextClear(r *http.Request) {
context.Clear(r)
}
......@@ -238,5 +238,69 @@ as well:
url, err := r.Get("article").URL("subdomain", "news",
"category", "technology",
"id", "42")
Mux supports the addition of middlewares to a Router, which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or ResponseWriter hijacking.
type MiddlewareFunc func(http.Handler) http.Handler
Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc (closures can access variables from the context where they are created).
A very basic middleware which logs the URI of the request being handled could be written as:
func simpleMw(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println(r.RequestURI)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}
Middlewares can be added to a router using `Router.Use()`:
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.Use(simpleMw)
A more complex authentication middleware, which maps session token to users, could be written as:
// Define our struct
type authenticationMiddleware struct {
tokenUsers map[string]string
}
// Initialize it somewhere
func (amw *authenticationMiddleware) Populate() {
amw.tokenUsers["00000000"] = "user0"
amw.tokenUsers["aaaaaaaa"] = "userA"
amw.tokenUsers["05f717e5"] = "randomUser"
amw.tokenUsers["deadbeef"] = "user0"
}
// Middleware function, which will be called for each request
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Session-Token")
if user, found := amw.tokenUsers[token]; found {
// We found the token in our map