From 158781984dbddd983bb421cc18f878e2f3610038 Mon Sep 17 00:00:00 2001
From: joe <joe@autistici.org>
Date: Sun, 1 Dec 2013 02:51:06 +0100
Subject: [PATCH] etcd interface passing all tests, supports audit.

The etcd interface now passes all tests and also supports a very crude
audit logging. At the moment, only 100 audit messages are supported.

Due to a change of behaviour in dateutil, we had to change a little our
exception handling in the schema in order to pass tests in python
2.7.5.

The etcd driver has still not been road-tested
---
 configdb/db/db_api.py                    |   8 +-
 configdb/db/interface/etcd_interface.py  | 174 ++++++++++++++++++++---
 configdb/tests/db_api_test_base.py       |  15 +-
 configdb/tests/db_interface_test_base.py |   3 +-
 configdb/tests/test_db_api_etcd.py       |  21 +++
 configdb/tests/test_etcd_interface.py    |  21 +++
 configdb/tests/test_schema.py            |   3 +-
 7 files changed, 211 insertions(+), 34 deletions(-)
 create mode 100644 configdb/tests/test_db_api_etcd.py
 create mode 100644 configdb/tests/test_etcd_interface.py

diff --git a/configdb/db/db_api.py b/configdb/db/db_api.py
index 50f7076..a6b0362 100644
--- a/configdb/db/db_api.py
+++ b/configdb/db/db_api.py
@@ -52,7 +52,7 @@ class AdmDbApi(object):
         # Deserialize the input data.
         try:
             data = entity.from_net(data)
-        except ValueError, e:
+        except (ValueError, TypeError), e:
             raise exceptions.ValidationError(
                 'Validation error in deserialization: %s' % str(e))
 
@@ -106,7 +106,7 @@ class AdmDbApi(object):
                             'no such object, %s=%s' % (
                                 field.remote_name, rel_name))
                     relation.append(rel_obj)
-                for rel_name in to_remove: 
+                for rel_name in to_remove:
                     rel_obj = self.db.get_by_name(
                         field.remote_name, rel_name, session)
                     relation.remove(rel_obj)
@@ -117,7 +117,7 @@ class AdmDbApi(object):
         if entity_name in self.schema.sys_schema_tables:
             #Avoid updating timestamp for tables that are not part of the schema.
             return True
-        
+
         data = {'name': entity_name, 'ts': time.time() }
         ts = self.schema.get_entity('__timestamp')
         data = self._unpack(ts, data)
@@ -130,7 +130,7 @@ class AdmDbApi(object):
         else:
             obj = self.db.create('__timestamp', data, session)
         return True
-        
+
     @with_session
     @with_timestamp
     def update(self, session, entity_name, object_name, data, auth_context):
diff --git a/configdb/db/interface/etcd_interface.py b/configdb/db/interface/etcd_interface.py
index faef00c..4043eb4 100644
--- a/configdb/db/interface/etcd_interface.py
+++ b/configdb/db/interface/etcd_interface.py
@@ -1,40 +1,86 @@
 # You need the etcd python client library you can find here:
 # https://github.com/lavagetto/python-etcd
-import etcd
-import cPickle as pickle
 import os
+import time
+import base64
+import urllib
 from urlparse import urlparse
+import cPickle as pickle
+import json
+
+import etcd
+
+from configdb import exceptions
+from configdb.db.interface import base
 from configdb.db.interface import inmemory_interface
 
-class EtcdSession(object):
+
+class EtcdSession(inmemory_interface.InMemorySession):
     """A EtcdInterface session."""
+
     def __init__(self,db):
         self.db = db
-        raise NotImplementedError
 
-    def _mkpath(self, entity_name, object_name):
-        return os.path.join(self.db.root, entity_name, object_name)
+    def _escape(self,s):
+        # Hack alert! Since etcd interprets any '/' as a dir separator,
+        # we simply replace it with a double backslash in the path.
+        # this of course introduces a potential bug.
+        s = s.replace('/','\\\\')
+        return urllib.quote(s, safe='')
+
+    def _mkpath(self, entity_name, obj_name=None):
+        path = os.path.join(self.db.root, self._escape(entity_name))
+        if obj_name:
+            path = os.path.join(path, self._escape(obj_name))
+        return path
+
 
     def add(self, obj):
         path = self._mkpath(obj._entity_name, obj.name)
-        #TODO: test for presence of an old object and do test_and_set
-        self.db.conn.set(path, self.db._serialize(obj))
+        try:
+            idx = self.db.conn.read(path).modifiedIndex
+            opts = {'prevIndex': idx}
+        except KeyError:
+            opts = {'prevExists': False}
 
+        # Will raise ValueError if the test fails.
+        try:
+            self.db.conn.write(path, self.db._serialize(obj), **opts)
+        except ValueError:
+            raise exceptions.IntegrityError('Bad revision')
 
     def delete(self, obj):
-        raise NotImplementedError
+        self._delte_by_name(obj._entity_name, obj.name)
+
 
     def _delete_by_name(self, entity_name, obj_name):
-        raise NotImplementedError
+        path = self._mkpath(entity_name, obj_name)
+        try:
+            #etcd has no way to atomically delete objects. Meh!
+            self.db.conn.delete(path)
+        except KeyError:
+            pass
+
 
     def _deserialize_if_not_none(self, data):
-        raise NotImplementedError
+        if data:
+            return self.db._deserialize(data)
+        else:
+            return None
 
     def _get(self, entity_name, obj_name):
-        raise NotImplementedError
+        path = self._mkpath(entity_name, obj_name)
+        try:
+            data = self.db.conn.read(path).value
+            return self._deserialize_if_not_none(data)
+        except KeyError:
+            pass
 
     def _find(self, entity_name):
-        raise NotImplementedError
+        path = self._mkpath(entity_name)
+        for r in self.db.conn.read(path, recursive = True).kvs:
+            if not r.dir:
+                yield self._deserialize_if_not_none(r.value)
 
     def commit(self):
         pass
@@ -52,22 +98,31 @@ class EtcdInterface(base.DbInterface):
 
     """
 
+    AUDIT_SUPPORT = True
+    AUDIT_LOG_LENGTH = 100
+
     def __init__(self, url, schema, root='/configdb', timeout=30):
         self.root = root
+        self.schema = schema
         try:
             p = urlparse(url)
             host, port = p.netloc.split(':')
         except ValueError:
-            raise ValueError('Url {} is not in the host:port format'.format(p.netloc))
+            raise ValueError(
+                'Url {} is not in the host:port format'.format(p.netloc))
 
-        self.conn = etcd.Client(host=host, port=port, protocol = p.schema, allow_reconnect = True)
+        #TODO: find a way to allow use of SSL client certificates.
+        self.conn = etcd.Client(
+            host=host, port=int(port), protocol = p.scheme, allow_reconnect = True)
 
 
     def _serialize(self, obj):
-        return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
+        return base64.b64encode(
+            pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
+
 
     def _deserialize(self, data):
-        return pickle.loads(data)
+        return pickle.loads(base64.b64decode(data))
 
     def session(self):
         return base.session_context_manager(EtcdSession(self))
@@ -76,15 +131,88 @@ class EtcdInterface(base.DbInterface):
         return session._get(entity_name, object_name)
 
     def find(self, entity_name, query, session):
-        raise NotImplementedError
+        entity = self.schema.get_entity(entity_name)
+        return self._run_query(entity, query,
+                               session._find(entity_name))
+
 
     def create(self, entity_name, attrs, session):
         entity = self.schema.get_entity(entity_name)
-        object =
-        raise NotImplementedError
+        obj = inmemory_interface.InMemoryObject(entity, attrs)
+        session.add(obj)
+        return obj
 
-    def delete(self, entity_name, object_name, session):
-        raise NotImplementedError
+    def delete(self, entity_name, obj_name, session):
+        session._delete_by_name(entity_name, obj_name)
 
     def close(self):
-        pass
+        self.conn.http.clear()
+
+
+    def _get_audit_slot(self):
+        path = os.path.join(self.root, '_audit', '_slots')
+        retries = 10
+        while retries > 0:
+            try:
+                res = self.conn.read(path)
+            except:
+                # we do not check for existence, on purpose
+                self.conn.write(path, 0)
+                return "0"
+            slot = (int(res.value) + 1) % self.AUDIT_LOG_LENGTH
+            try:
+                self.conn.write(path, slot, prevIndex = res.modifiedIndex)
+                return str(slot)
+            except:
+                retries -= 1
+        #we could not apply for a slot, it seems; just give up writing
+        return None
+
+    def add_audit(self, entity_name, obj_name, operation,
+                  data, auth_ctx, session):
+        """Add an entry in the audit log."""
+        if data is not None:
+            data = self.schema.get_entity(entity_name).to_net(data)
+        slot = self._get_audit_slot()
+        if slot is None:
+            return
+        path = os.path.join(self.root, '_audit', slot)
+
+        audit = {
+            'entity': entity_name,
+            'object': obj_name,
+            'op': operation,
+            'user': auth_ctx.get_username(),
+            'data': base64.b64encode(json.dumps(data)) if data else None,
+            'ts': time.time()
+        }
+        self.conn.write(path, json.dumps(audit))
+        try:
+            self.conn.write(path, json.dumps(audit), prevExists=False)
+        except ValueError:
+            pass
+
+    def get_audit(self, query, session):
+        """Query the audit log."""
+        # This is actually very expensive and this is why we have a limited number of slots
+        path = os.path.join(self.root, '_audit')
+        data = self.conn.read(path, recursive=True)
+        log = []
+
+        for result in data.kvs:
+            obj = json.loads(result.value)
+            if obj['data']:
+                obj['data'] = base64.b64decode(obj['data'])
+            matches = True
+
+            for (k,v) in query.iteritems():
+                if k not in obj:
+                    matches = False
+                    break
+                if obj[k] != v:
+                    matches = False
+                    break
+
+            if matches:
+                log.append(obj)
+        return log
diff --git a/configdb/tests/db_api_test_base.py b/configdb/tests/db_api_test_base.py
index 3fa992f..e5302ec 100644
--- a/configdb/tests/db_api_test_base.py
+++ b/configdb/tests/db_api_test_base.py
@@ -87,6 +87,7 @@ class DbApiTestBase(object):
             self.api.find('host',
                           {'roles': {'type': 'eq', 'value': 'role1'}},
                           self.ctx))
+
         self.assertEquals(1, len(result))
         self.assertEquals('obz', result[0].name)
 
@@ -157,7 +158,7 @@ class DbApiTestBase(object):
 
     def test_create_with_relations(self):
         host_data = {'name': 'utz', 'ip': '2.3.4.5',
-                     'roles': ['role1']}
+                     'roles': ['a/i']}
         self.assertTrue(self.api.create('host', host_data, self.ctx))
 
     def test_create_unknown_entity(self):
@@ -194,13 +195,17 @@ class DbApiTestBase(object):
     def test_update_adds_audit_log(self):
         if not self.api.db.AUDIT_SUPPORT:
             return
+        old_result = list(self.api.get_audit({'entity': 'host',
+                                              'object': 'obz',
+                                              'op': 'update'}, self.ctx))
+
         result = self.api.update('host', 'obz', {'ip': '2.3.4.5'}, self.ctx)
         self.assertTrue(result)
 
         result = list(self.api.get_audit({'entity': 'host',
                                           'object': 'obz',
                                           'op': 'update'}, self.ctx))
-        self.assertEquals(1, len(result))
+        self.assertEquals(1, len(result) - len(old_result))
 
     # FIXME: should renaming even work?
     #
@@ -285,13 +290,17 @@ class DbApiTestBase(object):
     def test_delete_adds_audit_log(self):
         if not self.api.db.AUDIT_SUPPORT:
             return
+        old_result = list(self.api.get_audit({'entity': 'host',
+                                              'object': 'obz',
+                                              'op': 'delete'}, self.ctx))
+
         self.assertTrue(
             self.api.delete('host', 'obz', self.ctx))
 
         result = list(self.api.get_audit({'entity': 'host',
                                           'object': 'obz',
                                           'op': 'delete'}, self.ctx))
-        self.assertEquals(1, len(result))
+        self.assertEquals(1, len(result) - len(old_result))
 
     def test_delete_twice(self):
         self.assertTrue(
diff --git a/configdb/tests/db_interface_test_base.py b/configdb/tests/db_interface_test_base.py
index ee77b11..681cd31 100644
--- a/configdb/tests/db_interface_test_base.py
+++ b/configdb/tests/db_interface_test_base.py
@@ -18,7 +18,7 @@ class DbInterfaceTestBase(object):
             s.add(a)
             s.add(b)
         return db
-        
+
     def test_init_ok(self):
         db = self.init_db()
         self.assertTrue(db is not None)
@@ -94,4 +94,3 @@ class DbInterfaceTestBase(object):
         r = self._find(db, 'host', {'roles': {'type': 'eq', 'value':'zzzz'}})
         self.assertEquals(0, len(r))
         db.close()
-                
diff --git a/configdb/tests/test_db_api_etcd.py b/configdb/tests/test_db_api_etcd.py
new file mode 100644
index 0000000..525fa5d
--- /dev/null
+++ b/configdb/tests/test_db_api_etcd.py
@@ -0,0 +1,21 @@
+import os
+from nose.exc import SkipTest
+try:
+    from configdb.db.interface import etcd_interface
+    if os.getenv('SKIP_ETCD') is not None:
+        raise SkipTest('Etcd tests disabled')
+except ImportError:
+    raise SkipTest('Etcd not found')
+
+from configdb.tests import *
+from configdb.tests.db_api_test_base import DbApiTestBase
+
+
+@attr('etcd')
+class EtcdInterfaceTest(DbApiTestBase, TestBase):
+
+    TESTROOT = '/configdb-test-%d' % os.getpid()
+
+    def init_db(self):
+        return etcd_interface.EtcdInterface(
+            'http://127.0.0.1:4001', self.get_schema(), self.TESTROOT)
diff --git a/configdb/tests/test_etcd_interface.py b/configdb/tests/test_etcd_interface.py
new file mode 100644
index 0000000..c4c4817
--- /dev/null
+++ b/configdb/tests/test_etcd_interface.py
@@ -0,0 +1,21 @@
+import os
+from nose.exc import SkipTest
+try:
+    from configdb.db.interface import etcd_interface
+    if os.getenv('SKIP_ETCD') is not None:
+        raise SkipTest('Etcd tests disabled')
+except ImportError:
+    raise SkipTest('Etcd not found')
+
+from configdb.tests import *
+from configdb.tests.db_interface_test_base import DbInterfaceTestBase
+
+
+@attr('etcd')
+class EtcdInterfaceTest(DbInterfaceTestBase, TestBase):
+
+    TESTROOT = '/configdb-test-%d' % os.getpid()
+
+    def init_db(self):
+        return etcd_interface.EtcdInterface(
+            'http://127.0.0.1:4001', self.get_schema(), self.TESTROOT)
diff --git a/configdb/tests/test_schema.py b/configdb/tests/test_schema.py
index 66d4fdb..1cee6b1 100644
--- a/configdb/tests/test_schema.py
+++ b/configdb/tests/test_schema.py
@@ -213,7 +213,7 @@ class SchemaSerializationTest(TestBase):
     def test_deserialization_error(self):
         data = {'stamp': 'not-a-timestamp'}
         self.assertRaises(
-            ValueError,
+            TypeError,
             self.ent.from_net, data)
 
 
@@ -325,4 +325,3 @@ class SchemaAclTest(TestBase):
             ent, ['name', 'role'],
             acl.AuthContext('testuser'),
             'w', None)
-
-- 
GitLab