diff --git a/noblogsmv/state.py b/noblogsmv/state.py index 703eb94c74757c5efc9a7106f61394da5a3fa75b..5f4b91d881ad0d6ef20863c8387f0d309d59efb3 100644 --- a/noblogsmv/state.py +++ b/noblogsmv/state.py @@ -2,12 +2,18 @@ import contextlib import json import logging import os -import sqlite3 import sys import threading import time import Queue +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import TypeDecorator, VARCHAR +from sqlalchemy.ext.mutable import Mutable + +Base = declarative_base() log = logging.getLogger(__name__) @@ -20,8 +26,59 @@ STATE_ERROR = 'error' EXIT_STATES = (STATE_DONE, STATE_ERROR) + +# Encode the data field as JSON +class JSONEncodedDict(TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = VARCHAR + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + + +class MutableDict(Mutable, dict): + + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + + # The base work unit has a primary key, a state, and some opaque data. -class WorkUnit(object): +class WorkUnit(Base): + + __tablename__ = 'work' + key = sa.Column(sa.String(128), primary_key=True, unique=True) + state = sa.Column(sa.String(6), index=True) + data = sa.Column(MutableDict.as_mutable(JSONEncodedDict)) def __init__(self, key, state, data): self.key = key @@ -40,8 +97,10 @@ class StateDatabase(object): def __init__(self, path, codec=json): self.path = path self.codec = codec - self._local = threading.local() - self._initialize() + + self.engine = sa.create_engine('sqlite:///' + path) + self.session = sessionmaker(bind=self.engine) + Base.metadata.create_all(self.engine) @property def conn(self): @@ -50,94 +109,62 @@ class StateDatabase(object): return self._local.conn def close(self): - self.conn.close() - - def _initialize(self): - try: - cursor = self.conn.cursor() - cursor.execute( - ''' -create table kv ( - key varchar primary key, - state varchar, - value text) -''') - self.conn.commit() - except: - pass - - def put(self, work): - cursor = self.conn.cursor() - cursor.execute( - 'insert into kv values (?, ?, ?)', - (work.key, work.state, self.codec.dumps(work.data))) - - def get(self, key, cursor=None): - if not cursor: - cursor = self.conn.cursor() - cursor.execute('select state, value from kv where key = ?', (key,)) - row = cursor.fetchone() - if not row: - return None - return WorkUnit(key, row[0], self.codec.loads(row[1])) - - def set(self, work, cursor=None): - if not cursor: - cursor = self.conn.cursor() - cursor.execute( - 'update kv set state=?, value=? where key=?', - (work.state, self.codec.dumps(work.data), work.key)) - - def scan(self, only_pending=True): + pass + + def put(self, session, work): + session.add(work) + + def set(self, session, work): + # The PickleField does not detect changes in the contained + # object, so we have to create a new one and add that instead. + #tmp = WorkUnit(work.key, work.state, work.data) + #self.put(session, tmp) + work.data = work.data + self.put(session, work) + + def get(self, session, key): + return session.query(WorkUnit).get(key) + + def scan(self, session, only_pending=True): """Iterate over all the database keys.""" - cursor = self.conn.cursor() - sql = 'select key from kv' + q = session.query(WorkUnit) if only_pending: - sql += ' where state not in (\'%s\')' % ("', '".join(EXIT_STATES)) - cursor.execute(sql) - while True: - row = cursor.fetchone() - if not row: - break - yield row[0] + q = q.filter(~WorkUnit.state.in_(EXIT_STATES)) + return (x.key for x in q) - def dump(self): + def dump(self, session): """Return every object in the db.""" - cursor = self.conn.cursor() - cursor.execute('select key, state, value from kv') - while True: - row = cursor.fetchone() - if not row: - break - yield WorkUnit(row[0], row[1], self.codec.loads(row[2])) + return session.query(WorkUnit) - def are_we_done(self): + def are_we_done(self, session): """Check if all the entries are in a final state.""" - cursor = self.conn.cursor() - cursor.execute( - 'select count(*) from kv where state not in (\'%s\')' % ( - "', '".join(EXIT_STATES))) - row = cursor.fetchone() - if not row: - return False # error, default to false - return row[0] == 0 - - def count_by_state(self): - cursor = self.conn.cursor() - cursor.execute('select state, count(*) from kv group by state') - counts = {} - for row in cursor.fetchall(): - counts[row[0]] = row[1] - return counts + n = session.query(WorkUnit).filter(~WorkUnit.state.in_(EXIT_STATES)).count() + return n == 0 + + def count_by_state(self, session): + return dict(session.query('state', 'c').from_statement( + 'select state, count(*) as c from work group by state').all()) + + +@contextlib.contextmanager +def transaction(db): + session = db.session() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() @contextlib.contextmanager def work_transaction(db, key): - cursor = db.conn.cursor() - work = db.get(key, cursor) - yield work - db.set(work, cursor) - db.conn.commit() + with transaction(db) as session: + work = db.get(session, key) + yield work + db.put(session, work) class WorkerProgressReporter(object): @@ -243,9 +270,9 @@ class StateMachine(object): self.state_count = {} def load_data(self, input_stream): - for key, value in input_stream: - self.db.put(WorkUnit(key, STATE_INIT, value)) - self.db.conn.commit() + with transaction(self.db) as session: + for key, value in input_stream: + self.db.put(session, WorkUnit(key, STATE_INIT, value)) def process(self, key, state, value, progress): value.pop('error_msg', '') @@ -274,10 +301,12 @@ class StateMachine(object): return [t.info for t in self.threads] def get_state(self): - return self.db.dump() + with transaction(self.db) as session: + return self.db.dump(session) def compute_stats(self): - self.state_count = self.db.count_by_state() + with transaction(self.db) as session: + self.state_count = self.db.count_by_state(session) def run(self): input_queue = Queue.Queue() @@ -291,11 +320,15 @@ class StateMachine(object): self.running = True # Inject initial state. - for key in self.db.scan(): - input_queue.put(key) + with transaction(self.db) as session: + for key in self.db.scan(session): + input_queue.put(key) # Wait until everything is done. - while not self.db.are_we_done(): + while True: + with transaction(self.db) as session: + if self.db.are_we_done(session): + break time.sleep(3) # Kill the workers. diff --git a/noblogsmv/test/test_state.py b/noblogsmv/test/test_state.py index 79e472f29f2679bd4e02ee811ea2ef402dc0a7e2..7a9a3698800ebcabbddfe40e502a407b7d29c634 100644 --- a/noblogsmv/test/test_state.py +++ b/noblogsmv/test/test_state.py @@ -1,7 +1,8 @@ import collections import os import shutil -import sqlite3 +import sqlalchemy +import sqlalchemy.exc import tempfile import threading import time @@ -10,7 +11,8 @@ import urllib2 from noblogsmv import state -TEST_WORK_UNIT = state.WorkUnit('key', state.STATE_INIT, {'a': 42}) +def _test_work_unit(): + return state.WorkUnit('key', state.STATE_INIT, {'a': 42}) class TestBase(unittest.TestCase): @@ -34,53 +36,67 @@ class DatabaseTest(TestBase): TestBase.tearDown(self) def test_put(self): - self.db.put(TEST_WORK_UNIT) - w2 = self.db.get(TEST_WORK_UNIT.key) - self.assertEquals(TEST_WORK_UNIT, w2) + with state.transaction(self.db) as session: + t = _test_work_unit() + self.db.put(session, t) + w2 = self.db.get(session, t.key) + self.assertEquals(t, w2) def test_put_is_unique(self): - self.db.put(TEST_WORK_UNIT) - self.assertRaises(sqlite3.IntegrityError, - self.db.put, TEST_WORK_UNIT) + def _saveit(): + with state.transaction(self.db) as session: + self.db.put(session, _test_work_unit()) + _saveit() + self.assertRaises(sqlalchemy.exc.SQLAlchemyError, _saveit) def test_set(self): - self.db.put(TEST_WORK_UNIT) - w = self.db.get(TEST_WORK_UNIT.key) - w.data['b'] = 21 - self.db.set(w) - w2 = self.db.get(w.key) - self.assertEquals(21, w2.data['b']) + t = _test_work_unit() + key = t.key + with state.transaction(self.db) as session: + self.db.put(session, t) + with state.transaction(self.db) as session: + w = self.db.get(session, key) + w.data['b'] = 21 + self.db.set(session, w) + with state.transaction(self.db) as session: + w2 = self.db.get(session, key) + self.assertEquals({'a': 42, 'b': 21}, w2.data) def test_scan(self): - self.db.put(state.WorkUnit('key1', state.STATE_INIT, {})) - self.db.put(state.WorkUnit('key2', state.STATE_DONE, {})) - self.db.put(state.WorkUnit('key3', 'mystate', {})) - result = set(self.db.scan()) - self.assertEquals(set(['key1', 'key3']), result) + with state.transaction(self.db) as session: + self.db.put(session, state.WorkUnit('key1', state.STATE_INIT, {})) + self.db.put(session, state.WorkUnit('key2', state.STATE_DONE, {})) + self.db.put(session, state.WorkUnit('key3', 'mystate', {})) + result = set(self.db.scan(session)) + self.assertEquals(set(['key1', 'key3']), result) def test_dump(self): - self.db.put(state.WorkUnit('key1', state.STATE_INIT, {})) - self.db.put(state.WorkUnit('key2', state.STATE_DONE, {})) - self.db.put(state.WorkUnit('key3', 'mystate', {})) - result = list(self.db.dump()) - self.assertEquals(3, len(result)) + with state.transaction(self.db) as session: + self.db.put(session, state.WorkUnit('key1', state.STATE_INIT, {})) + self.db.put(session, state.WorkUnit('key2', state.STATE_DONE, {})) + self.db.put(session, state.WorkUnit('key3', 'mystate', {})) + result = list(self.db.dump(session)) + self.assertEquals(3, len(result)) def test_are_we_done(self): - self.db.put(TEST_WORK_UNIT) - self.assertFalse(self.db.are_we_done()) - w = self.db.get(TEST_WORK_UNIT.key) - w.state = state.STATE_DONE - self.db.set(w) - self.assertTrue(self.db.are_we_done()) + t = _test_work_unit() + with state.transaction(self.db) as session: + self.db.put(session, t) + self.assertFalse(self.db.are_we_done(session)) + w = self.db.get(session, t.key) + w.state = state.STATE_DONE + self.db.set(session, w) + self.assertTrue(self.db.are_we_done(session)) def test_loadtest(self): nkeys = 100 nloops = 100 out = {'errors': 0} - for i in xrange(nkeys): - self.db.put(state.WorkUnit('key%d' % i, state.STATE_INIT, {i: i})) - self.db.conn.commit() + with state.transaction(self.db) as session: + for i in xrange(nkeys): + self.db.put(session, + state.WorkUnit('key%d' % i, state.STATE_INIT, {i: i})) def loadfn(i): key = 'key%d' % i @@ -114,14 +130,19 @@ class StateMachineLowLevelTest(TestBase): smc = _mk_state_machine({'test': state.nop}) sm = smc(2, os.path.join(self.tmpdir, 'sm.db')) sm.load_data(TEST_DATA) - self.assertFalse(sm.db.are_we_done()) + with state.transaction(sm.db) as session: + self.assertFalse(sm.db.are_we_done(session)) class StateMachineTest(TestBase): def _scrape_data(self, db): # Scrape the contents of the db into the test object. - self.output_data = [db.get(x) for x in db.scan(only_pending=False)] + def _d(w): + return {'key': w.key, 'state': w.state, 'data': w.data} + with state.transaction(db) as session: + self.output_data = [_d(db.get(session, x)) + for x in db.scan(session, only_pending=False)] def _run_sm(self, states, inputs, timeout=10): smc = _mk_state_machine(states) @@ -166,7 +187,7 @@ class StateMachineTest(TestBase): value['b'] = 42 return state.STATE_DONE self._run_sm({'init': mod_fn}, TEST_DATA) - bvalues = [w.data['b'] for w in self.output_data] + bvalues = [w['data']['b'] for w in self.output_data] self.assertEquals([42, 42], bvalues) diff --git a/setup.py b/setup.py index 2cda674f4203e5ab3fb5e7e7ea7b22e872760f9c..9a4642ec0dedf3c08ccacc864d9b3d0ae2801776 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( author="ale", author_email="ale@incal.net", url="https://git.autistici.org/ai/noblogsmv", - install_requires=['Flask'], + install_requires=['Flask', 'SQLAlchemy'], setup_requires=[], zip_safe=True, packages=find_packages(),