Commit c1717fb2 authored by ale's avatar ale

Initial commit

parents
package main
import (
"context"
"flag"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"git.autistici.org/id/auth/client"
"git.autistici.org/id/go-sso/server"
)
var (
addr = flag.String("addr", ":4141", "tcp `address` to listen on")
configPath = flag.String("config", "/etc/sso/server.yml", "configuration `file`")
authSocket = flag.String("auth-socket", client.DefaultSocketPath, "authentication socket `path`")
)
func main() {
log.SetFlags(0)
flag.Parse()
config, err := server.LoadConfig(*configPath)
if err != nil {
log.Fatal(err)
}
loginService, err := server.NewLoginService(config)
if err != nil {
log.Fatal(err)
}
authClient := client.New(*authSocket)
httpSrv, err := server.New(loginService, authClient, config)
if err != nil {
log.Fatal(err)
}
srv := &http.Server{
Addr: *addr,
Handler: httpSrv.Handler(),
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
done := make(chan struct{})
sigCh := make(chan os.Signal, 1)
go func() {
<-sigCh
log.Printf("exiting")
// Gracefully terminate for 3 seconds max, then shut
// down remaining clients.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err == context.Canceled {
srv.Close()
}
close(done)
}()
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
if err = srv.ListenAndServe(); err != http.ErrServerClosed {
log.Fatal("error: %v", err)
}
<-done
}
This source diff could not be displayed because it is too large. You can view the blob instead.
package server
import (
"errors"
"io/ioutil"
"log"
"regexp"
"time"
"git.autistici.org/id/go-sso/server/device"
"github.com/gorilla/securecookie"
"gopkg.in/yaml.v2"
)
// Config data for the SSO service.
type Config struct {
SecretKeyFile string `yaml:"secret_key_file"`
PublicKeyFile string `yaml:"public_key_file"`
Domain string `yaml:"domain"`
AllowedServices []string `yaml:"allowed_services"`
AllowedExchanges []*struct {
SrcRegexp string `yaml:"src_regexp"`
DstRegexp string `yaml:"dst_regexp"`
srcRx *regexp.Regexp
dstRx *regexp.Regexp
} `yaml:"allowed_exchanges"`
ServiceTTLs []*struct {
Regexp string `yaml:"regexp"`
TTLSeconds int `yaml:"ttl"`
rx *regexp.Regexp
} `yaml:"service_ttls"`
AuthSessionLifetimeSeconds int `yaml:"auth_session_lifetime"`
SessionSecrets [][]byte `yaml:"session_secrets"`
CSRFSecret []byte `yaml:"csrf_secret"`
AuthService string `yaml:"auth_service"`
DeviceManager *device.Config `yaml:"device_manager"`
allowedServicesRx []*regexp.Regexp
}
// LoadConfig reads configuration from a file.
func LoadConfig(path string) (*Config, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, err
}
if err := config.valid(); err != nil {
return nil, err
}
if err := config.compile(); err != nil {
return nil, err
}
return &config, nil
}
// Check syntax (missing required values).
func (c *Config) valid() error {
if c.SecretKeyFile == "" {
return errors.New("secret_key_file is empty")
}
if c.PublicKeyFile == "" {
return errors.New("public_key_file is empty")
}
if c.Domain == "" {
return errors.New("domain is empty")
}
if len(c.AllowedServices) == 0 {
return errors.New("the list of allowed services is empty")
}
if c.AuthService == "" {
return errors.New("auth_service is empty")
}
// Some things we can autogenerate, but for testing purposes
// only. Print a warning.
if len(c.SessionSecrets) == 0 {
log.Printf("Warning: session_secrets unset, generating temporary random session secrets")
c.SessionSecrets = [][]byte{
securecookie.GenerateRandomKey(64),
securecookie.GenerateRandomKey(32),
}
}
return nil
}
// Compile the configuration (regular expressions etc).
func (c *Config) compile() error {
var err error
for _, svcttl := range c.ServiceTTLs {
svcttl.rx, err = regexp.Compile(svcttl.Regexp)
if err != nil {
return err
}
}
for _, xch := range c.AllowedExchanges {
xch.srcRx, err = regexp.Compile(xch.SrcRegexp)
if err != nil {
return err
}
xch.dstRx, err = regexp.Compile(xch.DstRegexp)
if err != nil {
return err
}
}
for _, pattern := range c.AllowedServices {
rx, err := regexp.Compile(pattern)
if err != nil {
return err
}
c.allowedServicesRx = append(c.allowedServicesRx, rx)
}
return nil
}
var defaultServiceTTL = 300 * time.Second
func (c *Config) getServiceTTL(service string) time.Duration {
for _, svcttl := range c.ServiceTTLs {
if svcttl.rx.MatchString(service) {
return time.Duration(svcttl.TTLSeconds) * time.Second
}
}
return defaultServiceTTL
}
func (c *Config) isServiceAllowed(service string) bool {
for _, rx := range c.allowedServicesRx {
if rx.MatchString(service) {
return true
}
}
return false
}
func (c *Config) isExchangeAllowed(src, dst string) bool {
for _, xch := range c.AllowedExchanges {
if xch.srcRx.MatchString(src) && xch.dstRx.MatchString(dst) {
return true
}
}
return false
}
package device
import (
"fmt"
"net"
"net/http"
"strings"
)
func (m *Manager) getIPFromRequest(req *http.Request) net.IP {
// Parse the RemoteAddr Request field, for starters.
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
host = req.RemoteAddr
}
ip := net.ParseIP(host)
if ip == nil {
return nil
}
// See if it's a trusted forwarder, in which case go for the
// X-Forwarded-For header.
if matchIPNetList(ip, m.trustedForwarders) {
fwdAddr := req.Header.Get(m.remoteAddrHeader)
if fwdAddr == "" {
return nil
}
ip = net.ParseIP(strings.Split(fwdAddr, ", ")[0])
if ip == nil {
return nil
}
}
return ip
}
func fullMask(ip net.IP) net.IPMask {
if ip.To4() == nil {
return net.CIDRMask(128, 128)
}
return net.CIDRMask(32, 32)
}
// ParseIPNetList turns a comma-separated list of IP addresses or CIDR
// networks into a net.IPNet slice.
func parseIPNetList(iplist []string) ([]net.IPNet, error) {
var nets []net.IPNet
for _, s := range iplist {
if s == "" {
continue
}
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
ip := net.ParseIP(s)
if ip == nil {
return nil, fmt.Errorf("could not parse '%s'", s)
}
ipnet = &net.IPNet{IP: ip, Mask: fullMask(ip)}
}
nets = append(nets, *ipnet)
}
return nets, nil
}
// MatchIPNetList returns true if the given IP address matches one of
// the specified networks.
func matchIPNetList(ip net.IP, nets []net.IPNet) bool {
for _, n := range nets {
if n.Contains(ip) {
return true
}
}
return false
}
package device
import "net"
func (m *Manager) getZoneForIP(ip net.IP) (string, error) {
if m.geodb == nil {
return "", nil
}
// Only look up a single attribute (country).
var record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
if err := m.geodb.Lookup(ip, &record); err != nil {
return "", err
}
return record.Country.ISOCode, nil
}
package device
import (
"crypto/rand"
"encoding/hex"
"net"
"net/http"
"git.autistici.org/id/auth"
"github.com/gorilla/sessions"
"github.com/mssola/user_agent"
"github.com/oschwald/maxminddb-golang"
)
func randomDeviceID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
// Manager can provide DeviceInfo entries for incoming HTTP requests.
type Manager struct {
store sessions.Store
geodb *maxminddb.Reader
trustedForwarders []net.IPNet
remoteAddrHeader string
}
// Config stores options for the device info manager.
type Config struct {
AuthKey []byte `yaml:"auth_key"`
GeoIPDataFile string `yaml:"geo_ip_data"`
TrustedForwarders []string `yaml:"trusted_forwarders"`
RemoteAddrHeader string `yaml:"remote_addr_header"`
}
// New returns a new Manager with the given configuration.
func New(config *Config) (*Manager, error) {
if config == nil {
config = &Config{}
}
tf, err := parseIPNetList(config.TrustedForwarders)
if err != nil {
return nil, err
}
var geodb *maxminddb.Reader
if config.GeoIPDataFile != "" {
geodb, err = maxminddb.Open(config.GeoIPDataFile)
if err != nil {
return nil, err
}
}
// The remote IP header (if any) defaults to X-Forwarded-For.
hdr := "X-Forwarded-For"
if config.RemoteAddrHeader != "" {
hdr = config.RemoteAddrHeader
}
return &Manager{
geodb: geodb,
store: newStore(config.AuthKey),
trustedForwarders: tf,
remoteAddrHeader: hdr,
}, nil
}
const deviceIDSessionName = "_dev"
// GetDeviceInfoFromRequest will retrieve or create a DeviceInfo
// object for the given request. It will always return a valid object.
// The ResponseWriter is needed to store the unique ID on the client
// when a new device info object is created.
func (m *Manager) GetDeviceInfoFromRequest(w http.ResponseWriter, req *http.Request) *auth.DeviceInfo {
session, _ := m.store.Get(req, deviceIDSessionName)
devID, ok := session.Values["id"].(string)
if !ok || devID == "" {
// Generate a new Device ID and save it on the client.
devID = randomDeviceID()
session.Values["id"] = devID
session.Save(req, w)
}
uaStr := req.UserAgent()
ua := user_agent.New(uaStr)
browser, _ := ua.Browser()
d := auth.DeviceInfo{
ID: devID,
UserAgent: uaStr,
Mobile: ua.Mobile(),
OS: ua.OS(),
Browser: browser,
}
if ip := m.getIPFromRequest(req); ip != nil {
d.RemoteAddr = ip.String()
if zone, err := m.getZoneForIP(ip); err == nil {
d.RemoteZone = zone
}
}
return &d
}
package device
import "github.com/gorilla/sessions"
const aVeryLongTimeInSeconds = 10 * 365 * 86400
func newStore(authKey []byte) sessions.Store {
// No encryption, long-term lifetime cookie.
store := sessions.NewCookieStore(authKey, nil)
store.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
Secure: true,
MaxAge: aVeryLongTimeInSeconds,
}
return store
}
package server
//go:generate go-bindata --nocompress --pkg server static/... templates/...
import (
"encoding/gob"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"git.autistici.org/id/auth"
authclient "git.autistici.org/id/auth/client"
"git.autistici.org/id/go-sso/server/device"
)
const authSessionKey = "_auth"
type authSession struct {
*ExpiringSession
// User name and other information (like group membership).
Username string
UserInfo *auth.UserInfo
// Services the user has logged in to from this session.
Services []string
}
// AddService adds a service to the current session (if it's not
// already there).
func (s *authSession) AddService(service string) {
for _, svc := range s.Services {
if svc == service {
return
}
}
s.Services = append(s.Services, service)
}
// By default, make users log in again after (almost) one day.
var defaultAuthSessionLifetime = 20 * time.Hour
func newAuthSession(ttl time.Duration, username string, userinfo *auth.UserInfo) *authSession {
return &authSession{
ExpiringSession: newExpiringSession(ttl),
Username: username,
UserInfo: userinfo,
}
}
func init() {
gob.Register(&authSession{})
}
// Server for the SSO protocol. Provides the HTTP interface to a
// LoginService.
type Server struct {
authSessionStore sessions.Store
authSessionLifetime time.Duration
loginHandler *loginHandler
loginService *LoginService
csrfSecret []byte
}
// New returns a new Server.
func New(loginService *LoginService, authClient authclient.Client, config *Config) (*Server, error) {
store := sessions.NewCookieStore(config.SessionSecrets...)
store.Options = &sessions.Options{
HttpOnly: true,
Secure: true,
MaxAge: 0,
Path: "/",
}
s := &Server{
authSessionLifetime: defaultAuthSessionLifetime,
authSessionStore: store,
loginService: loginService,
csrfSecret: config.CSRFSecret,
}
if config.AuthSessionLifetimeSeconds > 0 {
s.authSessionLifetime = time.Duration(config.AuthSessionLifetimeSeconds) * time.Second
}
devMgr, err := device.New(config.DeviceManager)
if err != nil {
return nil, err
}
s.loginHandler = newLoginHandler(s.loginCallback, devMgr, authClient, config.AuthService, config.SessionSecrets...)
return s, nil
}
func (h *Server) loginCallback(w http.ResponseWriter, req *http.Request, username string, userinfo *auth.UserInfo) error {
session := newAuthSession(h.authSessionLifetime, username, userinfo)
httpSession, _ := h.authSessionStore.Get(req, authSessionKey)
httpSession.Values["auth"] = session
return httpSession.Save(req, w)
}
func (h *Server) withAuth(f func(http.ResponseWriter, *http.Request, *authSession)) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
httpSession, err := h.authSessionStore.Get(req, authSessionKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session, ok := httpSession.Values["auth"].(*authSession)
if ok && session != nil && session.Valid() {
f(w, req, session)
return
}
httpSession.Options.MaxAge = -1
_ = httpSession.Save(req, w)
http.Redirect(w, req, makeLoginURL(req), http.StatusFound)
})
}
func makeLoginURL(req *http.Request) string {
// Just concatenate path and raw request string.
v := make(url.Values)
v.Set("r", req.URL.Path+"?"+req.URL.RawQuery)
return "/login?" + v.Encode()
}
// Homepage handler. Authorizes an authenticated user to a service by
// signing a token with the user's identity. The client is redirected
// back to the service, with the signed token.
func (h *Server) handleHomepage(w http.ResponseWriter, req *http.Request, session *authSession) {
// Extract the authorization request parameters from the HTTP
// request.
username := session.Username
service := req.FormValue("s")
destination := req.FormValue("d")
nonce := req.FormValue("n")
var groups []string
reqGroups := strings.Split(req.FormValue("g"), ",")
if len(reqGroups) > 0 && session.UserInfo != nil {
groups = intersectGroups(reqGroups, session.UserInfo.Groups)
// We only make this check here as a convenience to
// the user (we may be able to show a nicer UI): the
// actual group ACL must be applied on the destination
// service, because the 'g' parameter is untrusted at
// this stage.
if len(groups) == 0 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
// Make the authorization request.
token, err := h.loginService.Authorize(username, service, destination, nonce, groups)
if err != nil {
log.Printf("auth error: %v: user=%s service=%s destination=%s nonce=%s groups=%s", err, username, service, destination, nonce, req.FormValue("g"))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
session.AddService(service)
_ = sessions.Save(req, w)
// Redirect to service callback.
callbackURL := serviceCallback(service, destination, token)
http.Redirect(w, req, callbackURL, http.StatusFound)
}
// Returns the URL of the login handler on the target service.
func serviceCallback(service, destination, token string) string {
v := make(url.Values)
v.Set("t", token)
v.Set("d", destination)
return fmt.Sprintf("https://%ssso_login?%s", service, v.Encode())
}
func (h *Server) handleExchange(w http.ResponseWriter, req *http.Request) {
curToken := req.FormValue("cur_tkt")
curService := req.FormValue("cur_svc")
curNonce := req.FormValue("cur_nonce")
newService := req.FormValue("new_svc")
newNonce := req.FormValue("new_nonce")
reqGroups := strings.Split(req.FormValue("new_groups"), ",")
token, err := h.loginService.Exchange(curToken, curService, curNonce, newService, newNonce, reqGroups)
if err != nil {
log.Printf("exchange error: %v", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/plain")
_, _ = io.WriteString(w, token)
}
// Handler returns the http.Handler for the SSO server application.
func (h *Server) Handler() http.Handler {
m := mux.NewRouter()
var lh http.Handler
if h.csrfSecret != nil {
lh = csrf.Protect(h.csrfSecret)(h.loginHandler)
} else {
lh = h.loginHandler
}
m.Handle("/login", withDynamicHeaders(lh))
m.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(&assetfs.AssetFS{
Asset: Asset,
AssetDir: AssetDir,
AssetInfo: AssetInfo,
Prefix: "static",
})))
m.Handle("/exchange", withDynamicHeaders(http.HandlerFunc(h.handleExchange)))
m.Handle("/", withDynamicHeaders(h.withAuth(h.handleHomepage)))
return m
}
// A relatively strict CSP.
const contentSecurityPolicy = "default-src 'none'; img-src 'self' data:; script-src 'self'; style-src 'self'; connect-src 'self';"
func withDynamicHeaders(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Expires", "-1")
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
w.Header().Set("X-Frame-Options", "NONE")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("X-Content-Type-Options", "nosniff")
h.ServeHTTP(w, r)
})
}
package server
import (
"context"
"crypto/tls"
"errors"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"regexp"
"strings"
"testing"
"git.autistici.org/id/auth"
)
type fakeAuthClient struct{}