package sso import ( "encoding/base64" "errors" "strconv" "strings" "time" "golang.org/x/crypto/ed25519" ) var ( // ErrMissingRequiredField is returned when a ticket does not // contain a required field. ErrMissingRequiredField = errors.New("missing required field") // ErrDeserialization means that the input is not valid base64. ErrDeserialization = errors.New("deserialization error") // ErrUnsupportedTicketVersion is returned for unsupported // ticket versions (either too old or too recent). ErrUnsupportedTicketVersion = errors.New("unsupported ticket version") // ErrMessageTooShort means that the input is shorter than the // fixed signature length + minimum ticket size. ErrMessageTooShort = errors.New("encoded message too short") // ErrBadSignature is returned when the signature does not // match the given public key. ErrBadSignature = errors.New("bad signature") // ErrBadService is returned when validation fails due to a // SSO service mismatch. ErrBadService = errors.New("service mismatch") // ErrBadDomain is returned when validation fails due to a SSO // domain mismatch. ErrBadDomain = errors.New("auth domain mismatch") // ErrBadNonce is returned when validation fails due to a // nonce mismatch. ErrBadNonce = errors.New("nonce mismatch") // ErrExpired means the ticket has expired. ErrExpired = errors.New("ticket expired") // ErrUnauthorized is returned when the user lacks the // necessary group membership. ErrUnauthorized = errors.New("unauthorized") ) const ( ticketVersion = "4" separatorCh = "|" groupSeparatorCh = "," charsToEscape = "%|," signatureLen = ed25519.SignatureSize ) func escapeField(s string) string { if strings.ContainsAny(s, charsToEscape) { s = strings.Replace(s, "%", "%25", -1) s = strings.Replace(s, "|", "%7C", -1) s = strings.Replace(s, ",", "%2C", -1) } return s } func unescapeField(s string) string { s = strings.Replace(s, "%2C", ",", -1) s = strings.Replace(s, "%7C", "|", -1) s = strings.Replace(s, "%25", "%", -1) return s } // A Ticket attests a user's identity within the scope of a specific // service, when properly signed. type Ticket struct { User string Service string Domain string Nonce string Groups []string Expires time.Time } func (t *Ticket) validateFields() error { if t.User == "" || t.Service == "" || t.Domain == "" { return ErrMissingRequiredField } return nil } func (t *Ticket) serialize() string { var escapedGroups []string for _, g := range t.Groups { escapedGroups = append(escapedGroups, escapeField(g)) } return strings.Join( []string{ ticketVersion, escapeField(t.User), escapeField(t.Service), escapeField(t.Domain), escapeField(t.Nonce), strconv.FormatInt(t.Expires.Unix(), 10), strings.Join(escapedGroups, groupSeparatorCh), }, separatorCh, ) } func deserializeTicket(s string) (*Ticket, error) { parts := strings.Split(s, separatorCh) if len(parts) != 7 { return nil, ErrDeserialization } if parts[0] != ticketVersion { return nil, ErrUnsupportedTicketVersion } secs, err := strconv.ParseInt(parts[5], 10, 64) if err != nil { return nil, ErrDeserialization } var groups []string for _, g := range strings.Split(parts[6], groupSeparatorCh) { groups = append(groups, unescapeField(g)) } return &Ticket{ User: unescapeField(parts[1]), Service: unescapeField(parts[2]), Domain: unescapeField(parts[3]), Nonce: unescapeField(parts[4]), Expires: time.Unix(secs, 0), Groups: groups, }, nil } // NewTicket creates a new Ticket, filling in all required values. func NewTicket(user, service, domain, nonce string, groups []string, validity time.Duration) *Ticket { // Let's set nanoseconds to 0 - they don't go over the wire anyway. expires := time.Now().Add(validity) expires = time.Unix(expires.Unix(), 0) return &Ticket{ User: user, Service: service, Domain: domain, Nonce: nonce, Groups: groups, Expires: expires, } } // A Signer can sign tickets. type Signer interface { Sign(*Ticket) (string, error) } type ssoSigner struct { key ed25519.PrivateKey } // NewSigner creates a new ED25519 signer with the given private key. func NewSigner(privateKey []byte) (Signer, error) { if len(privateKey) != ed25519.PrivateKeySize { return nil, errors.New("bad key size") } return &ssoSigner{key: privateKey}, nil } func (s *ssoSigner) Sign(t *Ticket) (string, error) { if err := t.validateFields(); err != nil { return "", err } serialized := []byte(t.serialize()) signature := ed25519.Sign(s.key, serialized) signed := append(signature, serialized...) return base64.RawURLEncoding.EncodeToString(signed), nil } // A Validator can verify that a ticket is valid. type Validator interface { Validate(string, string, string, []string) (*Ticket, error) } type ssoValidator struct { publicKey ed25519.PublicKey domain string } // NewValidator creates a new ED25519 validator for a specific domain, // with the provided public key. func NewValidator(publicKey []byte, domain string) (Validator, error) { if len(publicKey) != ed25519.PublicKeySize { return nil, errors.New("bad key size") } if domain == "" { return nil, errors.New("empty domain") } return &ssoValidator{ publicKey: publicKey, domain: domain, }, nil } func (v *ssoValidator) parse(encoded string) (*Ticket, error) { decoded, err := base64.RawURLEncoding.DecodeString(encoded) if err != nil { return nil, err } if len(decoded) < signatureLen { return nil, ErrMessageTooShort } serialized := decoded[signatureLen:] signature := decoded[:signatureLen] if !ed25519.Verify(v.publicKey, serialized, signature) { return nil, ErrBadSignature } return deserializeTicket(string(serialized)) } func anyGroupAllowed(allowedGroups, groups []string) bool { for _, g := range groups { for _, ag := range allowedGroups { if g == ag { return true } } } return false } func (v *ssoValidator) Validate(encoded, nonce, service string, allowedGroups []string) (*Ticket, error) { if service == "" { return nil, errors.New("empty service") } t, err := v.parse(encoded) if err != nil { return nil, err } if t.Domain != v.domain { return nil, ErrBadDomain } if t.Service != service { return nil, ErrBadService } if t.Expires.Before(time.Now()) { return nil, ErrExpired } if nonce != "" && t.Nonce != nonce { return nil, ErrBadNonce } // Only perform a group check if allowedGroups is not nil. if allowedGroups != nil && !anyGroupAllowed(allowedGroups, t.Groups) { return nil, ErrUnauthorized } return t, nil } // InspectTicket reads a ticket without validating it (beyond syntax), // returning user and service. The results are untrusted. func InspectTicket(encoded string) (string, string, error) { decoded, err := base64.RawURLEncoding.DecodeString(encoded) if err != nil { return "", "", err } if len(decoded) < signatureLen { return "", "", ErrMessageTooShort } serialized := decoded[signatureLen:] t, err := deserializeTicket(string(serialized)) if err != nil { return "", "", err } return t.User, t.Service, nil }