diff --git a/cmd/keystored/main.go b/cmd/keystored/main.go index a74d640f9b3d52461b3955a10e71cffdcdd7d7bc..4cdb544956088eda8bdcc0f68240596a9f9880e0 100644 --- a/cmd/keystored/main.go +++ b/cmd/keystored/main.go @@ -5,6 +5,7 @@ import ( "log" "os" + "github.com/prometheus/client_golang/prometheus" "gopkg.in/yaml.v3" "git.autistici.org/ai3/go-common/serverutil" @@ -57,6 +58,7 @@ func main() { if err != nil { log.Fatal(err) } + prometheus.MustRegister(server.NewKeystoreCollector(ks)) srv := server.NewServer(ks) diff --git a/dovecot/keyproxy_test.go b/dovecot/keyproxy_test.go index 68a53bdb824bc86bb71cb1fe98fc5b2439cc58c9..516760157500404deee19f6c7a14a2a796976ef1 100644 --- a/dovecot/keyproxy_test.go +++ b/dovecot/keyproxy_test.go @@ -216,7 +216,7 @@ func TestKeyProxy_LookupPrivateKeys_FromKeystore(t *testing.T) { defer cleanup() // Unlock the key in KeyStore with the correct password. - if err := ks.Open(context.Background(), "test@example.com", string(testPw1), 600); err != nil { + if err := ks.Open(context.Background(), "test@example.com", string(testPw1), "session", 600); err != nil { t.Fatalf("ks.Open: %v", err) } diff --git a/protocol.go b/protocol.go index fc6113e36dbba6ef1d05775119dbd32e8dabec4f..58b9e3f263b5c8e227d6163cd5778bfa7f523f85 100644 --- a/protocol.go +++ b/protocol.go @@ -1,9 +1,10 @@ package keystore type OpenRequest struct { - Username string `json:"username"` - Password string `json:"password"` - TTL int `json:"ttl"` + Username string `json:"username"` + Password string `json:"password"` + TTL int `json:"ttl"` + SessionID string `json:"session_id"` } type OpenResponse struct{} @@ -19,7 +20,8 @@ type GetResponse struct { } type CloseRequest struct { - Username string `json:"username"` + Username string `json:"username"` + SessionID string `json:"session_id"` } type CloseResponse struct{} diff --git a/server/instrumentation.go b/server/instrumentation.go index 3c7622235390b5d09a313509939216d97fdc0d3f..af32af3a9264df1c214ffba9873a4acf368a4034 100644 --- a/server/instrumentation.go +++ b/server/instrumentation.go @@ -3,10 +3,11 @@ package server import "github.com/prometheus/client_golang/prometheus" var ( - totalKeysInMemory = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "keystored_keys_total", - Help: "Total number of unlocked keys in-memory.", - }) + totalKeysInMemory = prometheus.NewDesc( + "keystored_keys_total", + "Total number of unlocked keys in-memory.", + nil, nil, + ) requestsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "keystored_requests_total", Help: "Counter of requests by method and status.", @@ -22,12 +23,29 @@ var ( ) func init() { - prometheus.MustRegister(totalKeysInMemory) prometheus.MustRegister(requestsCounter) prometheus.MustRegister(decryptedKeysCounter) prometheus.MustRegister(unlockedKeysServedCounter) } -func (s *KeyStore) updateKeyspaceSize() { - totalKeysInMemory.Set(float64(len(s.userKeys))) +type keystoreCollector struct { + ks *KeyStore +} + +// NewKeystoreCollector returns a prometheus.Collector that will +// export metrics for the given KeyStore instance. +func NewKeystoreCollector(ks *KeyStore) prometheus.Collector { + return &keystoreCollector{ks: ks} +} + +func (c *keystoreCollector) Describe(ch chan<- *prometheus.Desc) { + prometheus.DescribeByCollect(c, ch) +} + +func (c *keystoreCollector) Collect(ch chan<- prometheus.Metric) { + ch <- prometheus.MustNewConstMetric( + totalKeysInMemory, + prometheus.GaugeValue, + float64(c.ks.userKeys.numCachedKeys()), + ) } diff --git a/server/keystore.go b/server/keystore.go index 0ffcda59d724440ddd50bf84bb6c2777dbd48dd4..88abc7c8c583c2f9ca111f718207047fc6a0fc5c 100644 --- a/server/keystore.go +++ b/server/keystore.go @@ -25,8 +25,14 @@ var ( ) type userKey struct { - pkey []byte - expiry time.Time + pkey []byte +} + +type userSession struct { + id string + username string + expiry time.Time + index int } // Config for the KeyStore. @@ -68,12 +74,12 @@ func (c *Config) check() error { // token for the user whose secrets you would like to obtain. // type KeyStore struct { - mx sync.Mutex - userKeys map[string]userKey - db backend.Database service string validator sso.Validator + + mx sync.Mutex + userKeys *userKeyMap } func newKeyStoreWithBackend(config *Config, db backend.Database) (*KeyStore, error) { @@ -87,7 +93,7 @@ func newKeyStoreWithBackend(config *Config, db backend.Database) (*KeyStore, err } s := &KeyStore{ - userKeys: make(map[string]userKey), + userKeys: newUserKeyMap(), service: config.SSOService, validator: v, db: db, @@ -120,14 +126,8 @@ func NewKeyStore(config *Config) (*KeyStore, error) { func (s *KeyStore) expire(t time.Time) { s.mx.Lock() - for u, k := range s.userKeys { - if k.expiry.Before(t) { - log.Printf("forgetting key for %s", u) - wipeBytes(k.pkey) - delete(s.userKeys, u) - } - } - s.updateKeyspaceSize() + s.userKeys.expire(t) + //s.updateKeyspaceSize() s.mx.Unlock() } @@ -139,10 +139,17 @@ func (s *KeyStore) expireLoop() { // Open the user's key store with the given password. If successful, // the unencrypted user key will be stored for at most ttlSeconds, or -// until Close is called. +// until Close is called with the same session ID. +// +// Note that the key is fetched from the backend and decrypted even if +// we already have it in memory (for instance belonging to a separate +// session), because this acts as an implicit ACL check: does the user +// have access to the key because it can decrypt it with the provided +// credentials? // // A Context is needed because this method might issue an RPC. -func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSeconds int) error { +// +func (s *KeyStore) Open(ctx context.Context, username, password, sessionID string, ttlSeconds int) error { if ttlSeconds == 0 { return errInvalidTTL } @@ -170,11 +177,8 @@ func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSecon } s.mx.Lock() - s.userKeys[username] = userKey{ - pkey: pem, - expiry: time.Now().Add(time.Duration(ttlSeconds) * time.Second), - } - s.updateKeyspaceSize() + s.userKeys.addSessionWithKey(sessionID, username, &userKey{pkey: pem}, ttlSeconds) + //s.updateKeyspaceSize() s.mx.Unlock() return nil } @@ -196,7 +200,7 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) { s.mx.Lock() defer s.mx.Unlock() - u, ok := s.userKeys[username] + u, ok := s.userKeys.get(username) if !ok { return nil, errNoKeys } @@ -205,16 +209,10 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) { // Close the user's key store and wipe the associated unencrypted key // from memory. Returns true if a key was actually discarded. -func (s *KeyStore) Close(username string) bool { +func (s *KeyStore) Close(sessionID string) bool { s.mx.Lock() defer s.mx.Unlock() - k, ok := s.userKeys[username] - if ok { - wipeBytes(k.pkey) - delete(s.userKeys, username) - s.updateKeyspaceSize() - } - return ok + return s.userKeys.deleteSession(sessionID) } func wipeBytes(b []byte) { diff --git a/server/keystore_test.go b/server/keystore_test.go index 081196225391cd094d5512ee86f5c849f5be3d87..cb020831ce9bac23880bb9545c38125784c2d80d 100644 --- a/server/keystore_test.go +++ b/server/keystore_test.go @@ -109,7 +109,7 @@ func TestKeystore_OpenAndGet(t *testing.T) { defer cleanup() // Decrypt the private key with the right password. - err := keystore.Open(context.Background(), "testuser", string(pw), 60) + err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60) if err != nil { t.Fatal("keystore.Open():", err) } @@ -132,19 +132,60 @@ func TestKeystore_OpenAndGet(t *testing.T) { } // Call Close() and forget the key. - keystore.Close("testuser") + keystore.Close("session") if _, err := keystore.Get("testuser", ssoTicket); err == nil { t.Fatal("keystore.Get() returned no error after Close()") } } +func TestKeystore_OpenAndGet_MultipleSessions(t *testing.T) { + c, keystore, cleanup := newTestContext(t) + defer cleanup() + + // Decrypt the private key with the right password. + for _, id := range []string{"session1", "session2"} { + err := keystore.Open(context.Background(), "testuser", string(pw), id, 60) + if err != nil { + t.Fatalf("keystore.Open(%s): %v", id, err) + } + } + + // Call expire() now to make sure we don't wipe data that is + // not expired yet. + keystore.expire(time.Now()) + + // Sign a valid SSO ticket and use it to obtain the private + // key we just stored. + ssoTicket := c.sign("testuser", "keystore/", "domain") + result, err := keystore.Get("testuser", ssoTicket) + if err != nil { + t.Fatal("keystore.Get():", err) + } + + expectedPEM, _ := privKey.PEM() + if !bytes.Equal(result, expectedPEM) { + t.Fatalf("keystore.Get() returned bad key: got %v, expected %v", result, expectedPEM) + } + + // Call Close() on the first session, key should still be around. + keystore.Close("session1") + if _, err := keystore.Get("testuser", ssoTicket); err != nil { + t.Fatalf("keystore.Get() after Close(session1): %v", err) + } + // Closing the second session should wipe the key. + keystore.Close("session2") + if _, err := keystore.Get("testuser", ssoTicket); err == nil { + t.Fatal("keystore.Get() returned no error after Close(session2)") + } +} + func TestKeystore_OpenAndGet_NoKeys(t *testing.T) { c, keystore, cleanup := newTestContext(t) defer cleanup() // Check the return value of Open() when the user has no keys. username := "no-keys-user" - err := keystore.Open(context.Background(), username, string(pw), 60) + err := keystore.Open(context.Background(), username, string(pw), "session", 60) if err != errNoKeys { t.Fatalf("keystore.Open() returned unexpected err=%v", err) } @@ -181,7 +222,7 @@ func TestKeystore_Expire(t *testing.T) { defer cleanup() // Decrypt the private key with the right password. - err := keystore.Open(context.Background(), "testuser", string(pw), 60) + err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60) if err != nil { t.Fatal("keystore.Open():", err) } diff --git a/server/map.go b/server/map.go new file mode 100644 index 0000000000000000000000000000000000000000..3db45432fd69cafd0bd02861c7d3975b81776084 --- /dev/null +++ b/server/map.go @@ -0,0 +1,147 @@ +package server + +import ( + "container/heap" + "time" +) + +type sessionPQ []*userSession + +func (pq sessionPQ) Len() int { + return len(pq) +} + +func (pq sessionPQ) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +func (pq sessionPQ) Less(i, j int) bool { + return pq[i].expiry.Before(pq[j].expiry) +} + +func (pq *sessionPQ) Push(x interface{}) { + n := len(*pq) + item := x.(*userSession) + item.index = n + *pq = append(*pq, item) +} + +func (pq *sessionPQ) Pop() interface{} { + old := *pq + n := len(old) + x := old[n-1] + old[n-1] = nil + *pq = old[0 : n-1] + return x +} + +type sessionMap struct { + sessions map[string]*userSession + pq sessionPQ +} + +func newSessionMap() *sessionMap { + return &sessionMap{ + sessions: make(map[string]*userSession), + } +} + +func (m *sessionMap) add(sessionID, username string, ttl time.Duration) { + sess := &userSession{ + id: sessionID, + username: username, + expiry: time.Now().Add(ttl), + } + m.sessions[sessionID] = sess + heap.Push(&m.pq, sess) +} + +func (m *sessionMap) del(sessionID string) *userSession { + sess, ok := m.sessions[sessionID] + if !ok { + return nil + } + delete(m.sessions, sessionID) + heap.Remove(&m.pq, sess.index) + sess.index = -1 + return sess +} + +func (m *sessionMap) expireNext(deadline time.Time) *userSession { + // Peek and return first expired session. + if len(m.pq) > 0 && m.pq[0].expiry.Before(deadline) { + sess := heap.Pop(&m.pq).(*userSession) + delete(m.sessions, sess.id) + sess.index = -1 + return sess + } + return nil +} + +type userKeyMap struct { + sessions *sessionMap + userKeys map[string]*userKey + userSessions map[string][]string +} + +func newUserKeyMap() *userKeyMap { + return &userKeyMap{ + sessions: newSessionMap(), + userKeys: make(map[string]*userKey), + userSessions: make(map[string][]string), + } +} + +func (u *userKeyMap) numCachedKeys() int { + return len(u.userKeys) +} + +func (u *userKeyMap) get(username string) (*userKey, bool) { + k, ok := u.userKeys[username] + return k, ok +} + +func (u *userKeyMap) addSessionWithKey(sessionID, username string, key *userKey, ttlSeconds int) { + u.sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second) + u.userKeys[username] = key + u.userSessions[username] = append(u.userSessions[username], sessionID) +} + +func (u *userKeyMap) deleteSession(sessionID string) bool { + if sess := u.sessions.del(sessionID); sess != nil { + u.cleanupSession(sess) + return true + } + return false +} + +func (u *userKeyMap) expire(deadline time.Time) { + for { + sess := u.sessions.expireNext(deadline) + if sess == nil { + return + } + u.cleanupSession(sess) + } +} + +func (u *userKeyMap) cleanupSession(sess *userSession) { + var ids []string + for _, id := range u.userSessions[sess.username] { + if id != sess.id { + ids = append(ids, id) + } + } + + if len(ids) == 0 { + // No more sessions for this user, delete key. + k := u.userKeys[sess.username] + delete(u.userKeys, sess.username) + wipeBytes(k.pkey) + delete(u.userSessions, sess.username) + } else { + u.userSessions[sess.username] = ids + } +} diff --git a/server/server.go b/server/server.go index f22f37e4920810fe85a5a84df919d6a9c92c1bbf..149324bc0f662afbfa1f8e38d74177ce5f80bdfa 100644 --- a/server/server.go +++ b/server/server.go @@ -21,7 +21,7 @@ func (s *keyStoreServer) handleOpen(w http.ResponseWriter, r *http.Request) { return } - err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.TTL) + err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.SessionID, req.TTL) if err == errNoKeys { log.Printf("Open(%s): no encrypted keys found in database", req.Username) } else if err != nil { @@ -76,7 +76,7 @@ func (s *keyStoreServer) handleClose(w http.ResponseWriter, r *http.Request) { return } - if s.KeyStore.Close(req.Username) { + if s.KeyStore.Close(req.SessionID) { log.Printf("Close(%s): discarded key", req.Username) }