From d0234c764d282dbb6a817fad3aec6f288970607c Mon Sep 17 00:00:00 2001 From: joe <joe@autistici.org> Date: Sun, 12 May 2013 12:25:55 +0200 Subject: [PATCH] Fixing issue with two-way referenced relations in the sqlalchemy interface (closes issue #13) --- configdb/db/interface/sa_generator.py | 20 ++++++++++++++++---- configdb/db/interface/sa_interface.py | 4 ++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/configdb/db/interface/sa_generator.py b/configdb/db/interface/sa_generator.py index 48a1ebc..ba45157 100644 --- a/configdb/db/interface/sa_generator.py +++ b/configdb/db/interface/sa_generator.py @@ -8,6 +8,7 @@ class SqlAlchemyGenerator(object): def __init__(self, schema_obj): self.schema = schema_obj + self.defined_relations = [] def _audit_table_def(self): return """ @@ -76,6 +77,13 @@ class %(class_name)s(Base): field.name, ', '.join(args)) def _sa_field_relation_def(self, entity, field): + assoc_table = self._sa_assoc_table_name(field) + if assoc_table in self.defined_relations: + # In this case, we don't need a backref (it already exists) + return '%s = relationship("%s", secondary=%s)' % ( + field.name, + field.remote_name.capitalize(), + assoc_table + '_table') return '%s = relationship("%s", secondary=%s, backref="%s")' % ( field.name, field.remote_name.capitalize(), @@ -83,12 +91,17 @@ class %(class_name)s(Base): pl.plural(entity.name)) def _sa_field_assoc_table_def(self, entity, field): + table_name = self._sa_assoc_table_name(field) + if table_name in self.defined_relations: + #we have a 2-way used relation, and do not want to re-define it + return "" + self.defined_relations.append(table_name) return """ %s_table = Table('%s', Base.metadata, Column('left_id', Integer, ForeignKey('%s.id')), Column('right_id', Integer, ForeignKey('%s.id')))""" % ( - self._sa_assoc_table_name(field), - self._sa_assoc_table_name(field), + table_name, + table_name, field.local_name, field.remote_name) @@ -96,7 +109,7 @@ class %(class_name)s(Base): tbls = sorted([field.local_name, field.remote_name]) tbls.append(field.relation_id) return '%s_%s_assoc_%s' % tuple(tbls) - + def generate(self): out = ['from sqlalchemy import *', 'from sqlalchemy.orm import *', @@ -106,4 +119,3 @@ class %(class_name)s(Base): out.append(self._sa_entity_aux_tables(ent)) out.append(self._sa_entity_def(ent)) return '\n'.join(out) - diff --git a/configdb/db/interface/sa_interface.py b/configdb/db/interface/sa_interface.py index b39442b..00600fa 100644 --- a/configdb/db/interface/sa_interface.py +++ b/configdb/db/interface/sa_interface.py @@ -56,7 +56,7 @@ class SqlAlchemyDbInterface(base.DbInterface): self.engine = create_engine(uri, pool_recycle=1800) self.Session.configure(bind=self.engine) Base.metadata.create_all(self.engine) - + def _load_schema(self): with tempfile.NamedTemporaryFile() as schema_file: schema_gen = sa_generator.SqlAlchemyGenerator(self._schema) @@ -94,7 +94,7 @@ class SqlAlchemyDbInterface(base.DbInterface): def get_by_name(self, entity_name, object_name, session): return session.query(self._get_class(entity_name)).filter_by( name=object_name).first() - + def find(self, entity_name, query, session): classobj = self._get_class(entity_name) entity = self._schema.get_entity(entity_name) -- GitLab