package server

import (
	"context"
	"errors"
	"io/ioutil"
	"log"
	"strings"

	ldaputil "git.autistici.org/ai3/go-common/ldap"
	ct "git.autistici.org/ai3/go-common/ldap/compositetypes"
	"github.com/duo-labs/webauthn/webauthn"
	"github.com/go-ldap/ldap/v3"
	"gopkg.in/yaml.v3"

	"git.autistici.org/id/auth/backend"
)

// ldapServiceParams defines a search to be performed when looking up
// a user for a service.
type ldapServiceParams struct {
	// SearchBase, SearchFilter and Scope define parameters for
	// the LDAP search. The search should return a single object.
	// SearchBase or SearchFilter should contain the string "%s",
	// which will be replaced with the username before performing
	// a query.
	SearchBase   string `yaml:"search_base"`
	SearchFilter string `yaml:"search_filter"`
	Scope        string `yaml:"scope"`

	// Attrs tells us which LDAP attributes to query to find user
	// attributes. It is encoded as a {user_attribute:
	// ldap_attribute} map, where user attributes include 'email',
	// 'password', 'app_specific_password', 'totp_secret', and
	// more).
	Attrs map[string]string `yaml:"attrs"`
}

// The default attribute mapping just happens to match our schema.
var defaultLDAPAttributeMap = map[string]string{
	"password":              "userPassword",
	"totp_secret":           "totpSecret",
	"app_specific_password": "appSpecificPassword",
	"u2f_registration":      "u2fRegistration",
}

func dropCryptPrefix(s string) string {
	if strings.HasPrefix(s, "{crypt}") || strings.HasPrefix(s, "{CRYPT}") {
		return s[7:]
	}
	return s
}

func getStringFromLDAPEntry(entry *ldap.Entry, attr string) string {
	if attr == "" {
		return ""
	}
	return entry.GetAttributeValue(attr)
}

func getListFromLDAPEntry(entry *ldap.Entry, attr string) []string {
	if attr == "" {
		return nil
	}
	return entry.GetAttributeValues(attr)
}

func decodeAppSpecificPasswordList(encodedAsps []string) []*backend.AppSpecificPassword {
	var out []*backend.AppSpecificPassword
	for _, enc := range encodedAsps {
		if p, err := ct.UnmarshalAppSpecificPassword(enc); err == nil {
			out = append(out, &backend.AppSpecificPassword{
				Service:           p.Service,
				EncryptedPassword: []byte(p.EncryptedPassword),
			})
		}
	}
	return out
}

func decodeU2FRegistrationList(encRegs []string) []webauthn.Credential {
	var out []webauthn.Credential
	for _, enc := range encRegs {
		if r, err := ct.UnmarshalU2FRegistration(enc); err == nil {
			if cred, err := r.Decode(); err == nil {
				out = append(out, cred)
			}
		}
	}
	return out
}

// Global configuration for the LDAP user backend.
type ldapConfig struct {
	URI        string `yaml:"uri"`
	BindDN     string `yaml:"bind_dn"`
	BindPw     string `yaml:"bind_pw"`
	BindPwFile string `yaml:"bind_pw_file"`
}

// Valid returns an error if the configuration is invalid.
func (c *ldapConfig) valid() error {
	if c.URI == "" {
		return errors.New("empty uri")
	}
	if c.BindDN == "" {
		return errors.New("empty bind_dn")
	}
	if (c.BindPwFile == "" && c.BindPw == "") || (c.BindPwFile != "" && c.BindPw != "") {
		return errors.New("only one of bind_pw_file or bind_pw must be set")
	}
	return nil
}

type ldapBackend struct {
	config *ldapConfig
	pool   *ldaputil.ConnectionPool
}

// New returns a new LDAP backend.
func New(params *yaml.Node, configDir string) (backend.UserBackend, error) {
	// Unmarshal and validate configuration.
	var lc ldapConfig
	if err := params.Decode(&lc); err != nil {
		return nil, err
	}
	if err := lc.valid(); err != nil {
		return nil, err
	}

	// Read the bind password.
	bindPw := lc.BindPw
	if lc.BindPwFile != "" {
		pwData, err := ioutil.ReadFile(backend.ResolvePath(lc.BindPwFile, configDir))
		if err != nil {
			return nil, err
		}
		bindPw = strings.TrimSpace(string(pwData))
	}

	// Initialize the connection pool.
	pool, err := ldaputil.NewConnectionPool(lc.URI, lc.BindDN, bindPw, 5)
	if err != nil {
		return nil, err
	}

	return &ldapBackend{
		config: &lc,
		pool:   pool,
	}, nil
}

func (b *ldapBackend) Close() {
	b.pool.Close()
}

func (b *ldapBackend) NewServiceBackend(spec *backend.Spec) (backend.ServiceBackend, error) {
	var params ldapServiceParams
	if err := spec.Params.Decode(&params); err != nil {
		return nil, err
	}
	return newLDAPServiceBackend(b.pool, &params)
}

type ldapServiceBackend struct {
	pool     *ldaputil.ConnectionPool
	base     string
	filter   string
	scope    int
	attrList []string
	attrs    map[string]string
}

func newLDAPServiceBackend(pool *ldaputil.ConnectionPool, params *ldapServiceParams) (*ldapServiceBackend, error) {
	if params.SearchBase == "" {
		return nil, errors.New("empty search_base")
	}
	if params.SearchFilter == "" {
		return nil, errors.New("empty search_filter")
	}
	scope := ldap.ScopeWholeSubtree
	if params.Scope != "" {
		s, err := ldaputil.ParseScope(params.Scope)
		if err != nil {
			return nil, err
		}
		scope = s
	}

	// Merge in attributes from the default map if unset, and
	// convert them to a list to pass to NewSearchRequest.
	attrs := make(map[string]string)
	for k, v := range defaultLDAPAttributeMap {
		attrs[k] = v
	}
	for k, v := range params.Attrs {
		attrs[k] = v
	}
	var attrList []string
	for _, v := range attrs {
		attrList = append(attrList, v)
	}

	return &ldapServiceBackend{
		pool:     pool,
		base:     params.SearchBase,
		filter:   params.SearchFilter,
		scope:    scope,
		attrList: attrList,
		attrs:    attrs,
	}, nil
}

// Build a SearchRequest for this username.
func (b *ldapServiceBackend) searchRequest(username string) *ldap.SearchRequest {
	base := strings.Replace(b.base, "%s", escapeDN(username), -1)
	filter := strings.Replace(b.filter, "%s", ldap.EscapeFilter(username), -1)
	return ldap.NewSearchRequest(
		base,
		b.scope,
		ldap.NeverDerefAliases,
		0,
		0,
		false,
		filter,
		b.attrList,
		nil,
	)
}

// Build a User object from a LDAP response.
func (b *ldapServiceBackend) userFromResponse(username string, result *ldap.SearchResult) (*backend.User, bool) {
	if len(result.Entries) < 1 {
		return nil, false
	}
	// TODO: return an error if more than one entry is returned.

	entry := result.Entries[0]

	// Apply the attribute map. We don't care if an attribute is
	// not defined in the map, as the get* functions will silently
	// ignore an empty attribute name.
	u := backend.User{
		Name:                  username,
		Email:                 getStringFromLDAPEntry(entry, b.attrs["email"]),
		Shard:                 getStringFromLDAPEntry(entry, b.attrs["shard"]),
		EncryptedPassword:     []byte(dropCryptPrefix(getStringFromLDAPEntry(entry, b.attrs["password"]))),
		TOTPSecret:            getStringFromLDAPEntry(entry, b.attrs["totp_secret"]),
		AppSpecificPasswords:  decodeAppSpecificPasswordList(getListFromLDAPEntry(entry, b.attrs["app_specific_password"])),
		WebAuthnRegistrations: decodeU2FRegistrationList(getListFromLDAPEntry(entry, b.attrs["u2f_registration"])),
	}

	return &u, true
}

func (b *ldapServiceBackend) GetUser(ctx context.Context, name string) (*backend.User, bool) {
	result, err := b.pool.Search(ctx, b.searchRequest(name))
	if err != nil {
		// Only log unexpected errors.
		if !ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) {
			log.Printf("LDAP error: %v", err)
		}
		return nil, false
	}
	return b.userFromResponse(name, result)
}

var hexChars = "0123456789abcdef"

func mustEscape(c byte) bool {
	return (c > 0x7f || c == '<' || c == '>' || c == '\\' || c == '*' ||
		c == '"' || c == ',' || c == '+' || c == ';' || c == 0)
}

// escapeDN escapes from the provided LDAP RDN value string the
// special characters in the 'escaped' set and those out of the range
// 0 < c < 0x80, as defined in RFC4515.
//
//  escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE
//
func escapeDN(s string) string {
	escape := 0
	for i := 0; i < len(s); i++ {
		if mustEscape(s[i]) {
			escape++
		}
	}
	if escape == 0 {
		return s
	}
	buf := make([]byte, len(s)+escape*2)
	for i, j := 0, 0; i < len(s); i++ {
		c := s[i]
		if mustEscape(c) {
			buf[j+0] = '\\'
			buf[j+1] = hexChars[c>>4]
			buf[j+2] = hexChars[c&0xf]
			j += 3
		} else {
			buf[j] = c
			j++
		}
	}
	return string(buf)
}