package saml

import (
	"crypto/rand"
	"crypto/tls"
	"encoding/base64"
	"encoding/hex"
	"encoding/xml"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"strings"
	"time"

	"github.com/crewjam/saml"
	"github.com/crewjam/saml/logger"
	"github.com/gorilla/mux"
	"gopkg.in/yaml.v2"

	"git.autistici.org/id/go-sso/httpsso"
)

type Config struct {
	BaseURL string `yaml:"base_url"`

	UsersFile string `yaml:"users_file"`

	// SAML X509 credentials.
	CertificateFile string `yaml:"certificate_file"`
	PrivateKeyFile  string `yaml:"private_key_file"`

	// SSO configuration.
	SessionAuthKey    string `yaml:"session_auth_key"`
	SessionEncKey     string `yaml:"session_enc_key"`
	SSOLoginServerURL string `yaml:"sso_server_url"`
	SSOPublicKeyFile  string `yaml:"sso_public_key_file"`
	SSODomain         string `yaml:"sso_domain"`

	// Service provider config.
	ServiceProviders       []string `yaml:"service_providers"`
	parsedServiceProviders map[string]*saml.EntityDescriptor
}

// Sanity checks for the configuration.
func (c *Config) check() error {
	switch len(c.SessionAuthKey) {
	case 32, 64:
	case 0:
		return errors.New("session_auth_key is empty")
	default:
		return errors.New("session_auth_key must be a random string of 32 or 64 bytes")
	}
	switch len(c.SessionEncKey) {
	case 16, 24, 32:
	case 0:
		return errors.New("session_enc_key is empty")
	default:
		return errors.New("session_enc_key must be a random string of 16, 24 or 32 bytes")
	}
	if c.SSOLoginServerURL == "" {
		return errors.New("sso_server_url is empty")
	}
	if c.SSODomain == "" {
		return errors.New("sso_domain is empty")
	}
	return nil
}

func (c *Config) loadServiceProviders() error {
	c.parsedServiceProviders = make(map[string]*saml.EntityDescriptor)
	for _, path := range c.ServiceProviders {
		data, err := ioutil.ReadFile(path)
		if err != nil {
			return err
		}
		var ent saml.EntityDescriptor
		if err := xml.Unmarshal(data, &ent); err != nil {
			return err
		}
		c.parsedServiceProviders[ent.EntityID] = &ent
	}
	return nil
}

func (c *Config) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) {
	srv, ok := c.parsedServiceProviders[serviceProviderID]
	if !ok {
		return nil, os.ErrNotExist
	}
	return srv, nil
}

// Read users from a YAML-encoded file, in a format surprisingly
// compatible with git.autistici.org/id/auth/server.
//
// TODO: Make it retrieve the email addresses as extra data in the SSO
// token (this feature is currently unsupported by the SSO server,
// even though the auth-server provides the information).
type userInfo struct {
	Name  string `yaml:"name"`
	Email string `yaml:"email"`
}

type userFileBackend struct {
	users map[string]userInfo
}

func newUserFileBackend(path string) (*userFileBackend, error) {
	data, err := ioutil.ReadFile(path)
	if err != nil {
		return nil, err
	}
	var userList []userInfo
	if err := yaml.Unmarshal(data, &userList); err != nil {
		return nil, err
	}
	users := make(map[string]userInfo)
	for _, u := range userList {
		users[u.Name] = u
	}
	return &userFileBackend{users}, nil
}

func (b *userFileBackend) GetSession(w http.ResponseWriter, r *http.Request, req *saml.IdpAuthnRequest) *saml.Session {
	// The request should have the X-Authenticated-User header.
	username := r.Header.Get("X-Authenticated-User")
	if username == "" {
		http.Error(w, "No user found", http.StatusInternalServerError)
		return nil
	}
	user, ok := b.users[username]
	if !ok {
		http.Error(w, "User not found", http.StatusInternalServerError)
		return nil
	}

	return &saml.Session{
		ID:             base64.StdEncoding.EncodeToString(randomBytes(32)),
		CreateTime:     saml.TimeNow(),
		ExpireTime:     saml.TimeNow().Add(sessionMaxAge),
		Index:          hex.EncodeToString(randomBytes(32)),
		UserName:       user.Name,
		UserEmail:      user.Email,
		UserCommonName: user.Name,
		UserGivenName:  user.Name,
	}
}

func NewSAMLIDP(config *Config) (http.Handler, error) {
	if err := config.check(); err != nil {
		return nil, err
	}
	if err := config.loadServiceProviders(); err != nil {
		return nil, err
	}

	cert, err := tls.LoadX509KeyPair(config.CertificateFile, config.PrivateKeyFile)
	if err != nil {
		return nil, err
	}

	pkey, err := ioutil.ReadFile(config.SSOPublicKeyFile)
	if err != nil {
		return nil, err
	}

	w, err := httpsso.NewSSOWrapper(config.SSOLoginServerURL, pkey, config.SSODomain, []byte(config.SessionAuthKey), []byte(config.SessionEncKey))
	if err != nil {
		return nil, err
	}

	baseURL, err := url.Parse(config.BaseURL)
	if err != nil {
		return nil, err
	}
	ssoURL := baseURL
	ssoURL.Path += "/sso"
	metadataURL := baseURL
	metadataURL.Path += "/metadata"
	svc := fmt.Sprintf("%s%s", baseURL.Host, baseURL.Path)
	if !strings.HasSuffix(svc, "/") {
		svc += "/"
	}

	users, err := newUserFileBackend(config.UsersFile)
	if err != nil {
		return nil, err
	}

	idp := &saml.IdentityProvider{
		Key:                     cert.PrivateKey,
		Certificate:             cert.Leaf,
		Logger:                  logger.DefaultLogger,
		SSOURL:                  *ssoURL,
		ServiceProviderProvider: config,
		SessionProvider:         users,
	}
	h := idp.Handler()

	root := mux.NewRouter()
	root.Handle(ssoURL.Path, w.Wrap(h, svc, nil))
	root.Handle(metadataURL.Path, h)
	return root, nil
}

func randomBytes(n int) []byte {
	b := make([]byte, n)
	if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
		panic(err)
	}
	return b
}

var sessionMaxAge = 300 * time.Second