sso.go 5.81 KiB
package sso
import (
"encoding/base64"
"errors"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ed25519"
)
var (
// Errors.
ErrMissingRequiredField = errors.New("missing required field")
ErrBadNonceLength = errors.New("bad nonce length")
ErrDeserialization = errors.New("deserialization error")
ErrUnsupportedTicketVersion = errors.New("unsupported ticket version")
ErrMessageTooShort = errors.New("encoded message too short")
ErrBadSignature = errors.New("bad signature")
ErrBadService = errors.New("service mismatch")
ErrBadDomain = errors.New("auth domain mismatch")
ErrBadNonce = errors.New("nonce mismatch")
ErrExpired = errors.New("ticket expired")
ErrUnauthorized = errors.New("unauthorized")
)
const (
ticketVersion = "4"
separatorCh = "|"
groupSeparatorCh = ","
charsToEscape = "%|,"
signatureLen = ed25519.SignatureSize
)
func escapeField(s string) string {
if strings.ContainsAny(s, charsToEscape) {
s = strings.Replace(s, "%", "%25", -1)
s = strings.Replace(s, "|", "%7C", -1)
s = strings.Replace(s, ",", "%2C", -1)
}
return s
}
func unescapeField(s string) string {
s = strings.Replace(s, "%2C", ",", -1)
s = strings.Replace(s, "%7C", "|", -1)
s = strings.Replace(s, "%25", "%", -1)
return s
}
// A Ticket attests a user's identity within the scope of a specific
// service, when properly signed.
type Ticket struct {
User string
Service string
Domain string
Nonce string
Groups []string
Expires time.Time
}
func (t *Ticket) validateFields() error {
if t.User == "" || t.Service == "" || t.Domain == "" {
return ErrMissingRequiredField
}
return nil
}
func (t *Ticket) serialize() string {
var escapedGroups []string
for _, g := range t.Groups {
escapedGroups = append(escapedGroups, escapeField(g))
}
return strings.Join(
[]string{
ticketVersion,
escapeField(t.User),
escapeField(t.Service),
escapeField(t.Domain),
escapeField(t.Nonce),
strconv.FormatInt(t.Expires.Unix(), 10),
strings.Join(escapedGroups, groupSeparatorCh),
},
separatorCh,
)
}
func deserializeTicket(s string) (*Ticket, error) {
parts := strings.Split(s, separatorCh)
if len(parts) != 7 {
return nil, ErrDeserialization
}
if parts[0] != ticketVersion {
return nil, ErrUnsupportedTicketVersion
}
secs, err := strconv.ParseInt(parts[5], 10, 64)
if err != nil {
return nil, ErrDeserialization
}
var groups []string
for _, g := range strings.Split(parts[6], groupSeparatorCh) {
groups = append(groups, unescapeField(g))
}
return &Ticket{
User: unescapeField(parts[1]),
Service: unescapeField(parts[2]),
Domain: unescapeField(parts[3]),
Nonce: unescapeField(parts[4]),
Expires: time.Unix(secs, 0),
Groups: groups,
}, nil
}
// NewTicket creates a new Ticket, filling in all required values.
func NewTicket(user, service, domain, nonce string, groups []string, validity time.Duration) *Ticket {
// Let's set nanoseconds to 0 - they don't go over the wire anyway.
expires := time.Now().Add(validity)
expires = time.Unix(expires.Unix(), 0)
return &Ticket{
User: user,
Service: service,
Domain: domain,
Nonce: nonce,
Groups: groups,
Expires: expires,
}
}
// A Signer can sign tickets.
type Signer interface {
Sign(*Ticket) (string, error)
}
type ssoSigner struct {
key ed25519.PrivateKey
}
// NewSigner creates a new ED25519 signer with the given private key.
func NewSigner(privateKey []byte) (Signer, error) {
if len(privateKey) != ed25519.PrivateKeySize {
return nil, errors.New("bad key size")
}
return &ssoSigner{key: privateKey}, nil
}
func (s *ssoSigner) Sign(t *Ticket) (string, error) {
if err := t.validateFields(); err != nil {
return "", err
}
serialized := []byte(t.serialize())
signature := ed25519.Sign(s.key, serialized)
signed := append(signature, serialized...)
return base64.RawURLEncoding.EncodeToString(signed), nil
}
// A Validator can verify that a ticket is valid.
type Validator interface {
Validate(string, string, string, []string) (*Ticket, error)
}
type ssoValidator struct {
publicKey ed25519.PublicKey
domain string
}
// NewValidator creates a new ED25519 validator for a specific domain,
// with the provided public key.
func NewValidator(publicKey []byte, domain string) (Validator, error) {
if len(publicKey) != ed25519.PublicKeySize {
return nil, errors.New("bad key size")
}
if domain == "" {
return nil, errors.New("empty domain")
}
return &ssoValidator{
publicKey: publicKey,
domain: domain,
}, nil
}
func (v *ssoValidator) parse(encoded string) (*Ticket, error) {
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
if err != nil {
return nil, err
}
if len(decoded) < signatureLen {
return nil, ErrMessageTooShort
}
serialized := decoded[signatureLen:]
signature := decoded[:signatureLen]
if !ed25519.Verify(v.publicKey, serialized, signature) {
return nil, ErrBadSignature
}
return deserializeTicket(string(serialized))
}
func anyGroupAllowed(allowedGroups, groups []string) bool {
for _, g := range groups {
for _, ag := range allowedGroups {
if g == ag {
return true
}
}
}
return false
}
func (v *ssoValidator) Validate(encoded, nonce, service string, allowedGroups []string) (*Ticket, error) {
if service == "" {
return nil, errors.New("empty service")
}
t, err := v.parse(encoded)
if err != nil {
return nil, err
}
if t.Domain != v.domain {
return nil, ErrBadDomain
}
if t.Service != service {
return nil, ErrBadService
}
if t.Expires.Before(time.Now()) {
return nil, ErrExpired
}
if t.Nonce != nonce {
return nil, ErrBadNonce
}
// Only perform a group check if allowedGroups is not nil.
if allowedGroups != nil && !anyGroupAllowed(allowedGroups, t.Groups) {
return nil, ErrUnauthorized
}
return t, nil
}