Skip to content
Snippets Groups Projects
sso.go 5.81 KiB
package sso

import (
	"encoding/base64"
	"errors"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/ed25519"
)

var (
	// Errors.
	ErrMissingRequiredField     = errors.New("missing required field")
	ErrBadNonceLength           = errors.New("bad nonce length")
	ErrDeserialization          = errors.New("deserialization error")
	ErrUnsupportedTicketVersion = errors.New("unsupported ticket version")
	ErrMessageTooShort          = errors.New("encoded message too short")
	ErrBadSignature             = errors.New("bad signature")
	ErrBadService               = errors.New("service mismatch")
	ErrBadDomain                = errors.New("auth domain mismatch")
	ErrBadNonce                 = errors.New("nonce mismatch")
	ErrExpired                  = errors.New("ticket expired")
	ErrUnauthorized             = errors.New("unauthorized")
)

const (
	ticketVersion    = "4"
	separatorCh      = "|"
	groupSeparatorCh = ","
	charsToEscape    = "%|,"

	signatureLen = ed25519.SignatureSize
)

func escapeField(s string) string {
	if strings.ContainsAny(s, charsToEscape) {
		s = strings.Replace(s, "%", "%25", -1)
		s = strings.Replace(s, "|", "%7C", -1)
		s = strings.Replace(s, ",", "%2C", -1)
	}
	return s
}

func unescapeField(s string) string {
	s = strings.Replace(s, "%2C", ",", -1)
	s = strings.Replace(s, "%7C", "|", -1)
	s = strings.Replace(s, "%25", "%", -1)
	return s
}

// A Ticket attests a user's identity within the scope of a specific
// service, when properly signed.
type Ticket struct {
	User    string
	Service string
	Domain  string
	Nonce   string
	Groups  []string
	Expires time.Time
}

func (t *Ticket) validateFields() error {
	if t.User == "" || t.Service == "" || t.Domain == "" {
		return ErrMissingRequiredField
	}
	return nil
}
func (t *Ticket) serialize() string {
	var escapedGroups []string
	for _, g := range t.Groups {
		escapedGroups = append(escapedGroups, escapeField(g))
	}
	return strings.Join(
		[]string{
			ticketVersion,
			escapeField(t.User),
			escapeField(t.Service),
			escapeField(t.Domain),
			escapeField(t.Nonce),
			strconv.FormatInt(t.Expires.Unix(), 10),
			strings.Join(escapedGroups, groupSeparatorCh),
		},
		separatorCh,
	)
}

func deserializeTicket(s string) (*Ticket, error) {
	parts := strings.Split(s, separatorCh)
	if len(parts) != 7 {
		return nil, ErrDeserialization
	}

	if parts[0] != ticketVersion {
		return nil, ErrUnsupportedTicketVersion
	}

	secs, err := strconv.ParseInt(parts[5], 10, 64)
	if err != nil {
		return nil, ErrDeserialization
	}

	var groups []string
	for _, g := range strings.Split(parts[6], groupSeparatorCh) {
		groups = append(groups, unescapeField(g))
	}

	return &Ticket{
		User:    unescapeField(parts[1]),
		Service: unescapeField(parts[2]),
		Domain:  unescapeField(parts[3]),
		Nonce:   unescapeField(parts[4]),
		Expires: time.Unix(secs, 0),
		Groups:  groups,
	}, nil
}

// NewTicket creates a new Ticket, filling in all required values.
func NewTicket(user, service, domain, nonce string, groups []string, validity time.Duration) *Ticket {
	// Let's set nanoseconds to 0 - they don't go over the wire anyway.
	expires := time.Now().Add(validity)
	expires = time.Unix(expires.Unix(), 0)

	return &Ticket{
		User:    user,
		Service: service,
		Domain:  domain,
		Nonce:   nonce,
		Groups:  groups,
		Expires: expires,
	}
}

// A Signer can sign tickets.
type Signer interface {
	Sign(*Ticket) (string, error)
}
type ssoSigner struct {
	key ed25519.PrivateKey
}

// NewSigner creates a new ED25519 signer with the given private key.
func NewSigner(privateKey []byte) (Signer, error) {
	if len(privateKey) != ed25519.PrivateKeySize {
		return nil, errors.New("bad key size")
	}
	return &ssoSigner{key: privateKey}, nil
}

func (s *ssoSigner) Sign(t *Ticket) (string, error) {
	if err := t.validateFields(); err != nil {
		return "", err
	}

	serialized := []byte(t.serialize())
	signature := ed25519.Sign(s.key, serialized)
	signed := append(signature, serialized...)
	return base64.RawURLEncoding.EncodeToString(signed), nil
}

// A Validator can verify that a ticket is valid.
type Validator interface {
	Validate(string, string, string, []string) (*Ticket, error)
}

type ssoValidator struct {
	publicKey ed25519.PublicKey
	domain    string
}

// NewValidator creates a new ED25519 validator for a specific domain,
// with the provided public key.
func NewValidator(publicKey []byte, domain string) (Validator, error) {
	if len(publicKey) != ed25519.PublicKeySize {
		return nil, errors.New("bad key size")
	}
	if domain == "" {
		return nil, errors.New("empty domain")
	}
	return &ssoValidator{
		publicKey: publicKey,
		domain:    domain,
	}, nil
}

func (v *ssoValidator) parse(encoded string) (*Ticket, error) {
	decoded, err := base64.RawURLEncoding.DecodeString(encoded)
	if err != nil {
		return nil, err
	}

	if len(decoded) < signatureLen {
		return nil, ErrMessageTooShort
	}
	serialized := decoded[signatureLen:]
	signature := decoded[:signatureLen]
	if !ed25519.Verify(v.publicKey, serialized, signature) {
		return nil, ErrBadSignature
	}

	return deserializeTicket(string(serialized))
}

func anyGroupAllowed(allowedGroups, groups []string) bool {
	for _, g := range groups {
		for _, ag := range allowedGroups {
			if g == ag {
				return true
			}
		}
	}
	return false
}

func (v *ssoValidator) Validate(encoded, nonce, service string, allowedGroups []string) (*Ticket, error) {
	if service == "" {
		return nil, errors.New("empty service")
	}

	t, err := v.parse(encoded)
	if err != nil {
		return nil, err
	}

	if t.Domain != v.domain {
		return nil, ErrBadDomain
	}
	if t.Service != service {
		return nil, ErrBadService
	}
	if t.Expires.Before(time.Now()) {
		return nil, ErrExpired
	}
	if t.Nonce != nonce {
		return nil, ErrBadNonce
	}

	// Only perform a group check if allowedGroups is not nil.
	if allowedGroups != nil && !anyGroupAllowed(allowedGroups, t.Groups) {
		return nil, ErrUnauthorized
	}

	return t, nil
}