package serverutil

import (
	"crypto/tls"
	"errors"
	"fmt"
	"log"
	"net/http"
	"regexp"
	"strings"

	common "git.autistici.org/ai3/go-common"
)

// TLSAuthACL describes a single access control entry. Path and
// CommonName are anchored regular expressions (they must match the
// entire string). The first path to match in a list of ACLs will
// identify the ACL to be applied.
type TLSAuthACL struct {
	Path       string `yaml:"path"`
	CommonName string `yaml:"cn"`

	pathRx, cnRx *regexp.Regexp
}

func (p *TLSAuthACL) compile() error {
	var err error
	p.pathRx, err = regexp.Compile("^" + p.Path + "$")
	if err != nil {
		return err
	}
	p.cnRx, err = regexp.Compile("^" + p.CommonName + "$")
	return err
}

func (p *TLSAuthACL) matchPath(req *http.Request) bool {
	return p.pathRx.MatchString(req.URL.Path)
}

func (p *TLSAuthACL) matchCN(req *http.Request) bool {
	for _, cert := range req.TLS.PeerCertificates {
		if p.cnRx.MatchString(cert.Subject.CommonName) {
			return true
		}
	}
	return false
}

// TLSAuthACLListFlag is a convenience type that allows callers to use
// the 'flag' package to specify a list of TLSAuthACL objects. It
// implements the flag.Value interface.
type TLSAuthACLListFlag []*TLSAuthACL

func (l TLSAuthACLListFlag) String() string {
	var out []string
	for _, acl := range l {
		out = append(out, fmt.Sprintf("%s:%s", acl.Path, acl.CommonName))
	}
	return strings.Join(out, ",")
}

func (l *TLSAuthACLListFlag) Set(value string) error {
	parts := strings.SplitN(value, ":", 2)
	if len(parts) != 2 {
		return errors.New("bad acl format")
	}
	*l = append(*l, &TLSAuthACL{
		Path:       parts[0],
		CommonName: parts[1],
	})
	return nil
}

// TLSAuthConfig stores access control lists for TLS authentication. Access
// control lists are matched against the request path and the
// CommonName component of the peer certificate subject.
type TLSAuthConfig struct {
	Allow []*TLSAuthACL `yaml:"allow"`
}

func (c *TLSAuthConfig) match(req *http.Request) bool {
	// Fail *OPEN* if unconfigured.
	if c == nil || len(c.Allow) == 0 {
		return true
	}

	for _, acl := range c.Allow {
		if acl.matchPath(req) {
			if acl.matchCN(req) {
				return true
			}
			break
		}
	}
	return false
}

var serverCiphers = []uint16{
	tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
	tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
}

// TLSServerConfig configures a TLS server with client authentication
// and authorization based on the client X509 certificate.
type TLSServerConfig struct {
	Cert string         `yaml:"cert"`
	Key  string         `yaml:"key"`
	CA   string         `yaml:"ca"`
	Auth *TLSAuthConfig `yaml:"acl"`
}

// TLSConfig returns a tls.Config created with the current configuration.
func (c *TLSServerConfig) TLSConfig() (*tls.Config, error) {
	cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
	if err != nil {
		return nil, err
	}

	// Set some TLS-level parameters (cipher-related), assuming
	// we're using EC keys.
	tlsConf := &tls.Config{
		Certificates:             []tls.Certificate{cert},
		CipherSuites:             serverCiphers,
		MinVersion:               tls.VersionTLS12,
		PreferServerCipherSuites: true,
		NextProtos:               []string{"h2", "http/1.1"},
	}

	// Require client certificates if a CA is specified.
	if c.CA != "" {
		cas, err := common.LoadCA(c.CA)
		if err != nil {
			return nil, err
		}

		tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
		tlsConf.ClientCAs = cas
	}

	tlsConf.BuildNameToCertificate()

	return tlsConf, nil
}

// TLSAuthWrapper protects a root HTTP handler with TLS authentication.
func (c *TLSServerConfig) TLSAuthWrapper(h http.Handler) (http.Handler, error) {
	// Compile regexps.
	if c.Auth != nil {
		for _, acl := range c.Auth.Allow {
			if err := acl.compile(); err != nil {
				return nil, err
			}
		}
	}

	// Build the wrapper function to check client certificates
	// identities (looking at the CN part of the X509 subject).
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if c.Auth.match(r) {
			h.ServeHTTP(w, r)
			return
		}

		// Log the failed access, useful for debugging.
		var tlsmsg string
		if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
			tlsmsg = fmt.Sprintf("TLS client '%s' at", r.TLS.PeerCertificates[0].Subject.CommonName)
		}
		log.Printf("unauthorized access to %s from %s%s", r.URL.Path, tlsmsg, r.RemoteAddr)
		http.Error(w, "Forbidden", http.StatusForbidden)
	}), nil
}