import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import ( escape_like, sort_query, InstrumentedList, PhoneNumber, PhoneNumberType, merge ) class TestCase(object): def setup_method(self, method): self.engine = create_engine('sqlite:///:memory:') self.Base = declarative_base() self.create_models() self.Base.metadata.create_all(self.engine) Session = sessionmaker(bind=self.engine) self.session = Session() def teardown_method(self, method): self.session.close_all() self.Base.metadata.drop_all(self.engine) self.engine.dispose() def create_models(self): class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) phone_number = sa.Column(PhoneNumberType()) class Category(self.Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) class Article(self.Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) category = sa.orm.relationship( Category, primaryjoin=category_id == Category.id, backref=sa.orm.backref( 'articles', collection_class=InstrumentedList ) ) self.User = User self.Category = Category self.Article = Article class TestInstrumentedList(TestCase): def test_any_returns_true_if_member_has_attr_defined(self): category = self.Category() category.articles.append(self.Article()) category.articles.append(self.Article(name=u'some name')) assert category.articles.any('name') def test_any_returns_false_if_no_member_has_attr_defined(self): category = self.Category() category.articles.append(self.Article()) assert not category.articles.any('name') class TestEscapeLike(TestCase): def test_escapes_wildcards(self): assert escape_like('_*%') == '*_***%' class TestSortQuery(TestCase): def test_without_sort_param_returns_the_query_object_untouched(self): query = self.session.query(self.Article) sorted_query = sort_query(query, '') assert query == sorted_query def test_sort_by_column_ascending(self): query = sort_query(self.session.query(self.Article), 'name') assert 'ORDER BY article.name ASC' in str(query) def test_sort_by_column_descending(self): query = sort_query(self.session.query(self.Article), '-name') assert 'ORDER BY article.name DESC' in str(query) def test_skips_unknown_columns(self): query = self.session.query(self.Article) sorted_query = sort_query(query, '-unknown') assert query == sorted_query def test_sort_by_calculated_value_ascending(self): query = self.session.query( self.Category, sa.func.count(self.Article.id).label('articles') ) query = sort_query(query, 'articles') assert 'ORDER BY articles ASC' in str(query) def test_sort_by_calculated_value_descending(self): query = self.session.query( self.Category, sa.func.count(self.Article.id).label('articles') ) query = sort_query(query, '-articles') assert 'ORDER BY articles DESC' in str(query) def test_sort_by_joined_table_column(self): query = self.session.query(self.Article).join(self.Article.category) sorted_query = sort_query(query, 'category-name') assert 'category.name ASC' in str(sorted_query) class TestPhoneNumber(object): def setup_method(self, method): self.valid_phone_numbers = [ '040 1234567', '+358 401234567', '09 2501234', '+358 92501234', '0800 939393', '09 4243 0456', '0600 900 500' ] self.invalid_phone_numbers = [ 'abc', '+040 1234567', '0111234567', '358' ] def test_valid_phone_numbers(self): for raw_number in self.valid_phone_numbers: phone_number = PhoneNumber(raw_number, 'FI') assert phone_number.is_valid_number() def test_invalid_phone_numbers(self): for raw_number in self.invalid_phone_numbers: try: phone_number = PhoneNumber(raw_number, 'FI') assert not phone_number.is_valid_number() except: pass def test_phone_number_attributes(self): phone_number = PhoneNumber('+358401234567') assert phone_number.e164 == u'+358401234567' assert phone_number.international == u'+358 40 1234567' assert phone_number.national == u'040 1234567' class TestPhoneNumberType(TestCase): def setup_method(self, method): super(TestPhoneNumberType, self).setup_method(method) self.phone_number = PhoneNumber( '040 1234567', 'FI' ) self.user = self.User() self.user.name = u'Someone' self.user.phone_number = self.phone_number self.session.add(self.user) self.session.commit() def test_query_returns_phone_number_object(self): queried_user = self.session.query(self.User).first() assert queried_user.phone_number == self.phone_number def test_phone_number_is_stored_as_string(self): result = self.session.execute( 'SELECT phone_number FROM user WHERE id=:param', {'param': self.user.id} ) assert result.first()[0] == u'+358401234567' class DatabaseTestCase(object): def create_models(self): pass def setup_method(self, method): self.engine = create_engine( 'sqlite:///' ) #self.engine.echo = True self.Base = declarative_base() self.create_models() self.Base.metadata.create_all(self.engine) Session = sessionmaker(bind=self.engine) self.session = Session() def teardown_method(self, method): self.engine.dispose() self.Base.metadata.drop_all(self.engine) self.session.expunge_all() class TestMerge(DatabaseTestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) def __repr__(self): return 'User(%r)' % self.name class BlogPost(self.Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.Unicode(255)) content = sa.Column(sa.UnicodeText) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) self.User = User self.BlogPost = BlogPost def test_updates_foreign_keys(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') post = self.BlogPost(title=u'Some title', author=john) post2 = self.BlogPost(title=u'Other title', author=jack) self.session.add(john) self.session.add(jack) self.session.add(post) self.session.add(post2) self.session.commit() merge(john, jack) assert post.author == jack assert post2.author == jack def test_deletes_from_entity(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') self.session.add(john) self.session.add(jack) self.session.commit() merge(john, jack) assert john in self.session.deleted class TestMergeManyToManyAssociations(DatabaseTestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) def __repr__(self): return 'User(%r)' % self.name team_member = sa.Table( 'team_member', self.Base.metadata, sa.Column( 'user_id', sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE'), primary_key=True ), sa.Column( 'team_id', sa.Integer, sa.ForeignKey('team.id', ondelete='CASCADE'), primary_key=True ) ) class Team(self.Base): __tablename__ = 'team' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) members = sa.orm.relationship( User, secondary=team_member, backref='teams' ) self.User = User self.Team = Team def test_when_association_only_exists_in_from_entity(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') team = self.Team(name=u'Team') team.members.append(john) self.session.add(john) self.session.add(jack) self.session.commit() merge(john, jack) assert john not in team.members assert jack in team.members def test_when_association_exists_in_both(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') team = self.Team(name=u'Team') team.members.append(john) team.members.append(jack) self.session.add(john) self.session.add(jack) self.session.commit() merge(john, jack) assert john not in team.members assert jack in team.members count = self.session.execute( 'SELECT COUNT(1) FROM team_member' ).fetchone()[0] assert count == 1 class TestMergeManyToManyAssociationObjects(DatabaseTestCase): def create_models(self): class Team(self.Base): __tablename__ = 'team' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) class TeamMember(self.Base): __tablename__ = 'team_member' user_id = sa.Column( sa.Integer, sa.ForeignKey(User.id, ondelete='CASCADE'), primary_key=True ) team_id = sa.Column( sa.Integer, sa.ForeignKey(Team.id, ondelete='CASCADE'), primary_key=True ) role = sa.Column(sa.Unicode(255)) team = sa.orm.relationship( Team, backref=sa.orm.backref( 'members', cascade='all, delete-orphan' ), primaryjoin=team_id == Team.id, ) user = sa.orm.relationship( User, backref=sa.orm.backref( 'memberships', cascade='all, delete-orphan' ), primaryjoin=user_id == User.id, ) self.User = User self.TeamMember = TeamMember self.Team = Team def test_when_association_only_exists_in_from_entity(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') team = self.Team(name=u'Team') team.members.append(self.TeamMember(user=john)) self.session.add(john) self.session.add(jack) self.session.add(team) self.session.commit() merge(john, jack) self.session.commit() users = [member.user for member in team.members] assert john not in users assert jack in users def test_when_association_exists_in_both(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') team = self.Team(name=u'Team') team.members.append(self.TeamMember(user=john)) team.members.append(self.TeamMember(user=jack)) self.session.add(john) self.session.add(jack) self.session.add(team) self.session.commit() merge(john, jack) users = [member.user for member in team.members] assert john not in users assert jack in users assert self.session.query(self.TeamMember).count() == 1