diff --git a/server/keystore.go b/server/keystore.go index 0a2c6cc90adab7cd4f1bbb35e14fd1cf078c25a0..33ca3cdfbd73bf79b3eccfc69baf918416fcebb9 100644 --- a/server/keystore.go +++ b/server/keystore.go @@ -1,8 +1,10 @@ package server import ( + "compress/gzip" "context" "errors" + "io" "io/ioutil" "log" "strings" @@ -125,6 +127,28 @@ func (s *KeyStore) expireLoop() { } } +// Dump the keystore in-memory contents. +func (s *KeyStore) Dump(w io.Writer) error { + gzw := gzip.NewWriter(w) + if err := s.userKeys.dump(gzw); err != nil { + return err + } + if err := gzw.Flush(); err != nil { + return err + } + return gzw.Close() +} + +// Load a keystore data dump. +func (s *KeyStore) Load(r io.Reader) error { + gzr, err := gzip.NewReader(r) + if err != nil { + return err + } + defer gzr.Close() + return s.userKeys.load(gzr) +} + // Open the user's key store with the given password. If successful, // the unencrypted user key will be stored for at most ttlSeconds, or // until Close is called with the same session ID. @@ -195,7 +219,7 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) { if !ok { return nil, errNoKeys } - return u.pkey, nil + return u.PKey, nil } // Close the user's key store and wipe the associated unencrypted key diff --git a/server/keystore_test.go b/server/keystore_test.go index d32b3aca1f55f31c2959ac6244f271b58465abda..fb5ed2e0b5d3a45ba484e7b903670692cafdc5b3 100644 --- a/server/keystore_test.go +++ b/server/keystore_test.go @@ -283,3 +283,40 @@ func TestKeystore_Expire(t *testing.T) { t.Fatal("keystore.Get():", err) } } + +func TestKeystore_DumpAndReload(t *testing.T) { + _, keystore, cleanup := newTestContext(t) + defer cleanup() + + // Decrypt the private key with the right password. + err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60) + if err != nil { + t.Fatal("keystore.Open():", err) + } + + // Now dump the keystore state, and reload it in a new one. + var buf bytes.Buffer + if err := keystore.Dump(&buf); err != nil { + t.Fatal("keystore.Dump():", err) + } + + c, keystore2, cleanup2 := newTestContext(t) + defer cleanup2() + + if err := keystore2.Load(&buf); err != nil { + t.Fatal("keystore2.Load():", err) + } + + // Sign a valid SSO ticket and use it to obtain the private + // key we just stored. + ssoTicket := c.sign("testuser", "keystore/", "domain") + result, err := keystore2.Get("testuser", ssoTicket) + if err != nil { + t.Fatal("keystore2.Get():", err) + } + + expectedPEM, _ := privKey.PEM() + if !bytes.Equal(result, expectedPEM) { + t.Fatalf("keystore2.Get() returned bad key: got %v, expected %v", result, expectedPEM) + } +} diff --git a/server/map.go b/server/map.go index 198039c99eed721973e346042ad94e52bf8976e2..88634e83fd322b50993197b5c8acf0878aea1a64 100644 --- a/server/map.go +++ b/server/map.go @@ -1,23 +1,26 @@ package server import ( + "bytes" "container/heap" + "encoding/gob" + "io" "time" ) type userKey struct { - pkey []byte + PKey []byte } func newUserKey(pem []byte) *userKey { - return &userKey{pkey: pem} + return &userKey{PKey: pem} } type userSession struct { - id string - username string - expiry time.Time - index int + ID string + Username string + Expiry time.Time + Index int } // Priority queue of userSession objects, kept ordered by expiration @@ -30,18 +33,18 @@ func (pq sessionPQ) Len() int { func (pq sessionPQ) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j + pq[i].Index = i + pq[j].Index = j } func (pq sessionPQ) Less(i, j int) bool { - return pq[i].expiry.Before(pq[j].expiry) + return pq[i].Expiry.Before(pq[j].Expiry) } func (pq *sessionPQ) Push(x interface{}) { n := len(*pq) item := x.(*userSession) - item.index = n + item.Index = n *pq = append(*pq, item) } @@ -67,9 +70,9 @@ func newSessionMap() *sessionMap { func (m *sessionMap) add(sessionID, username string, ttl time.Duration) { sess := &userSession{ - id: sessionID, - username: username, - expiry: time.Now().Add(ttl), + ID: sessionID, + Username: username, + Expiry: time.Now().Add(ttl), } m.sessions[sessionID] = sess heap.Push(&m.pq, sess) @@ -81,60 +84,82 @@ func (m *sessionMap) del(sessionID string) *userSession { return nil } delete(m.sessions, sessionID) - heap.Remove(&m.pq, sess.index) - sess.index = -1 + heap.Remove(&m.pq, sess.Index) + sess.Index = -1 return sess } // Peek at the oldest entry and pop it if expired. func (m *sessionMap) expireNext(deadline time.Time) *userSession { - if len(m.pq) > 0 && m.pq[0].expiry.Before(deadline) { + if len(m.pq) > 0 && m.pq[0].Expiry.Before(deadline) { sess := heap.Pop(&m.pq).(*userSession) - delete(m.sessions, sess.id) - sess.index = -1 + delete(m.sessions, sess.ID) + sess.Index = -1 return sess } return nil } +func (m *sessionMap) GobEncode() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode([]*userSession(m.pq)); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (m *sessionMap) GobDecode(b []byte) error { + var tmp []*userSession + if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&tmp); err != nil { + return err + } + sessions := make(map[string]*userSession) + for _, sess := range tmp { + sessions[sess.ID] = sess + } + m.pq = sessionPQ(tmp) + m.sessions = sessions + return nil +} + // Maintain association between sessions (that may expire) and user // keys. Enforces the constraint that keys will be removed once they // have no sessions attached. type userKeyMap struct { - sessions *sessionMap - userKeys map[string]*userKey - userSessions map[string][]string + Sessions *sessionMap + UserKeys map[string]*userKey + UserSessions map[string][]string } func newUserKeyMap() *userKeyMap { return &userKeyMap{ - sessions: newSessionMap(), - userKeys: make(map[string]*userKey), - userSessions: make(map[string][]string), + Sessions: newSessionMap(), + UserKeys: make(map[string]*userKey), + UserSessions: make(map[string][]string), } } func (u *userKeyMap) numCachedKeys() int { - return len(u.userKeys) + return len(u.UserKeys) } func (u *userKeyMap) numSessions() int { - return u.sessions.pq.Len() + return u.Sessions.pq.Len() } func (u *userKeyMap) get(username string) (*userKey, bool) { - k, ok := u.userKeys[username] + k, ok := u.UserKeys[username] return k, ok } func (u *userKeyMap) addSessionWithKey(sessionID, username string, key *userKey, ttlSeconds int) { - u.sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second) - u.userKeys[username] = key - u.userSessions[username] = append(u.userSessions[username], sessionID) + u.Sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second) + u.UserKeys[username] = key + u.UserSessions[username] = append(u.UserSessions[username], sessionID) } func (u *userKeyMap) deleteSession(sessionID string) bool { - if sess := u.sessions.del(sessionID); sess != nil { + if sess := u.Sessions.del(sessionID); sess != nil { u.cleanupSession(sess) return true } @@ -143,7 +168,7 @@ func (u *userKeyMap) deleteSession(sessionID string) bool { func (u *userKeyMap) expire(deadline time.Time) { for { - sess := u.sessions.expireNext(deadline) + sess := u.Sessions.expireNext(deadline) if sess == nil { return } @@ -153,25 +178,43 @@ func (u *userKeyMap) expire(deadline time.Time) { func (u *userKeyMap) cleanupSession(sess *userSession) { var ids []string - for _, id := range u.userSessions[sess.username] { - if id != sess.id { + for _, id := range u.UserSessions[sess.Username] { + if id != sess.ID { ids = append(ids, id) } } if len(ids) == 0 { // No more sessions for this user, delete key. - k := u.userKeys[sess.username] - delete(u.userKeys, sess.username) - wipeBytes(k.pkey) - delete(u.userSessions, sess.username) + k := u.UserKeys[sess.Username] + delete(u.UserKeys, sess.Username) + wipeBytes(k.PKey) + delete(u.UserSessions, sess.Username) } else { - u.userSessions[sess.username] = ids + u.UserSessions[sess.Username] = ids } } +func (u *userKeyMap) dump(w io.Writer) error { + return gob.NewEncoder(w).Encode(u) +} + +func (u *userKeyMap) load(r io.Reader) error { + return gob.NewDecoder(r).Decode(u) +} + func wipeBytes(b []byte) { for i := 0; i < len(b); i++ { b[i] = 0 } } + +func init() { + // Register the types referenced by userKeyMap with the gob + // encoder. Use short names so that maps of objects take less + // space in the output. + gob.RegisterName("k", &userKey{}) + gob.RegisterName("s", &userSession{}) + gob.RegisterName("m", &sessionMap{}) + gob.RegisterName("ukm", &userKeyMap{}) +}