Skip to content
Snippets Groups Projects
Commit 070c74a4 authored by ale's avatar ale
Browse files

Support multiple sessions for each user

Breaking API change, adding a "session_id" field to Open and Close
calls.

Multiple sessions can be attached to the same decrypted key, each
expiring (or being closed) independently: the key will only be dropped
when the last session terminates.
parent becc7e22
Branches
No related tags found
1 merge request!32Support multiple sessions for each user
......@@ -5,6 +5,7 @@ import (
"log"
"os"
"github.com/prometheus/client_golang/prometheus"
"gopkg.in/yaml.v3"
"git.autistici.org/ai3/go-common/serverutil"
......@@ -57,6 +58,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
prometheus.MustRegister(server.NewKeystoreCollector(ks))
srv := server.NewServer(ks)
......
......@@ -216,7 +216,7 @@ func TestKeyProxy_LookupPrivateKeys_FromKeystore(t *testing.T) {
defer cleanup()
// Unlock the key in KeyStore with the correct password.
if err := ks.Open(context.Background(), "test@example.com", string(testPw1), 600); err != nil {
if err := ks.Open(context.Background(), "test@example.com", string(testPw1), "session", 600); err != nil {
t.Fatalf("ks.Open: %v", err)
}
......
package keystore
type OpenRequest struct {
Username string `json:"username"`
Password string `json:"password"`
TTL int `json:"ttl"`
Username string `json:"username"`
Password string `json:"password"`
TTL int `json:"ttl"`
SessionID string `json:"session_id"`
}
type OpenResponse struct{}
......@@ -19,7 +20,8 @@ type GetResponse struct {
}
type CloseRequest struct {
Username string `json:"username"`
Username string `json:"username"`
SessionID string `json:"session_id"`
}
type CloseResponse struct{}
......@@ -3,10 +3,11 @@ package server
import "github.com/prometheus/client_golang/prometheus"
var (
totalKeysInMemory = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "keystored_keys_total",
Help: "Total number of unlocked keys in-memory.",
})
totalKeysInMemory = prometheus.NewDesc(
"keystored_keys_total",
"Total number of unlocked keys in-memory.",
nil, nil,
)
requestsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "keystored_requests_total",
Help: "Counter of requests by method and status.",
......@@ -22,12 +23,29 @@ var (
)
func init() {
prometheus.MustRegister(totalKeysInMemory)
prometheus.MustRegister(requestsCounter)
prometheus.MustRegister(decryptedKeysCounter)
prometheus.MustRegister(unlockedKeysServedCounter)
}
func (s *KeyStore) updateKeyspaceSize() {
totalKeysInMemory.Set(float64(len(s.userKeys)))
type keystoreCollector struct {
ks *KeyStore
}
// NewKeystoreCollector returns a prometheus.Collector that will
// export metrics for the given KeyStore instance.
func NewKeystoreCollector(ks *KeyStore) prometheus.Collector {
return &keystoreCollector{ks: ks}
}
func (c *keystoreCollector) Describe(ch chan<- *prometheus.Desc) {
prometheus.DescribeByCollect(c, ch)
}
func (c *keystoreCollector) Collect(ch chan<- prometheus.Metric) {
ch <- prometheus.MustNewConstMetric(
totalKeysInMemory,
prometheus.GaugeValue,
float64(c.ks.userKeys.numCachedKeys()),
)
}
......@@ -25,8 +25,14 @@ var (
)
type userKey struct {
pkey []byte
expiry time.Time
pkey []byte
}
type userSession struct {
id string
username string
expiry time.Time
index int
}
// Config for the KeyStore.
......@@ -68,12 +74,12 @@ func (c *Config) check() error {
// token for the user whose secrets you would like to obtain.
//
type KeyStore struct {
mx sync.Mutex
userKeys map[string]userKey
db backend.Database
service string
validator sso.Validator
mx sync.Mutex
userKeys *userKeyMap
}
func newKeyStoreWithBackend(config *Config, db backend.Database) (*KeyStore, error) {
......@@ -87,7 +93,7 @@ func newKeyStoreWithBackend(config *Config, db backend.Database) (*KeyStore, err
}
s := &KeyStore{
userKeys: make(map[string]userKey),
userKeys: newUserKeyMap(),
service: config.SSOService,
validator: v,
db: db,
......@@ -120,14 +126,8 @@ func NewKeyStore(config *Config) (*KeyStore, error) {
func (s *KeyStore) expire(t time.Time) {
s.mx.Lock()
for u, k := range s.userKeys {
if k.expiry.Before(t) {
log.Printf("forgetting key for %s", u)
wipeBytes(k.pkey)
delete(s.userKeys, u)
}
}
s.updateKeyspaceSize()
s.userKeys.expire(t)
//s.updateKeyspaceSize()
s.mx.Unlock()
}
......@@ -139,10 +139,17 @@ func (s *KeyStore) expireLoop() {
// Open the user's key store with the given password. If successful,
// the unencrypted user key will be stored for at most ttlSeconds, or
// until Close is called.
// until Close is called with the same session ID.
//
// Note that the key is fetched from the backend and decrypted even if
// we already have it in memory (for instance belonging to a separate
// session), because this acts as an implicit ACL check: does the user
// have access to the key because it can decrypt it with the provided
// credentials?
//
// A Context is needed because this method might issue an RPC.
func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSeconds int) error {
//
func (s *KeyStore) Open(ctx context.Context, username, password, sessionID string, ttlSeconds int) error {
if ttlSeconds == 0 {
return errInvalidTTL
}
......@@ -170,11 +177,8 @@ func (s *KeyStore) Open(ctx context.Context, username, password string, ttlSecon
}
s.mx.Lock()
s.userKeys[username] = userKey{
pkey: pem,
expiry: time.Now().Add(time.Duration(ttlSeconds) * time.Second),
}
s.updateKeyspaceSize()
s.userKeys.addSessionWithKey(sessionID, username, &userKey{pkey: pem}, ttlSeconds)
//s.updateKeyspaceSize()
s.mx.Unlock()
return nil
}
......@@ -196,7 +200,7 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) {
s.mx.Lock()
defer s.mx.Unlock()
u, ok := s.userKeys[username]
u, ok := s.userKeys.get(username)
if !ok {
return nil, errNoKeys
}
......@@ -205,16 +209,10 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) {
// Close the user's key store and wipe the associated unencrypted key
// from memory. Returns true if a key was actually discarded.
func (s *KeyStore) Close(username string) bool {
func (s *KeyStore) Close(sessionID string) bool {
s.mx.Lock()
defer s.mx.Unlock()
k, ok := s.userKeys[username]
if ok {
wipeBytes(k.pkey)
delete(s.userKeys, username)
s.updateKeyspaceSize()
}
return ok
return s.userKeys.deleteSession(sessionID)
}
func wipeBytes(b []byte) {
......
......@@ -109,7 +109,7 @@ func TestKeystore_OpenAndGet(t *testing.T) {
defer cleanup()
// Decrypt the private key with the right password.
err := keystore.Open(context.Background(), "testuser", string(pw), 60)
err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60)
if err != nil {
t.Fatal("keystore.Open():", err)
}
......@@ -132,19 +132,60 @@ func TestKeystore_OpenAndGet(t *testing.T) {
}
// Call Close() and forget the key.
keystore.Close("testuser")
keystore.Close("session")
if _, err := keystore.Get("testuser", ssoTicket); err == nil {
t.Fatal("keystore.Get() returned no error after Close()")
}
}
func TestKeystore_OpenAndGet_MultipleSessions(t *testing.T) {
c, keystore, cleanup := newTestContext(t)
defer cleanup()
// Decrypt the private key with the right password.
for _, id := range []string{"session1", "session2"} {
err := keystore.Open(context.Background(), "testuser", string(pw), id, 60)
if err != nil {
t.Fatalf("keystore.Open(%s): %v", id, err)
}
}
// Call expire() now to make sure we don't wipe data that is
// not expired yet.
keystore.expire(time.Now())
// Sign a valid SSO ticket and use it to obtain the private
// key we just stored.
ssoTicket := c.sign("testuser", "keystore/", "domain")
result, err := keystore.Get("testuser", ssoTicket)
if err != nil {
t.Fatal("keystore.Get():", err)
}
expectedPEM, _ := privKey.PEM()
if !bytes.Equal(result, expectedPEM) {
t.Fatalf("keystore.Get() returned bad key: got %v, expected %v", result, expectedPEM)
}
// Call Close() on the first session, key should still be around.
keystore.Close("session1")
if _, err := keystore.Get("testuser", ssoTicket); err != nil {
t.Fatalf("keystore.Get() after Close(session1): %v", err)
}
// Closing the second session should wipe the key.
keystore.Close("session2")
if _, err := keystore.Get("testuser", ssoTicket); err == nil {
t.Fatal("keystore.Get() returned no error after Close(session2)")
}
}
func TestKeystore_OpenAndGet_NoKeys(t *testing.T) {
c, keystore, cleanup := newTestContext(t)
defer cleanup()
// Check the return value of Open() when the user has no keys.
username := "no-keys-user"
err := keystore.Open(context.Background(), username, string(pw), 60)
err := keystore.Open(context.Background(), username, string(pw), "session", 60)
if err != errNoKeys {
t.Fatalf("keystore.Open() returned unexpected err=%v", err)
}
......@@ -181,7 +222,7 @@ func TestKeystore_Expire(t *testing.T) {
defer cleanup()
// Decrypt the private key with the right password.
err := keystore.Open(context.Background(), "testuser", string(pw), 60)
err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60)
if err != nil {
t.Fatal("keystore.Open():", err)
}
......
package server
import (
"container/heap"
"time"
)
type sessionPQ []*userSession
func (pq sessionPQ) Len() int {
return len(pq)
}
func (pq sessionPQ) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq sessionPQ) Less(i, j int) bool {
return pq[i].expiry.Before(pq[j].expiry)
}
func (pq *sessionPQ) Push(x interface{}) {
n := len(*pq)
item := x.(*userSession)
item.index = n
*pq = append(*pq, item)
}
func (pq *sessionPQ) Pop() interface{} {
old := *pq
n := len(old)
x := old[n-1]
old[n-1] = nil
*pq = old[0 : n-1]
return x
}
type sessionMap struct {
sessions map[string]*userSession
pq sessionPQ
}
func newSessionMap() *sessionMap {
return &sessionMap{
sessions: make(map[string]*userSession),
}
}
func (m *sessionMap) add(sessionID, username string, ttl time.Duration) {
sess := &userSession{
id: sessionID,
username: username,
expiry: time.Now().Add(ttl),
}
m.sessions[sessionID] = sess
heap.Push(&m.pq, sess)
}
func (m *sessionMap) del(sessionID string) *userSession {
sess, ok := m.sessions[sessionID]
if !ok {
return nil
}
delete(m.sessions, sessionID)
heap.Remove(&m.pq, sess.index)
sess.index = -1
return sess
}
func (m *sessionMap) expireNext(deadline time.Time) *userSession {
// Peek and return first expired session.
if len(m.pq) > 0 && m.pq[0].expiry.Before(deadline) {
sess := heap.Pop(&m.pq).(*userSession)
delete(m.sessions, sess.id)
sess.index = -1
return sess
}
return nil
}
type userKeyMap struct {
sessions *sessionMap
userKeys map[string]*userKey
userSessions map[string][]string
}
func newUserKeyMap() *userKeyMap {
return &userKeyMap{
sessions: newSessionMap(),
userKeys: make(map[string]*userKey),
userSessions: make(map[string][]string),
}
}
func (u *userKeyMap) numCachedKeys() int {
return len(u.userKeys)
}
func (u *userKeyMap) get(username string) (*userKey, bool) {
k, ok := u.userKeys[username]
return k, ok
}
func (u *userKeyMap) addSessionWithKey(sessionID, username string, key *userKey, ttlSeconds int) {
u.sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second)
u.userKeys[username] = key
u.userSessions[username] = append(u.userSessions[username], sessionID)
}
func (u *userKeyMap) deleteSession(sessionID string) bool {
if sess := u.sessions.del(sessionID); sess != nil {
u.cleanupSession(sess)
return true
}
return false
}
func (u *userKeyMap) expire(deadline time.Time) {
for {
sess := u.sessions.expireNext(deadline)
if sess == nil {
return
}
u.cleanupSession(sess)
}
}
func (u *userKeyMap) cleanupSession(sess *userSession) {
var ids []string
for _, id := range u.userSessions[sess.username] {
if id != sess.id {
ids = append(ids, id)
}
}
if len(ids) == 0 {
// No more sessions for this user, delete key.
k := u.userKeys[sess.username]
delete(u.userKeys, sess.username)
wipeBytes(k.pkey)
delete(u.userSessions, sess.username)
} else {
u.userSessions[sess.username] = ids
}
}
......@@ -21,7 +21,7 @@ func (s *keyStoreServer) handleOpen(w http.ResponseWriter, r *http.Request) {
return
}
err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.TTL)
err := s.KeyStore.Open(r.Context(), req.Username, req.Password, req.SessionID, req.TTL)
if err == errNoKeys {
log.Printf("Open(%s): no encrypted keys found in database", req.Username)
} else if err != nil {
......@@ -76,7 +76,7 @@ func (s *keyStoreServer) handleClose(w http.ResponseWriter, r *http.Request) {
return
}
if s.KeyStore.Close(req.Username) {
if s.KeyStore.Close(req.SessionID) {
log.Printf("Close(%s): discarded key", req.Username)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment