package httpsso

import (
	"crypto/rand"
	"encoding/gob"
	"encoding/hex"
	"io"
	"log"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/gorilla/sessions"

	"git.autistici.org/id/go-sso"
	"git.autistici.org/id/go-sso/httputil"
)

type authSession struct {
	*httputil.ExpiringSession

	Auth     bool
	Username string
}

var authSessionLifetime = 1 * time.Hour

func init() {
	gob.Register(&authSession{})
}

// SSOWrapper protects http handlers with single-sign-on authentication.
type SSOWrapper struct {
	v              sso.Validator
	sessionAuthKey []byte
	sessionEncKey  []byte
	serverURL      string
}

// NewSSOWrapper returns a new SSOWrapper that will authenticate users
// on the specified login service.
func NewSSOWrapper(serverURL string, pkey []byte, domain string, sessionAuthKey, sessionEncKey []byte) (*SSOWrapper, error) {
	v, err := sso.NewValidator(pkey, domain)
	if err != nil {
		return nil, err
	}

	return &SSOWrapper{
		v:              v,
		serverURL:      serverURL,
		sessionAuthKey: sessionAuthKey,
		sessionEncKey:  sessionEncKey,
	}, nil
}

// Wrap a http.Handler with authentication and access control.
// Currently only a simple form of group-based ACLs is supported.
func (s *SSOWrapper) Wrap(h http.Handler, service string, groups []string) http.Handler {
	svcPath := pathFromService(service)
	store := sessions.NewCookieStore(s.sessionAuthKey, s.sessionEncKey)
	store.Options = &sessions.Options{
		HttpOnly: true,
		Secure:   true,
		MaxAge:   0,
		Path:     svcPath,
	}

	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		session, _ := store.Get(req, "sso")

		switch strings.TrimPrefix(req.URL.Path, svcPath) {
		case "sso_login":
			s.handleLogin(w, req, session, service, groups)

		case "sso_logout":
			s.handleLogout(w, req, session)

		default:
			if auth, ok := session.Values["a"].(*authSession); ok && auth.Valid() && auth.Auth {
				req.Header.Set("X-Authenticated-User", auth.Username)

				h.ServeHTTP(w, req)
				return
			}

			s.redirectToLogin(w, req, session, service, groups)
		}
	})
}

func (s *SSOWrapper) handleLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
	t := req.FormValue("t")
	d := req.FormValue("d")

	// Pop the nonce from the session.
	nonce, ok := session.Values["nonce"].(string)
	if !ok || nonce == "" {
		log.Printf("got login request without nonce")
		http.Error(w, "Missing nonce", http.StatusBadRequest)
		return
	}
	delete(session.Values, "nonce")

	tkt, err := s.v.Validate(t, nonce, service, groups)
	if err != nil {
		log.Printf("validation error for token %s: %v", t, err)
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	// Authenticate the user.
	session.Values["a"] = &authSession{
		ExpiringSession: httputil.NewExpiringSession(authSessionLifetime),
		Auth:            true,
		Username:        tkt.User,
	}
	if err := sessions.Save(req, w); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	http.Redirect(w, req, d, http.StatusFound)
}

func (s *SSOWrapper) handleLogout(w http.ResponseWriter, req *http.Request, session *sessions.Session) {
	session.Options.MaxAge = -1
	if err := sessions.Save(req, w); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "text/plain")
	w.Header().Set("Access-Control-Allow-Origin", strings.TrimRight(s.serverURL, "/"))
	w.Header().Set("Access-Control-Allow-Credentials", "true")
	io.WriteString(w, "OK")
}

// Redirect to the SSO server.
func (s *SSOWrapper) redirectToLogin(w http.ResponseWriter, req *http.Request, session *sessions.Session, service string, groups []string) {
	// Generate a random nonce and store it in the local session.
	nonce := makeUniqueNonce()
	session.Values["nonce"] = nonce
	if err := sessions.Save(req, w); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	v := make(url.Values)
	v.Set("s", service)
	v.Set("d", getFullURL(req, "https").String())
	v.Set("n", nonce)
	v.Set("g", strings.Join(groups, ","))
	loginURL := s.serverURL + "?" + v.Encode()
	http.Redirect(w, req, loginURL, http.StatusFound)
}

// Extract the URL path from the service specification. The result
// will have both a leading and a trailing slash.
func pathFromService(service string) string {
	i := strings.IndexRune(service, '/')
	if i < 0 {
		return ""
	}
	return service[i:]
}

// Return a full URL from a HTTP request, assuming the given scheme
// (the URL field in net/http.Request normally only contains path and
// query args).
func getFullURL(req *http.Request, scheme string) *url.URL {
	u := *req.URL
	u.Scheme = scheme
	u.Host = req.Host
	return &u
}

func makeUniqueNonce() string {
	var b [8]byte
	if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
		panic(err)
	}
	return hex.EncodeToString(b[:])
}