From d0234c764d282dbb6a817fad3aec6f288970607c Mon Sep 17 00:00:00 2001
From: joe <joe@autistici.org>
Date: Sun, 12 May 2013 12:25:55 +0200
Subject: [PATCH] Fixing issue with two-way referenced relations in the
 sqlalchemy interface (closes issue #13)

---
 configdb/db/interface/sa_generator.py | 20 ++++++++++++++++----
 configdb/db/interface/sa_interface.py |  4 ++--
 2 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/configdb/db/interface/sa_generator.py b/configdb/db/interface/sa_generator.py
index 48a1ebc..ba45157 100644
--- a/configdb/db/interface/sa_generator.py
+++ b/configdb/db/interface/sa_generator.py
@@ -8,6 +8,7 @@ class SqlAlchemyGenerator(object):
 
     def __init__(self, schema_obj):
         self.schema = schema_obj
+        self.defined_relations = []
 
     def _audit_table_def(self):
         return """
@@ -76,6 +77,13 @@ class %(class_name)s(Base):
             field.name, ', '.join(args))
 
     def _sa_field_relation_def(self, entity, field):
+        assoc_table = self._sa_assoc_table_name(field)
+        if assoc_table in self.defined_relations:
+            # In this case, we don't need a backref (it already exists)
+            return '%s = relationship("%s", secondary=%s)' % (
+                field.name,
+                field.remote_name.capitalize(),
+                assoc_table + '_table')
         return '%s = relationship("%s", secondary=%s, backref="%s")' % (
             field.name,
             field.remote_name.capitalize(),
@@ -83,12 +91,17 @@ class %(class_name)s(Base):
             pl.plural(entity.name))
 
     def _sa_field_assoc_table_def(self, entity, field):
+        table_name = self._sa_assoc_table_name(field)
+        if table_name in self.defined_relations:
+            #we have a 2-way used relation, and do not want to re-define it
+            return ""
+        self.defined_relations.append(table_name)
         return """
 %s_table = Table('%s', Base.metadata,
     Column('left_id', Integer, ForeignKey('%s.id')),
     Column('right_id', Integer, ForeignKey('%s.id')))""" % (
-            self._sa_assoc_table_name(field),
-            self._sa_assoc_table_name(field),
+            table_name,
+            table_name,
             field.local_name,
             field.remote_name)
 
@@ -96,7 +109,7 @@ class %(class_name)s(Base):
         tbls = sorted([field.local_name, field.remote_name])
         tbls.append(field.relation_id)
         return '%s_%s_assoc_%s' % tuple(tbls)
-        
+
     def generate(self):
         out = ['from sqlalchemy import *',
                'from sqlalchemy.orm import *',
@@ -106,4 +119,3 @@ class %(class_name)s(Base):
             out.append(self._sa_entity_aux_tables(ent))
             out.append(self._sa_entity_def(ent))
         return '\n'.join(out)
-
diff --git a/configdb/db/interface/sa_interface.py b/configdb/db/interface/sa_interface.py
index b39442b..00600fa 100644
--- a/configdb/db/interface/sa_interface.py
+++ b/configdb/db/interface/sa_interface.py
@@ -56,7 +56,7 @@ class SqlAlchemyDbInterface(base.DbInterface):
         self.engine = create_engine(uri, pool_recycle=1800)
         self.Session.configure(bind=self.engine)
         Base.metadata.create_all(self.engine)
-        
+
     def _load_schema(self):
         with tempfile.NamedTemporaryFile() as schema_file:
             schema_gen = sa_generator.SqlAlchemyGenerator(self._schema)
@@ -94,7 +94,7 @@ class SqlAlchemyDbInterface(base.DbInterface):
     def get_by_name(self, entity_name, object_name, session):
         return session.query(self._get_class(entity_name)).filter_by(
             name=object_name).first()
-    
+
     def find(self, entity_name, query, session):
         classobj = self._get_class(entity_name)
         entity = self._schema.get_entity(entity_name)
-- 
GitLab