Skip to content
Snippets Groups Projects
Commit 38d0d11c authored by ale's avatar ale
Browse files

Implement basic dump/restore functionality

parent c4fe03f2
No related branches found
No related tags found
1 merge request!35Add state load/dump logic to the main server binary
package server package server
import ( import (
"compress/gzip"
"context" "context"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"strings" "strings"
...@@ -125,6 +127,28 @@ func (s *KeyStore) expireLoop() { ...@@ -125,6 +127,28 @@ func (s *KeyStore) expireLoop() {
} }
} }
// Dump the keystore in-memory contents.
func (s *KeyStore) Dump(w io.Writer) error {
gzw := gzip.NewWriter(w)
if err := s.userKeys.dump(gzw); err != nil {
return err
}
if err := gzw.Flush(); err != nil {
return err
}
return gzw.Close()
}
// Load a keystore data dump.
func (s *KeyStore) Load(r io.Reader) error {
gzr, err := gzip.NewReader(r)
if err != nil {
return err
}
defer gzr.Close()
return s.userKeys.load(gzr)
}
// Open the user's key store with the given password. If successful, // Open the user's key store with the given password. If successful,
// the unencrypted user key will be stored for at most ttlSeconds, or // the unencrypted user key will be stored for at most ttlSeconds, or
// until Close is called with the same session ID. // until Close is called with the same session ID.
...@@ -195,7 +219,7 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) { ...@@ -195,7 +219,7 @@ func (s *KeyStore) Get(username, ssoTicket string) ([]byte, error) {
if !ok { if !ok {
return nil, errNoKeys return nil, errNoKeys
} }
return u.pkey, nil return u.PKey, nil
} }
// Close the user's key store and wipe the associated unencrypted key // Close the user's key store and wipe the associated unencrypted key
......
...@@ -283,3 +283,40 @@ func TestKeystore_Expire(t *testing.T) { ...@@ -283,3 +283,40 @@ func TestKeystore_Expire(t *testing.T) {
t.Fatal("keystore.Get():", err) t.Fatal("keystore.Get():", err)
} }
} }
func TestKeystore_DumpAndReload(t *testing.T) {
_, keystore, cleanup := newTestContext(t)
defer cleanup()
// Decrypt the private key with the right password.
err := keystore.Open(context.Background(), "testuser", string(pw), "session", 60)
if err != nil {
t.Fatal("keystore.Open():", err)
}
// Now dump the keystore state, and reload it in a new one.
var buf bytes.Buffer
if err := keystore.Dump(&buf); err != nil {
t.Fatal("keystore.Dump():", err)
}
c, keystore2, cleanup2 := newTestContext(t)
defer cleanup2()
if err := keystore2.Load(&buf); err != nil {
t.Fatal("keystore2.Load():", err)
}
// Sign a valid SSO ticket and use it to obtain the private
// key we just stored.
ssoTicket := c.sign("testuser", "keystore/", "domain")
result, err := keystore2.Get("testuser", ssoTicket)
if err != nil {
t.Fatal("keystore2.Get():", err)
}
expectedPEM, _ := privKey.PEM()
if !bytes.Equal(result, expectedPEM) {
t.Fatalf("keystore2.Get() returned bad key: got %v, expected %v", result, expectedPEM)
}
}
package server package server
import ( import (
"bytes"
"container/heap" "container/heap"
"encoding/gob"
"io"
"time" "time"
) )
type userKey struct { type userKey struct {
pkey []byte PKey []byte
} }
func newUserKey(pem []byte) *userKey { func newUserKey(pem []byte) *userKey {
return &userKey{pkey: pem} return &userKey{PKey: pem}
} }
type userSession struct { type userSession struct {
id string ID string
username string Username string
expiry time.Time Expiry time.Time
index int Index int
} }
// Priority queue of userSession objects, kept ordered by expiration // Priority queue of userSession objects, kept ordered by expiration
...@@ -30,18 +33,18 @@ func (pq sessionPQ) Len() int { ...@@ -30,18 +33,18 @@ func (pq sessionPQ) Len() int {
func (pq sessionPQ) Swap(i, j int) { func (pq sessionPQ) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i] pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i pq[i].Index = i
pq[j].index = j pq[j].Index = j
} }
func (pq sessionPQ) Less(i, j int) bool { func (pq sessionPQ) Less(i, j int) bool {
return pq[i].expiry.Before(pq[j].expiry) return pq[i].Expiry.Before(pq[j].Expiry)
} }
func (pq *sessionPQ) Push(x interface{}) { func (pq *sessionPQ) Push(x interface{}) {
n := len(*pq) n := len(*pq)
item := x.(*userSession) item := x.(*userSession)
item.index = n item.Index = n
*pq = append(*pq, item) *pq = append(*pq, item)
} }
...@@ -67,9 +70,9 @@ func newSessionMap() *sessionMap { ...@@ -67,9 +70,9 @@ func newSessionMap() *sessionMap {
func (m *sessionMap) add(sessionID, username string, ttl time.Duration) { func (m *sessionMap) add(sessionID, username string, ttl time.Duration) {
sess := &userSession{ sess := &userSession{
id: sessionID, ID: sessionID,
username: username, Username: username,
expiry: time.Now().Add(ttl), Expiry: time.Now().Add(ttl),
} }
m.sessions[sessionID] = sess m.sessions[sessionID] = sess
heap.Push(&m.pq, sess) heap.Push(&m.pq, sess)
...@@ -81,60 +84,82 @@ func (m *sessionMap) del(sessionID string) *userSession { ...@@ -81,60 +84,82 @@ func (m *sessionMap) del(sessionID string) *userSession {
return nil return nil
} }
delete(m.sessions, sessionID) delete(m.sessions, sessionID)
heap.Remove(&m.pq, sess.index) heap.Remove(&m.pq, sess.Index)
sess.index = -1 sess.Index = -1
return sess return sess
} }
// Peek at the oldest entry and pop it if expired. // Peek at the oldest entry and pop it if expired.
func (m *sessionMap) expireNext(deadline time.Time) *userSession { func (m *sessionMap) expireNext(deadline time.Time) *userSession {
if len(m.pq) > 0 && m.pq[0].expiry.Before(deadline) { if len(m.pq) > 0 && m.pq[0].Expiry.Before(deadline) {
sess := heap.Pop(&m.pq).(*userSession) sess := heap.Pop(&m.pq).(*userSession)
delete(m.sessions, sess.id) delete(m.sessions, sess.ID)
sess.index = -1 sess.Index = -1
return sess return sess
} }
return nil return nil
} }
func (m *sessionMap) GobEncode() ([]byte, error) {
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode([]*userSession(m.pq)); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (m *sessionMap) GobDecode(b []byte) error {
var tmp []*userSession
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&tmp); err != nil {
return err
}
sessions := make(map[string]*userSession)
for _, sess := range tmp {
sessions[sess.ID] = sess
}
m.pq = sessionPQ(tmp)
m.sessions = sessions
return nil
}
// Maintain association between sessions (that may expire) and user // Maintain association between sessions (that may expire) and user
// keys. Enforces the constraint that keys will be removed once they // keys. Enforces the constraint that keys will be removed once they
// have no sessions attached. // have no sessions attached.
type userKeyMap struct { type userKeyMap struct {
sessions *sessionMap Sessions *sessionMap
userKeys map[string]*userKey UserKeys map[string]*userKey
userSessions map[string][]string UserSessions map[string][]string
} }
func newUserKeyMap() *userKeyMap { func newUserKeyMap() *userKeyMap {
return &userKeyMap{ return &userKeyMap{
sessions: newSessionMap(), Sessions: newSessionMap(),
userKeys: make(map[string]*userKey), UserKeys: make(map[string]*userKey),
userSessions: make(map[string][]string), UserSessions: make(map[string][]string),
} }
} }
func (u *userKeyMap) numCachedKeys() int { func (u *userKeyMap) numCachedKeys() int {
return len(u.userKeys) return len(u.UserKeys)
} }
func (u *userKeyMap) numSessions() int { func (u *userKeyMap) numSessions() int {
return u.sessions.pq.Len() return u.Sessions.pq.Len()
} }
func (u *userKeyMap) get(username string) (*userKey, bool) { func (u *userKeyMap) get(username string) (*userKey, bool) {
k, ok := u.userKeys[username] k, ok := u.UserKeys[username]
return k, ok return k, ok
} }
func (u *userKeyMap) addSessionWithKey(sessionID, username string, key *userKey, ttlSeconds int) { func (u *userKeyMap) addSessionWithKey(sessionID, username string, key *userKey, ttlSeconds int) {
u.sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second) u.Sessions.add(sessionID, username, time.Duration(ttlSeconds)*time.Second)
u.userKeys[username] = key u.UserKeys[username] = key
u.userSessions[username] = append(u.userSessions[username], sessionID) u.UserSessions[username] = append(u.UserSessions[username], sessionID)
} }
func (u *userKeyMap) deleteSession(sessionID string) bool { func (u *userKeyMap) deleteSession(sessionID string) bool {
if sess := u.sessions.del(sessionID); sess != nil { if sess := u.Sessions.del(sessionID); sess != nil {
u.cleanupSession(sess) u.cleanupSession(sess)
return true return true
} }
...@@ -143,7 +168,7 @@ func (u *userKeyMap) deleteSession(sessionID string) bool { ...@@ -143,7 +168,7 @@ func (u *userKeyMap) deleteSession(sessionID string) bool {
func (u *userKeyMap) expire(deadline time.Time) { func (u *userKeyMap) expire(deadline time.Time) {
for { for {
sess := u.sessions.expireNext(deadline) sess := u.Sessions.expireNext(deadline)
if sess == nil { if sess == nil {
return return
} }
...@@ -153,21 +178,29 @@ func (u *userKeyMap) expire(deadline time.Time) { ...@@ -153,21 +178,29 @@ func (u *userKeyMap) expire(deadline time.Time) {
func (u *userKeyMap) cleanupSession(sess *userSession) { func (u *userKeyMap) cleanupSession(sess *userSession) {
var ids []string var ids []string
for _, id := range u.userSessions[sess.username] { for _, id := range u.UserSessions[sess.Username] {
if id != sess.id { if id != sess.ID {
ids = append(ids, id) ids = append(ids, id)
} }
} }
if len(ids) == 0 { if len(ids) == 0 {
// No more sessions for this user, delete key. // No more sessions for this user, delete key.
k := u.userKeys[sess.username] k := u.UserKeys[sess.Username]
delete(u.userKeys, sess.username) delete(u.UserKeys, sess.Username)
wipeBytes(k.pkey) wipeBytes(k.PKey)
delete(u.userSessions, sess.username) delete(u.UserSessions, sess.Username)
} else { } else {
u.userSessions[sess.username] = ids u.UserSessions[sess.Username] = ids
}
} }
func (u *userKeyMap) dump(w io.Writer) error {
return gob.NewEncoder(w).Encode(u)
}
func (u *userKeyMap) load(r io.Reader) error {
return gob.NewDecoder(r).Decode(u)
} }
func wipeBytes(b []byte) { func wipeBytes(b []byte) {
...@@ -175,3 +208,13 @@ func wipeBytes(b []byte) { ...@@ -175,3 +208,13 @@ func wipeBytes(b []byte) {
b[i] = 0 b[i] = 0
} }
} }
func init() {
// Register the types referenced by userKeyMap with the gob
// encoder. Use short names so that maps of objects take less
// space in the output.
gob.RegisterName("k", &userKey{})
gob.RegisterName("s", &userSession{})
gob.RegisterName("m", &sessionMap{})
gob.RegisterName("ukm", &userKeyMap{})
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment