package saml import ( "crypto/rand" "crypto/tls" "encoding/base64" "encoding/hex" "encoding/xml" "errors" "fmt" "io" "io/ioutil" "net/http" "net/url" "os" "strings" "time" "github.com/crewjam/saml" "github.com/crewjam/saml/logger" "github.com/gorilla/mux" "gopkg.in/yaml.v2" "git.autistici.org/id/go-sso/httpsso" ) type Config struct { BaseURL string `yaml:"base_url"` UsersFile string `yaml:"users_file"` // SAML X509 credentials. CertificateFile string `yaml:"certificate_file"` PrivateKeyFile string `yaml:"private_key_file"` // SSO configuration. SessionAuthKey string `yaml:"session_auth_key"` SessionEncKey string `yaml:"session_enc_key"` SSOLoginServerURL string `yaml:"sso_server_url"` SSOPublicKeyFile string `yaml:"sso_public_key_file"` SSODomain string `yaml:"sso_domain"` // Service provider config. ServiceProviders []string `yaml:"service_providers"` parsedServiceProviders map[string]*saml.EntityDescriptor } // Sanity checks for the configuration. func (c *Config) check() error { switch len(c.SessionAuthKey) { case 32, 64: case 0: return errors.New("session_auth_key is empty") default: return errors.New("session_auth_key must be a random string of 32 or 64 bytes") } switch len(c.SessionEncKey) { case 16, 24, 32: case 0: return errors.New("session_enc_key is empty") default: return errors.New("session_enc_key must be a random string of 16, 24 or 32 bytes") } if c.SSOLoginServerURL == "" { return errors.New("sso_server_url is empty") } if c.SSODomain == "" { return errors.New("sso_domain is empty") } return nil } func (c *Config) loadServiceProviders() error { c.parsedServiceProviders = make(map[string]*saml.EntityDescriptor) for _, path := range c.ServiceProviders { data, err := ioutil.ReadFile(path) if err != nil { return err } var ent saml.EntityDescriptor if err := xml.Unmarshal(data, &ent); err != nil { return err } c.parsedServiceProviders[ent.EntityID] = &ent } return nil } func (c *Config) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) { srv, ok := c.parsedServiceProviders[serviceProviderID] if !ok { return nil, os.ErrNotExist } return srv, nil } // Read users from a YAML-encoded file, in a format surprisingly // compatible with git.autistici.org/id/auth/server. // // TODO: Make it retrieve the email addresses as extra data in the SSO // token (this feature is currently unsupported by the SSO server, // even though the auth-server provides the information). type userInfo struct { Name string `yaml:"name"` Email string `yaml:"email"` } type userFileBackend struct { users map[string]userInfo } func newUserFileBackend(path string) (*userFileBackend, error) { data, err := ioutil.ReadFile(path) if err != nil { return nil, err } var userList []userInfo if err := yaml.Unmarshal(data, &userList); err != nil { return nil, err } users := make(map[string]userInfo) for _, u := range userList { users[u.Name] = u } return &userFileBackend{users}, nil } func (b *userFileBackend) GetSession(w http.ResponseWriter, r *http.Request, req *saml.IdpAuthnRequest) *saml.Session { // The request should have the X-Authenticated-User header. username := r.Header.Get("X-Authenticated-User") if username == "" { http.Error(w, "No user found", http.StatusInternalServerError) return nil } user, ok := b.users[username] if !ok { http.Error(w, "User not found", http.StatusInternalServerError) return nil } return &saml.Session{ ID: base64.StdEncoding.EncodeToString(randomBytes(32)), CreateTime: saml.TimeNow(), ExpireTime: saml.TimeNow().Add(sessionMaxAge), Index: hex.EncodeToString(randomBytes(32)), UserName: user.Name, UserEmail: user.Email, UserCommonName: user.Name, UserGivenName: user.Name, } } func NewSAMLIDP(config *Config) (http.Handler, error) { if err := config.check(); err != nil { return nil, err } if err := config.loadServiceProviders(); err != nil { return nil, err } cert, err := tls.LoadX509KeyPair(config.CertificateFile, config.PrivateKeyFile) if err != nil { return nil, err } pkey, err := ioutil.ReadFile(config.SSOPublicKeyFile) if err != nil { return nil, err } w, err := httpsso.NewSSOWrapper(config.SSOLoginServerURL, pkey, config.SSODomain, []byte(config.SessionAuthKey), []byte(config.SessionEncKey)) if err != nil { return nil, err } baseURL, err := url.Parse(config.BaseURL) if err != nil { return nil, err } ssoURL := baseURL ssoURL.Path += "/sso" metadataURL := baseURL metadataURL.Path += "/metadata" svc := fmt.Sprintf("%s%s", baseURL.Host, baseURL.Path) if !strings.HasSuffix(svc, "/") { svc += "/" } users, err := newUserFileBackend(config.UsersFile) if err != nil { return nil, err } idp := &saml.IdentityProvider{ Key: cert.PrivateKey, Certificate: cert.Leaf, Logger: logger.DefaultLogger, SSOURL: *ssoURL, ServiceProviderProvider: config, SessionProvider: users, } h := idp.Handler() root := mux.NewRouter() root.Handle(ssoURL.Path, w.Wrap(h, svc, nil)) root.Handle(metadataURL.Path, h) return root, nil } func randomBytes(n int) []byte { b := make([]byte, n) if _, err := io.ReadFull(rand.Reader, b[:]); err != nil { panic(err) } return b } var sessionMaxAge = 300 * time.Second