diff --git a/noblogsmv/state.py b/noblogsmv/state.py index 701b6463048a51fd6618d4beb0de2ee64ad0668e..88df33d6ba4198b548ef8cc5f8338c63c435c2ae 100644 --- a/noblogsmv/state.py +++ b/noblogsmv/state.py @@ -78,7 +78,7 @@ class LevelDbSession(object): return self._decode(self.snap.Get(key)) def put_many(self, values): - wb = self.db.WriteBatch() + wb = leveldb.WriteBatch() for key, value in values: wb.Put(key, self._encode(value)) self.db.Write(wb) @@ -111,13 +111,14 @@ class StateDatabase(object): def put(self, session, work): session.put(work.key, work) - def set(self, session, work): - # The Mutable takes care of keeping track of changes to 'data'. - self.put(session, work) + set = put def get(self, session, key): return session.get(key) + def put_many(self, session, values): + session.put_many(values) + def scan(self, session, only_pending=True): """Iterate over all the database keys.""" result = [] @@ -307,8 +308,9 @@ class StateMachine(object): def load_data(self, input_stream): with transaction(self.db) as session: - for key, value in input_stream: - self.db.put(session, WorkUnit(key, STATE_INIT, value)) + self.db.put_many(session, + ((key, WorkUnit(key, STATE_INIT, value)) + for key, value in input_stream)) def process(self, key, state, value, progress): value.pop('error_msg', '')