Refactor docs, add merge_references

This commit is contained in:
Konsta Vesterinen
2014-05-14 17:20:34 +03:00
parent 04a5292425
commit 42650d9f32
15 changed files with 449 additions and 450 deletions

View File

@@ -4,12 +4,13 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.26.1 (2014-05-xx)
0.26.1 (2014-05-14)
^^^^^^^^^^^^^^^^^^^
- Added get_bind
- Added group_foreign_keys
- Added get_mapper
- Added merge_references
0.26.0 (2014-05-07)

View File

@@ -5,18 +5,6 @@ Database helpers
.. module:: sqlalchemy_utils.functions
is_indexed_foreign_key
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: is_indexed_foreign_key
non_indexed_foreign_keys
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: non_indexed_foreign_keys
database_exists
^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,40 @@
Foreign key helpers
===================
.. module:: sqlalchemy_utils.functions
dependent_objects
^^^^^^^^^^^^^^^^^
.. autofunction:: dependent_objects
get_referencing_foreign_keys
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: get_referencing_foreign_keys
group_foreign_keys
^^^^^^^^^^^^^^^^^^
.. autofunction:: group_foreign_keys
is_indexed_foreign_key
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: is_indexed_foreign_key
merge_references
^^^^^^^^^^^^^^^^
.. autofunction:: merge_references
non_indexed_foreign_keys
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: non_indexed_foreign_keys

View File

@@ -1,5 +1,5 @@
Generic relationship
====================
Generic relationships
=====================
Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model.

View File

@@ -16,6 +16,7 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ
decorators
generic_relationship
database_helpers
model_helpers
foreign_key_helpers
orm_helpers
utility_classes
license

View File

@@ -1,15 +1,9 @@
Model helpers
=============
ORM helpers
===========
.. module:: sqlalchemy_utils.functions
dependent_objects
^^^^^^^^^^^^^^^^^
.. autofunction:: dependent_objects
escape_like
^^^^^^^^^^^
@@ -46,24 +40,12 @@ get_primary_keys
.. autofunction:: get_primary_keys
get_referencing_foreign_keys
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: get_referencing_foreign_keys
get_tables
^^^^^^^^^^
.. autofunction:: get_tables
group_foreign_keys
^^^^^^^^^^^^^^^^^^
.. autofunction:: group_foreign_keys
query_entities
^^^^^^^^^^^^^^

View File

@@ -44,7 +44,7 @@ for name, requirements in extras_require.items():
setup(
name='SQLAlchemy-Utils',
version='0.26.0',
version='0.26.1',
url='https://github.com/kvesteri/sqlalchemy-utils',
license='BSD',
author='Konsta Vesterinen, Ryan Leckey, Janne Vanhala, Vesa Uimonen',

View File

@@ -20,6 +20,7 @@ from .functions import (
get_tables,
group_foreign_keys,
identity,
merge_references,
mock_engine,
naturally_equivalent,
render_expression,
@@ -32,7 +33,6 @@ from .listeners import (
force_auto_coercion,
force_instant_defaults
)
from .merge import merge, Merger
from .generic import generic_relationship
from .proxy_dict import ProxyDict, proxy_dict
from .query_chain import QueryChain
@@ -67,7 +67,7 @@ from .types import (
)
__version__ = '0.26.0'
__version__ = '0.26.1'
__all__ = (
@@ -95,7 +95,7 @@ __all__ = (
group_foreign_keys,
identity,
instrumented_list,
merge,
merge_references,
mock_engine,
naturally_equivalent,
proxy_dict,
@@ -120,7 +120,6 @@ __all__ = (
IPAddressType,
JSONType,
LocaleType,
Merger,
NumericRangeType,
Password,
PasswordType,

View File

@@ -8,20 +8,23 @@ from .database import (
drop_database,
escape_like,
is_auto_assigned_date_column,
)
from .foreign_keys import (
dependent_objects,
get_referencing_foreign_keys,
group_foreign_keys,
is_indexed_foreign_key,
merge_references,
non_indexed_foreign_keys,
)
from .orm import (
dependent_objects,
get_bind,
get_columns,
get_declarative_base,
get_mapper,
get_primary_keys,
get_referencing_foreign_keys,
get_tables,
getdotattr,
group_foreign_keys,
has_changes,
identity,
naturally_equivalent,

View File

@@ -1,4 +1,3 @@
from collections import defaultdict
from sqlalchemy.engine.url import make_url
import sqlalchemy as sa
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
@@ -186,55 +185,3 @@ def drop_database(url):
else:
text = "DROP DATABASE %s" % database
engine.execute(text)
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
if metadata.bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not is_indexed_foreign_key(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def is_indexed_foreign_key(constraint):
"""
Whether or not given foreign key constraint's columns have been indexed.
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set(
column.name for column in index.columns
)
if index_column_names == set(constraint.columns):
return True
return False

View File

@@ -0,0 +1,350 @@
from collections import defaultdict
from itertools import groupby
import six
import sqlalchemy as sa
from sqlalchemy.engine import reflection
from sqlalchemy.orm import object_session, mapperlib
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .orm import get_mapper, get_tables
from ..query_chain import QueryChain
def get_foreign_key_values(fk, obj):
return {
fk.constraint.columns[index].key:
getattr(obj, element.column.key)
for
index, element
in
enumerate(fk.constraint.elements)
}
def group_foreign_keys(foreign_keys):
"""
Return a groupby iterator that groups given foreign keys by table.
:param foreign_keys: a sequence of foreign keys
::
foreign_keys = get_referencing_foreign_keys(User)
for table, fks in group_foreign_keys(foreign_keys):
# do something
pass
.. seealso:: :func:`get_referencing_foreign_keys`
.. versionadded: 0.26.1
"""
foreign_keys = sorted(
foreign_keys, key=lambda key: key.constraint.table.name
)
return groupby(foreign_keys, lambda key: key.constraint.table)
def get_referencing_foreign_keys(mixed):
"""
Returns referencing foreign keys for given Table object or declarative
class.
:param mixed:
SA Table object or SA declarative class
::
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
get_referencing_foreign_keys(User.__table__)
This function also understands inheritance. This means it returns
all foreign keys that reference any table in the class inheritance tree.
Let's say you have three classes which use joined table inheritance,
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
TextItem.
::
# This will check all foreign keys that reference either article table
# or textitem table.
get_referencing_foreign_keys(Article)
.. seealso:: :func:`get_tables`
"""
if isinstance(mixed, sa.Table):
tables = [mixed]
else:
tables = get_tables(mixed)
referencing_foreign_keys = set()
for table in mixed.metadata.tables.values():
if table not in tables:
for constraint in table.constraints:
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
for fk in constraint.elements:
if any(fk.references(t) for t in tables):
referencing_foreign_keys.add(fk)
return referencing_foreign_keys
def merge_references(from_, to, foreign_keys=None):
"""
Merge the references of an entity into another entity.
Consider the following models::
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(name=%r)' % self.name
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.Unicode(255))
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
Now lets add some data::
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack)
self.session.add_all([
john,
jack,
post,
post2
])
self.session.commit()
If we wanted to merge all John's references to Jack it would be as easy as
::
merge_references(john, jack)
self.session.commit()
post.author # User(name='Jack')
post2.author # User(name='Jack')
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param foreign_keys: A sequence of foreign keys. By default this is None
indicating all referencing foreign keys should be used.
.. seealso: :func:`dependent_objects`
.. versionadded: 0.26.1
"""
if from_.__tablename__ != to.__tablename__:
raise TypeError('The tables of given arguments do not match.')
session = object_session(from_)
foreign_keys = get_referencing_foreign_keys(from_)
for fk in foreign_keys:
old_values = get_foreign_key_values(fk, from_)
new_values = get_foreign_key_values(fk, to)
criteria = (
getattr(fk.constraint.table.c, key) == value
for key, value in six.iteritems(old_values)
)
try:
mapper = get_mapper(fk.constraint.table)
except ValueError:
query = (
fk.constraint.table
.update()
.where(sa.and_(*criteria))
.values(new_values)
)
session.execute(query)
else:
print old_values, new_values
(
session.query(mapper.class_)
.filter_by(**old_values)
.update(
new_values,
'evaluate'
)
)
def dependent_objects(obj, foreign_keys=None):
"""
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
through all dependent objects for given SQLAlchemy object.
Consider a User object is referenced in various articles and also in
various orders. Getting all these dependent objects is as easy as:
::
from sqlalchemy_utils import dependent_objects
dependent_objects(user)
If you expect an object to have lots of dependent_objects it might be good
to limit the results::
dependent_objects(user).limit(5)
The common use case is checking for all restrict dependent objects before
deleting parent object and inform the user if there are dependent objects
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
it will lead to nasty IntegrityErrors being raised.
In the following example we delete given user if it doesn't have any
foreign key restricted dependent objects.
::
from sqlalchemy_utils import get_referencing_foreign_keys
user = session.query(User).get(some_user_id)
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(User)
# On most databases RESTRICT is the default mode hence we
# check for None values also
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
)
if deps:
# Do something to inform the user
pass
else:
session.delete(user)
:param obj: SQLAlchemy declarative model object
:param foreign_keys:
A sequence of foreign keys to use for searching the dependent_objects
for given object. By default this is None, indicating that all foreign
keys referencing the object will be used.
.. note::
This function does not support exotic mappers that use multiple tables
.. seealso:: :func:`get_referencing_foreign_keys`
.. seealso:: :func:`merge_references`
.. versionadded: 0.26.0
"""
if foreign_keys is None:
foreign_keys = get_referencing_foreign_keys(obj)
session = object_session(obj)
chain = QueryChain([])
classes = obj.__class__._decl_class_registry
for table, keys in group_foreign_keys(foreign_keys):
for class_ in classes.values():
if hasattr(class_, '__table__') and class_.__table__ == table:
criteria = []
visited_constraints = []
for key in keys:
if key.constraint not in visited_constraints:
visited_constraints.append(key.constraint)
subcriteria = [
getattr(class_, column.key) ==
getattr(
obj,
key.constraint.elements[index].column.key
)
for index, column
in enumerate(key.constraint.columns)
]
criteria.append(sa.and_(*subcriteria))
query = session.query(class_).filter(
sa.or_(
*criteria
)
)
chain.queries.append(query)
return chain
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
if metadata.bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not is_indexed_foreign_key(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def is_indexed_foreign_key(constraint):
"""
Whether or not given foreign key constraint's columns have been indexed.
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set(
column.name for column in index.columns
)
if index_column_names == set(constraint.columns):
return True
return False

View File

@@ -61,8 +61,12 @@ def get_mapper(mixed):
]
if len(mappers) > 1:
raise ValueError(
"Could not get mapper for '%r'. Multiple mappers found."
% mixed
"Multiple mappers found for table '%s'."
% mixed.name
)
elif not mappers:
raise ValueError(
"Could not get mapper for table '%s'."
)
else:
return mappers[0]
@@ -104,188 +108,6 @@ def get_bind(obj):
return conn
def dependent_objects(obj, foreign_keys=None):
"""
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
through all dependent objects for given SQLAlchemy object.
Consider a User object is referenced in various articles and also in
various orders. Getting all these dependent objects is as easy as:
::
from sqlalchemy_utils import dependent_objects
dependent_objects(user)
If you expect an object to have lots of dependent_objects it might be good
to limit the results::
dependent_objects(user).limit(5)
The common use case is checking for all restrict dependent objects before
deleting parent object and inform the user if there are dependent objects
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
it will lead to nasty IntegrityErrors being raised.
In the following example we delete given user if it doesn't have any
foreign key restricted dependent objects.
::
from sqlalchemy_utils import get_referencing_foreign_keys
user = session.query(User).get(some_user_id)
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(User)
# On most databases RESTRICT is the default mode hence we
# check for None values also
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
)
if deps:
# Do something to inform the user
pass
else:
session.delete(user)
:param obj: SQLAlchemy declarative model object
:param foreign_keys:
A sequence of foreign keys to use for searching the dependent_objects
for given object. By default this is None, indicating that all foreign
keys referencing the object will be used.
.. note::
This function does not support exotic mappers that use multiple tables
.. seealso:: :func:`get_referencing_foreign_keys`
.. versionadded: 0.26.0
"""
if foreign_keys is None:
foreign_keys = get_referencing_foreign_keys(obj)
session = object_session(obj)
chain = QueryChain([])
classes = obj.__class__._decl_class_registry
for table, keys in group_foreign_keys(foreign_keys):
for class_ in classes.values():
if hasattr(class_, '__table__') and class_.__table__ == table:
criteria = []
visited_constraints = []
for key in keys:
if key.constraint not in visited_constraints:
visited_constraints.append(key.constraint)
subcriteria = [
getattr(class_, column.key) ==
getattr(
obj,
key.constraint.elements[index].column.key
)
for index, column
in enumerate(key.constraint.columns)
]
criteria.append(sa.and_(*subcriteria))
query = session.query(class_).filter(
sa.or_(
*criteria
)
)
chain.queries.append(query)
return chain
def group_foreign_keys(foreign_keys):
"""
Return a groupby iterator that groups given foreign keys by table.
:param foreign_keys: a sequence of foreign keys
::
foreign_keys = get_referencing_foreign_keys(User)
for table, fks in group_foreign_keys(foreign_keys):
# do something
pass
.. also:: :func:`get_referencing_foreign_keys`
.. versionadded: 0.26.1
"""
foreign_keys = sorted(
foreign_keys, key=lambda key: key.constraint.table.name
)
return groupby(foreign_keys, lambda key: key.constraint.table)
def get_referencing_foreign_keys(mixed):
"""
Returns referencing foreign keys for given Table object or declarative
class.
:param mixed:
SA Table object or SA declarative class
::
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
get_referencing_foreign_keys(User.__table__)
This function also understands inheritance. This means it returns
all foreign keys that reference any table in the class inheritance tree.
Let's say you have three classes which use joined table inheritance,
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
TextItem.
::
# This will check all foreign keys that reference either article table
# or textitem table.
get_referencing_foreign_keys(Article)
.. seealso:: :func:`get_tables`
"""
if isinstance(mixed, sa.Table):
tables = [mixed]
else:
tables = get_tables(mixed)
referencing_foreign_keys = set()
for table in mixed.metadata.tables.values():
if table not in tables:
for constraint in table.constraints:
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
for fk in constraint.elements:
if any(fk.references(t) for t in tables):
referencing_foreign_keys.add(fk)
return referencing_foreign_keys
def get_primary_keys(mixed):
"""
Return an OrderedDict of all primary keys for given Table object,

View File

@@ -1,124 +0,0 @@
import six
import sqlalchemy as sa
from sqlalchemy.engine import reflection
from sqlalchemy.orm import object_session, mapperlib
def dependent_foreign_keys(model_class):
"""
Returns dependent foreign keys as dicts for given model class.
** Experimental function **
"""
session = object_session(model_class)
engine = session.bind
inspector = reflection.Inspector.from_engine(engine)
table_names = inspector.get_table_names()
dependent_foreign_keys = {}
for table_name in table_names:
fks = inspector.get_foreign_keys(table_name)
if fks:
dependent_foreign_keys[table_name] = []
for fk in fks:
if fk['referred_table'] == model_class.__tablename__:
dependent_foreign_keys[table_name].append(fk)
return dependent_foreign_keys
class Merger(object):
def memory_merge(self, session, table_name, old_values, new_values):
# try to fetch mappers for given table and update in memory objects as
# well as database table
found = False
for mapper in mapperlib._mapper_registry:
class_ = mapper.class_
if table_name == class_.__table__.name:
try:
(
session.query(mapper.class_)
.filter_by(**old_values)
.update(
new_values,
'fetch'
)
)
except sa.exc.IntegrityError:
pass
found = True
return found
def raw_merge(self, session, table, old_values, new_values):
conditions = []
for key, value in six.iteritems(old_values):
conditions.append(getattr(table.c, key) == value)
sql = (
table
.update()
.where(sa.and_(
*conditions
))
.values(
new_values
)
)
try:
session.execute(sql)
except sa.exc.IntegrityError:
pass
def merge_update(self, table_name, from_, to, foreign_key):
session = object_session(from_)
constrained_columns = foreign_key['constrained_columns']
referred_columns = foreign_key['referred_columns']
metadata = from_.metadata
table = metadata.tables[table_name]
new_values = {}
for index, column in enumerate(constrained_columns):
new_values[column] = getattr(
to, referred_columns[index]
)
old_values = {}
for index, column in enumerate(constrained_columns):
old_values[column] = getattr(
from_, referred_columns[index]
)
if not self.memory_merge(session, table_name, old_values, new_values):
self.raw_merge(session, table, old_values, new_values)
def __call__(self, from_, to):
"""
Merges entity into another entity. After merging deletes the from_
argument entity.
"""
if from_.__tablename__ != to.__tablename__:
raise Exception()
session = object_session(from_)
foreign_keys = dependent_foreign_keys(from_)
for table_name in foreign_keys:
for foreign_key in foreign_keys[table_name]:
self.merge_update(table_name, from_, to, foreign_key)
session.delete(from_)
def merge(from_, to, merger=Merger):
"""
Merges entity into another entity. After merging deletes the from_ argument
entity.
After merging the from_ entity is deleted from database.
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
class
"""
return Merger()(from_, to)

View File

@@ -71,3 +71,18 @@ class TestGetMapperWithMultipleMappersFound(object):
alias = sa.orm.aliased(self.Building.__table__)
with raises(ValueError):
get_mapper(alias)
class TestGetMapperForTableWithoutMapper(object):
def setup_method(self, method):
metadata = sa.MetaData()
self.building = sa.Table('building', metadata)
def test_table(self):
with raises(ValueError):
get_mapper(self.building)
def test_table_alias(self):
alias = sa.orm.aliased(self.building)
with raises(ValueError):
get_mapper(alias)

View File

@@ -1,10 +1,10 @@
import sqlalchemy as sa
from sqlalchemy_utils import merge
from sqlalchemy_utils import merge_references
from tests import TestCase
class TestMerge(TestCase):
class TestMergeReferences(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
@@ -36,21 +36,29 @@ class TestMerge(TestCase):
self.session.add(post)
self.session.add(post2)
self.session.commit()
merge(john, jack)
merge_references(john, jack)
self.session.commit()
assert post.author == jack
assert post2.author == jack
def test_deletes_from_entity(self):
def test_object_merging_whenever_possible(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack)
self.session.add(john)
self.session.add(jack)
self.session.add(post)
self.session.add(post2)
self.session.commit()
merge(john, jack)
assert john in self.session.deleted
# Load the author for post
assert post.author_id == john.id
merge_references(john, jack)
assert post.author_id == jack.id
assert post2.author_id == jack.id
class TestMergeManyToManyAssociations(TestCase):
class TestMergeReferencesWithManyToManyAssociations(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
@@ -88,7 +96,7 @@ class TestMergeManyToManyAssociations(TestCase):
self.User = User
self.Team = Team
def test_when_association_only_exists_in_from_entity(self):
def test_supports_associations(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
@@ -96,29 +104,12 @@ class TestMergeManyToManyAssociations(TestCase):
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
merge_references(john, jack)
assert john not in team.members
assert jack in team.members
# def test_when_association_exists_in_both(self):
# john = self.User(name=u'John')
# jack = self.User(name=u'Jack')
# team = self.Team(name=u'Team')
# team.members.append(john)
# team.members.append(jack)
# self.session.add(john)
# self.session.add(jack)
# self.session.commit()
# merge(john, jack)
# assert john not in team.members
# assert jack in team.members
# count = self.session.execute(
# 'SELECT COUNT(1) FROM team_member'
# ).fetchone()[0]
# assert count == 1
class TestMergeManyToManyAssociationObjects(TestCase):
class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
def create_models(self):
class Team(self.Base):
__tablename__ = 'team'
@@ -164,7 +155,7 @@ class TestMergeManyToManyAssociationObjects(TestCase):
self.TeamMember = TeamMember
self.Team = Team
def test_when_association_only_exists_in_from_entity(self):
def test_supports_associations(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
@@ -173,24 +164,8 @@ class TestMergeManyToManyAssociationObjects(TestCase):
self.session.add(jack)
self.session.add(team)
self.session.commit()
merge(john, jack)
merge_references(john, jack)
self.session.commit()
users = [member.user for member in team.members]
assert john not in users
assert jack in users
# def test_when_association_exists_in_both(self):
# john = self.User(name=u'John')
# jack = self.User(name=u'Jack')
# team = self.Team(name=u'Team')
# team.members.append(self.TeamMember(user=john))
# team.members.append(self.TeamMember(user=jack))
# self.session.add(john)
# self.session.add(jack)
# self.session.add(team)
# self.session.commit()
# merge(john, jack)
# users = [member.user for member in team.members]
# assert john not in users
# assert jack in users
# assert self.session.query(self.TeamMember).count() == 1