Commit ec1bcfe0 authored by ale's avatar ale
Browse files

Do not return an error when a user has no encryption keys on Open

And make HTTP result codes slightly more meaningful.
parent 3dff474e
...@@ -27,14 +27,17 @@ similarly JSON-encoded. ...@@ -27,14 +27,17 @@ similarly JSON-encoded.
Retrieve the encrypted key for a user, decrypt it with the provided Retrieve the encrypted key for a user, decrypt it with the provided
password, and store it in memory. password, and store it in memory.
OpenRequest is an object with the OpenRequest is an object with the following attributes:
following attributes:
* `username` * `username`
* `password` to decrypt the user's key with * `password` to decrypt the user's key with
* `ttl` (seconds) time after which the credentials are automatically * `ttl` (seconds) time after which the credentials are automatically
forgotten forgotten
If the user has no encrypted keys in the database, the request will
still return successfully: no action will be performed, and no errors
will be returned.
`/api/get` (*GetRequest*) -> *GetResponse* `/api/get` (*GetRequest*) -> *GetResponse*
Retrieve the key for a user. GetRequest must contain the following Retrieve the key for a user. GetRequest must contain the following
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"io/ioutil" "io/ioutil"
"log"
"strings" "strings"
"sync" "sync"
"time" "time"
...@@ -15,9 +16,10 @@ import ( ...@@ -15,9 +16,10 @@ import (
) )
var ( var (
ErrNoKeys = errors.New("no keys available") errNoKeys = errors.New("no keys available")
ErrBadUser = errors.New("username does not match authentication token") errBadUser = errors.New("username does not match authentication token")
ErrInvalidTTL = errors.New("invalid ttl") errUnauthorized = errors.New("unauthorized")
errInvalidTTL = errors.New("invalid ttl")
) )
// Database represents the interface to the underlying backend for // Database represents the interface to the underlying backend for
...@@ -130,7 +132,7 @@ func (s *KeyStore) expire() { ...@@ -130,7 +132,7 @@ func (s *KeyStore) expire() {
// A Context is needed because this method might issue an RPC. // A Context is needed because this method might issue an RPC.
func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSeconds int) error { func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSeconds int) error {
if ttlSeconds == 0 { if ttlSeconds == 0 {
return ErrInvalidTTL return errInvalidTTL
} }
encKeys, err := s.db.GetPrivateKeys(ctx, username) encKeys, err := s.db.GetPrivateKeys(ctx, username)
...@@ -138,7 +140,8 @@ func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSecon ...@@ -138,7 +140,8 @@ func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSecon
return err return err
} }
if len(encKeys) == 0 { if len(encKeys) == 0 {
return ErrNoKeys // No keys found. Not an error.
return nil
} }
// Naive and inefficient way of decrypting multiple keys: it // Naive and inefficient way of decrypting multiple keys: it
...@@ -163,17 +166,20 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) { ...@@ -163,17 +166,20 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) {
// Validate the SSO ticket. // Validate the SSO ticket.
tkt, err := s.validator.Validate(ssoTicket, "", s.service, nil) tkt, err := s.validator.Validate(ssoTicket, "", s.service, nil)
if err != nil { if err != nil {
return nil, err // Log authentication failures for debugging purposes.
log.Printf("Validate(%s) error: %v", username, err)
return nil, errUnauthorized
} }
if tkt.User != username { if tkt.User != username {
return nil, ErrBadUser log.Printf("Validate(%s) user mismatch: sso=%s", username, tkt.User)
return nil, errBadUser
} }
s.mx.Lock() s.mx.Lock()
defer s.mx.Unlock() defer s.mx.Unlock()
u, ok := s.userKeys[username] u, ok := s.userKeys[username]
if !ok { if !ok {
return nil, ErrNoKeys return nil, errNoKeys
} }
return u.pkey, nil return u.pkey, nil
} }
......
...@@ -24,7 +24,7 @@ func (s *keyStoreServer) handleOpen(w http.ResponseWriter, r *http.Request) { ...@@ -24,7 +24,7 @@ func (s *keyStoreServer) handleOpen(w http.ResponseWriter, r *http.Request) {
err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.TTL) err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.TTL)
if err != nil { if err != nil {
log.Printf("Open(%s) error: %v", req.Username, err) log.Printf("Open(%s) error: %v", req.Username, err)
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
...@@ -40,7 +40,15 @@ func (s *keyStoreServer) handleGet(w http.ResponseWriter, r *http.Request) { ...@@ -40,7 +40,15 @@ func (s *keyStoreServer) handleGet(w http.ResponseWriter, r *http.Request) {
key, err := s.KeyStore.Get(req.Username, req.SSOTicket) key, err := s.KeyStore.Get(req.Username, req.SSOTicket)
if err != nil { if err != nil {
log.Printf("Get(%s) error: %v", req.Username, err) log.Printf("Get(%s) error: %v", req.Username, err)
http.Error(w, err.Error(), http.StatusUnauthorized) // Return an appropriate error code.
switch err {
case errUnauthorized, errBadUser:
http.Error(w, err.Error(), http.StatusUnauthorized)
case errNoKeys:
http.NotFound(w, r)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return return
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment