diff --git a/cmd/keystored/main.go b/cmd/keystored/main.go
index a74d640f9b3d52461b3955a10e71cffdcdd7d7bc..4cdb544956088eda8bdcc0f68240596a9f9880e0 100644
--- a/cmd/keystored/main.go
+++ b/cmd/keystored/main.go
@@ -5,6 +5,7 @@ import (
 	"log"
 	"os"
 
+	"github.com/prometheus/client_golang/prometheus"
 	"gopkg.in/yaml.v3"
 
 	"git.autistici.org/ai3/go-common/serverutil"
@@ -57,6 +58,7 @@ func main() {
 	if err != nil {
 		log.Fatal(err)
 	}
+	prometheus.MustRegister(server.NewKeystoreCollector(ks))
 
 	srv := server.NewServer(ks)
 
diff --git a/dovecot/keyproxy_test.go b/dovecot/keyproxy_test.go
index 68a53bdb824bc86bb71cb1fe98fc5b2439cc58c9..516760157500404deee19f6c7a14a2a796976ef1 100644
--- a/dovecot/keyproxy_test.go
+++ b/dovecot/keyproxy_test.go
@@ -216,7 +216,7 @@ func TestKeyProxy_LookupPrivateKeys_FromKeystore(t *testing.T) {
 	defer cleanup()
 
 	// Unlock the key in KeyStore with the correct password.
-	if err := ks.Open(context.Background(), "test@example.com", string(testPw1), 600); err != nil {
+	if err := ks.Open(context.Background(), "test@example.com", string(testPw1), "session", 600); err != nil {
 		t.Fatalf("ks.Open: %v", err)
 	}
 
diff --git a/protocol.go b/protocol.go
index fc6113e36dbba6ef1d05775119dbd32e8dabec4f..58b9e3f263b5c8e227d6163cd5778bfa7f523f85 100644
--- a/protocol.go
+++ b/protocol.go
@@ -1,9 +1,10 @@
 package keystore
 
 type OpenRequest struct {
-	Username string `json:"username"`
-	Password string `json:"password"`
-	TTL      int    `json:"ttl"`
+	Username  string `json:"username"`
+	Password  string `json:"password"`
+	TTL       int    `json:"ttl"`
+	SessionID string `json:"session_id"`
 }
 
 type OpenResponse struct{}
@@ -19,7 +20,8 @@ type GetResponse struct {
 }
 
 type CloseRequest struct {
-	Username string `json:"username"`
+	Username  string `json:"username"`
+	SessionID string `json:"session_id"`
 }
 
 type CloseResponse struct{}
diff --git a/server/instrumentation.go b/server/instrumentation.go
index 3c7622235390b5d09a313509939216d97fdc0d3f..af32af3a9264df1c214ffba9873a4acf368a4034 100644
--- a/server/instrumentation.go
+++ b/server/instrumentation.go
@@ -3,10 +3,11 @@ package server
 import "github.com/prometheus/client_golang/prometheus"
 
 var (
-	totalKeysInMemory = prometheus.NewGauge(prometheus.GaugeOpts{
-		Name: "keystored_keys_total",
-		Help: "Total number of unlocked keys in-memory.",
-	})
+	totalKeysInMemory = prometheus.NewDesc(
+		"keystored_keys_total",
+		"Total number of unlocked keys in-memory.",
+		nil, nil,
+	)
 	requestsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
 		Name: "keystored_requests_total",
 		Help: "Counter of requests by method and status.",
@@ -22,12 +23,29 @@ var (
 )
 
 func init() {
-	prometheus.MustRegister(totalKeysInMemory)
 	prometheus.MustRegister(requestsCounter)
 	prometheus.MustRegister(decryptedKeysCounter)
 	prometheus.MustRegister(unlockedKeysServedCounter)
 }
 
-func (s *KeyStore) updateKeyspaceSize() {
-	totalKeysInMemory.Set(float64(len(s.userKeys)))
+type keystoreCollector struct {
+	ks *KeyStore
+}
+
+// NewKeystoreCollector returns a prometheus.Collector that will
+// export metrics for the given KeyStore instance.
+func NewKeystoreCollector(ks *KeyStore) prometheus.Collector {
+	return &keystoreCollector{ks: ks}
+}
+
+func (c *keystoreCollector) Describe(ch chan<- *prometheus.Desc) {
+	prometheus.DescribeByCollect(c, ch)
+}
+
+func (c *keystoreCollector) Collect(ch chan<- prometheus.Metric) {
+	ch <- prometheus.MustNewConstMetric(
+		totalKeysInMemory,
+		prometheus.GaugeValue,
+		float64(c.ks.userKeys.numCachedKeys()),
+	)
 }
diff --git a/server/keystore.go b/server/keystore.go
index 0ffcda59d724440ddd50bf84bb6c2777dbd48dd4..88abc7c8c583c2f9ca111f718207047fc6a0fc5c 100644
--- a/server/keystore.go
+++ b/server/keystore.go
@@ -25,8 +25,14 @@ var (
 )
 
 type userKey struct {
-	pkey   []byte
-	expiry time.Time
+	pkey []byte
+}
+
+type userSession struct {
+	id       string
+	username string
+	expiry   time.Time
+	index    int
 }
 
 // Config for the KeyStore.
@@ -68,12 +74,12 @@ func (c *Config) check() error {
 // token for the user whose secrets you would like to obtain.
 //
 type KeyStore struct {
-	mx       sync.Mutex
-	userKeys map[string]userKey
-
 	db        backend.Database
 	service   string
 	validator sso.Validator
+
+	mx       sync.Mutex
+	userKeys *userKeyMap
 }
 
 func newKeyStoreWithBackend(config *Config, db backend.Database) (*KeyStore, error) {
@@ -87,7 +93,7 @@ func newKeyStoreWithBackend(config *Config, db backend.Database) (*KeyStore, err
 	}
 
 	s := &KeyStore{
-		userKeys:  make(map[string]userKey),
+		userKeys:  newUserKeyMap(),
 		service:   config.SSOService,
 		validator: v,
 		db:        db,
@@ -120,14 +126,8 @@ func NewKeyStore(config *Config) (*KeyStore, error) {
 
 func (s *KeyStore) expire(t time.Time) {
 	s.mx.Lock()
-	for u, k := range s.userKeys {
-		if k.expiry.Before(t) {
-			log.Printf("forgetting key for %s", u)
-			wipeBytes(k.pkey)
-			delete(s.userKeys, u)
-		}
-	}
-	s.updateKeyspaceSize()
+	s.userKeys.expire(t)
+	//s.updateKeyspaceSize()
 	s.mx.Unlock()
 }
 
@@ -139,10 +139,17 @@ func (s *KeyStore) expireLoop() {
 
 // Open the user's key store with the given password. If successful,
 // the unencrypted user key will be stored for at most ttlSeconds, or
-// until Close is called.
+// until Close is called with the same session ID.
+//
+// Note that the key is fetched from the backend and decrypted even if
+// we already have it in memory (for instance belonging to a separate
+// session), because this acts as an implicit ACL check: does the user
+// have access to the key because it can decrypt it with the provided
+// credentials?
 //
 // A Context is needed because this method might issue an RPC.
-func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSeconds int) error {
+//
+func (s *KeyStore) Open(ctx context.Context, username, password, sessionID string, ttlSeconds int) error {
 	if ttlSeconds == 0 {
 		return errInvalidTTL
 	}
@@ -170,11 +177,8 @@ func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSecon
 	}
 
 	s.mx.Lock()
-	s.userKeys[username] = userKey{
-		pkey:   pem,
-		expiry: time.Now().Add(time.Duration(ttlSeconds) * time.Second),
-	}
-	s.updateKeyspaceSize()
+	s.userKeys.addSessionWithKey(sessionID, username, &userKey{pkey: pem}, ttlSeconds)
+	//s.updateKeyspaceSize()
 	s.mx.Unlock()
 	return nil
 }
@@ -196,7 +200,7 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) {
 
 	s.mx.Lock()
 	defer s.mx.Unlock()
-	u, ok := s.userKeys[username]
+	u, ok := s.userKeys.get(username)
 	if !ok {
 		return nil, errNoKeys
 	}
@@ -205,16 +209,10 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) {
 
 // Close the user's key store and wipe the associated unencrypted key
 // from memory. Returns true if a key was actually discarded.
-func (s *KeyStore) Close(username string) bool {
+func (s *KeyStore) Close(sessionID string) bool {
 	s.mx.Lock()
 	defer s.mx.Unlock()
-	k, ok := s.userKeys[username]
-	if ok {
-		wipeBytes(k.pkey)
-		delete(s.userKeys, username)
-		s.updateKeyspaceSize()
-	}
-	return ok
+	return s.userKeys.deleteSession(sessionID)
 }
 
 func wipeBytes(b []byte) {
diff --git a/server/keystore_test.go b/server/keystore_test.go
index 081196225391cd094d5512ee86f5c849f5be3d87..cb020831ce9bac23880bb9545c38125784c2d80d 100644
--- a/server/keystore_test.go
+++ b/server/keystore_test.go
@@ -109,7 +109,7 @@ func TestKeystore_OpenAndGet(t *testing.T) {
 	defer cleanup()
 
 	// Decrypt the private key with the right password.
-	err := keystore.Open(context.Background(), "testuser", string(pw), 60)
+	err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60)
 	if err != nil {
 		t.Fatal("keystore.Open():", err)
 	}
@@ -132,19 +132,60 @@ func TestKeystore_OpenAndGet(t *testing.T) {
 	}
 
 	// Call Close() and forget the key.
-	keystore.Close("testuser")
+	keystore.Close("session")
 	if _, err := keystore.Get("testuser", ssoTicket); err == nil {
 		t.Fatal("keystore.Get() returned no error after Close()")
 	}
 }
 
+func TestKeystore_OpenAndGet_MultipleSessions(t *testing.T) {
+	c, keystore, cleanup := newTestContext(t)
+	defer cleanup()
+
+	// Decrypt the private key with the right password.
+	for _, id := range []string{"session1", "session2"} {
+		err := keystore.Open(context.Background(), "testuser", string(pw), id, 60)
+		if err != nil {
+			t.Fatalf("keystore.Open(%s): %v", id, err)
+		}
+	}
+
+	// Call expire() now to make sure we don't wipe data that is
+	// not expired yet.
+	keystore.expire(time.Now())
+
+	// Sign a valid SSO ticket and use it to obtain the private
+	// key we just stored.
+	ssoTicket := c.sign("testuser", "keystore/", "domain")
+	result, err := keystore.Get("testuser", ssoTicket)
+	if err != nil {
+		t.Fatal("keystore.Get():", err)
+	}
+
+	expectedPEM, _ := privKey.PEM()
+	if !bytes.Equal(result, expectedPEM) {
+		t.Fatalf("keystore.Get() returned bad key: got %v, expected %v", result, expectedPEM)
+	}
+
+	// Call Close() on the first session, key should still be around.
+	keystore.Close("session1")
+	if _, err := keystore.Get("testuser", ssoTicket); err != nil {
+		t.Fatalf("keystore.Get() after Close(session1): %v", err)
+	}
+	// Closing the second session should wipe the key.
+	keystore.Close("session2")
+	if _, err := keystore.Get("testuser", ssoTicket); err == nil {
+		t.Fatal("keystore.Get() returned no error after Close(session2)")
+	}
+}
+
 func TestKeystore_OpenAndGet_NoKeys(t *testing.T) {
 	c, keystore, cleanup := newTestContext(t)
 	defer cleanup()
 
 	// Check the return value of Open() when the user has no keys.
 	username := "no-keys-user"
-	err := keystore.Open(context.Background(), username, string(pw), 60)
+	err := keystore.Open(context.Background(), username, string(pw), "session", 60)
 	if err != errNoKeys {
 		t.Fatalf("keystore.Open() returned unexpected err=%v", err)
 	}
@@ -181,7 +222,7 @@ func TestKeystore_Expire(t *testing.T) {
 	defer cleanup()
 
 	// Decrypt the private key with the right password.
-	err := keystore.Open(context.Background(), "testuser", string(pw), 60)
+	err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60)
 	if err != nil {
 		t.Fatal("keystore.Open():", err)
 	}
diff --git a/server/map.go b/server/map.go
new file mode 100644
index 0000000000000000000000000000000000000000..3db45432fd69cafd0bd02861c7d3975b81776084
--- /dev/null
+++ b/server/map.go
@@ -0,0 +1,147 @@
+package server
+
+import (
+	"container/heap"
+	"time"
+)
+
+type sessionPQ []*userSession
+
+func (pq sessionPQ) Len() int {
+	return len(pq)
+}
+
+func (pq sessionPQ) Swap(i, j int) {
+	pq[i], pq[j] = pq[j], pq[i]
+	pq[i].index = i
+	pq[j].index = j
+}
+
+func (pq sessionPQ) Less(i, j int) bool {
+	return pq[i].expiry.Before(pq[j].expiry)
+}
+
+func (pq *sessionPQ) Push(x interface{}) {
+	n := len(*pq)
+	item := x.(*userSession)
+	item.index = n
+	*pq = append(*pq, item)
+}
+
+func (pq *sessionPQ) Pop() interface{} {
+	old := *pq
+	n := len(old)
+	x := old[n-1]
+	old[n-1] = nil
+	*pq = old[0 : n-1]
+	return x
+}
+
+type sessionMap struct {
+	sessions map[string]*userSession
+	pq       sessionPQ
+}
+
+func newSessionMap() *sessionMap {
+	return &sessionMap{
+		sessions: make(map[string]*userSession),
+	}
+}
+
+func (m *sessionMap) add(sessionID, username string, ttl time.Duration) {
+	sess := &userSession{
+		id:       sessionID,
+		username: username,
+		expiry:   time.Now().Add(ttl),
+	}
+	m.sessions[sessionID] = sess
+	heap.Push(&m.pq, sess)
+}
+
+func (m *sessionMap) del(sessionID string) *userSession {
+	sess, ok := m.sessions[sessionID]
+	if !ok {
+		return nil
+	}
+	delete(m.sessions, sessionID)
+	heap.Remove(&m.pq, sess.index)
+	sess.index = -1
+	return sess
+}
+
+func (m *sessionMap) expireNext(deadline time.Time) *userSession {
+	// Peek and return first expired session.
+	if len(m.pq) > 0 && m.pq[0].expiry.Before(deadline) {
+		sess := heap.Pop(&m.pq).(*userSession)
+		delete(m.sessions, sess.id)
+		sess.index = -1
+		return sess
+	}
+	return nil
+}
+
+type userKeyMap struct {
+	sessions     *sessionMap
+	userKeys     map[string]*userKey
+	userSessions map[string][]string
+}
+
+func newUserKeyMap() *userKeyMap {
+	return &userKeyMap{
+		sessions:     newSessionMap(),
+		userKeys:     make(map[string]*userKey),
+		userSessions: make(map[string][]string),
+	}
+}
+
+func (u *userKeyMap) numCachedKeys() int {
+	return len(u.userKeys)
+}
+
+func (u *userKeyMap) get(username string) (*userKey, bool) {
+	k, ok := u.userKeys[username]
+	return k, ok
+}
+
+func (u *userKeyMap) addSessionWithKey(sessionID, username string, key *userKey, ttlSeconds int) {
+	u.sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second)
+	u.userKeys[username] = key
+	u.userSessions[username] = append(u.userSessions[username], sessionID)
+}
+
+func (u *userKeyMap) deleteSession(sessionID string) bool {
+	if sess := u.sessions.del(sessionID); sess != nil {
+		u.cleanupSession(sess)
+		return true
+	}
+	return false
+}
+
+func (u *userKeyMap) expire(deadline time.Time) {
+	for {
+		sess := u.sessions.expireNext(deadline)
+		if sess == nil {
+			return
+		}
+		u.cleanupSession(sess)
+	}
+}
+
+func (u *userKeyMap) cleanupSession(sess *userSession) {
+	var ids []string
+	for _, id := range u.userSessions[sess.username] {
+		if id != sess.id {
+			ids = append(ids, id)
+		}
+	}
+
+	if len(ids) == 0 {
+		// No more sessions for this user, delete key.
+		k := u.userKeys[sess.username]
+		delete(u.userKeys, sess.username)
+		wipeBytes(k.pkey)
+		delete(u.userSessions, sess.username)
+	} else {
+		u.userSessions[sess.username] = ids
+	}
+}
diff --git a/server/server.go b/server/server.go
index f22f37e4920810fe85a5a84df919d6a9c92c1bbf..149324bc0f662afbfa1f8e38d74177ce5f80bdfa 100644
--- a/server/server.go
+++ b/server/server.go
@@ -21,7 +21,7 @@ func (s *keyStoreServer) handleOpen(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.TTL)
+	err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.SessionID, req.TTL)
 	if err == errNoKeys {
 		log.Printf("Open(%s): no encrypted keys found in database", req.Username)
 	} else if err != nil {
@@ -76,7 +76,7 @@ func (s *keyStoreServer) handleClose(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	if s.KeyStore.Close(req.Username) {
+	if s.KeyStore.Close(req.SessionID) {
 		log.Printf("Close(%s): discarded key", req.Username)
 	}