Merge pull request #181 from jmagnusson/pytest-fixtures

Use pytest fixtures to reduce complexity and repetition
This commit is contained in:
Konsta Vesterinen
2016-01-20 09:23:11 +02:00
128 changed files with 5414 additions and 4287 deletions

14
.editorconfig Normal file
View File

@@ -0,0 +1,14 @@
# EditorConfig helps developers define and maintain consistent
# coding styles between different editors and IDEs
# editorconfig.org
root = true
[*]
indent_style = space
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
indent_size = 4

5
.gitignore vendored
View File

@@ -15,6 +15,8 @@ var
sdist sdist
develop-eggs develop-eggs
.installed.cfg .installed.cfg
.cache
.eggs
lib lib
lib64 lib64
docs/_build docs/_build
@@ -42,3 +44,6 @@ nosetests.xml
Session.vim Session.vim
.netrwhist .netrwhist
*~ *~
# Sublime Text
*.sublime-*

View File

@@ -1,5 +1,5 @@
[settings] [settings]
known_first_party=sqlalchemy_utils,tests known_first_party=sqlalchemy_utils
line_length=79 line_length=79
multi_line_output=3 multi_line_output=3
not_skip=__init__.py not_skip=__init__.py

View File

@@ -1,3 +1,6 @@
sudo: false
language: python
addons: addons:
postgresql: "9.4" postgresql: "9.4"
@@ -6,22 +9,29 @@ before_script:
- psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test - psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test
- mysql -e 'create database sqlalchemy_utils_test;' - mysql -e 'create database sqlalchemy_utils_test;'
language: python matrix:
python: include:
- 2.6 - python: 2.6
- 2.7 env:
- 3.3 - "TOXENV=py26"
- 3.4 - python: 2.7
- 3.5 env:
- "TOXENV=py27"
env: - python: 3.3
- EXTRAS=test env:
- EXTRAS=test_all - "TOXENV=py33"
- python: 3.4
env:
- "TOXENV=py34"
- python: 3.5
env:
- "TOXENV=py35"
- python: 3.5
env:
- "TOXENV=lint"
install: install:
- pip install -e .[$EXTRAS] - pip install tox
script: script:
- isort --recursive --diff sqlalchemy_utils tests && isort --recursive --check-only sqlalchemy_utils tests - tox
- flake8 sqlalchemy_utils tests
- py.test

View File

@@ -1,4 +1,4 @@
include CHANGES.rst LICENSE README.rst include CHANGES.rst LICENSE README.rst conftest.py .isort.cfg
recursive-include tests * recursive-include tests *
recursive-exclude tests *.pyc recursive-exclude tests *.pyc
recursive-include docs * recursive-include docs *

198
conftest.py Normal file
View File

@@ -0,0 +1,198 @@
import os
import warnings
import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import (
aggregates,
coercion_listener,
i18n,
InstrumentedList
)
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
try:
conn.query_count += 1
except AttributeError:
conn.query_count = 0
warnings.simplefilter('error', sa.exc.SAWarning)
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
def get_locale():
class Locale():
territories = {'FI': 'Finland'}
return Locale()
@pytest.fixture(scope='session')
def db_name():
return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test')
@pytest.fixture(scope='session')
def postgresql_db_user():
return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres')
@pytest.fixture(scope='session')
def mysql_db_user():
return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root')
@pytest.fixture
def postgresql_dsn(postgresql_db_user, db_name):
return 'postgres://{0}@localhost/{1}'.format(postgresql_db_user, db_name)
@pytest.fixture
def mysql_dsn(mysql_db_user, db_name):
return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name)
@pytest.fixture
def sqlite_memory_dsn():
return 'sqlite:///:memory:'
@pytest.fixture
def sqlite_file_dsn():
return 'sqlite:///{0}.db'.format(db_name)
@pytest.fixture
def dsn(request):
if 'postgresql_dsn' in request.fixturenames:
return request.getfuncargvalue('postgresql_dsn')
elif 'mysql_dsn' in request.fixturenames:
return request.getfuncargvalue('mysql_dsn')
elif 'sqlite_file_dsn' in request.fixturenames:
return request.getfuncargvalue('sqlite_file_dsn')
elif 'sqlite_memory_dsn' in request.fixturenames:
pass # Return default
return request.getfuncargvalue('sqlite_memory_dsn')
@pytest.fixture
def engine(dsn):
engine = create_engine(dsn)
# engine.echo = True
return engine
@pytest.fixture
def connection(engine):
return engine.connect()
@pytest.fixture
def Base():
return declarative_base()
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
return User
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
title = sa.Column(sa.Unicode(255))
@hybrid_property
def full_name(self):
return u'%s %s' % (self.title, self.name)
@full_name.expression
def full_name(self):
return sa.func.concat(self.title, ' ', self.name)
@hybrid_property
def articles_count(self):
return len(self.articles)
@articles_count.expression
def articles_count(cls):
Article = Base._decl_class_registry['Article']
return (
sa.select([sa.func.count(Article.id)])
.where(Article.category_id == cls.id)
.correlate(Article.__table__)
.label('article_count')
)
@property
def name_alias(self):
return self.name
@synonym_for('name')
@property
def name_synonym(self):
return self.name
return Category
@pytest.fixture
def Article(Base, Category):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), index=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
return Article
@pytest.fixture
def init_models(User, Category, Article):
pass
@pytest.fixture
def session(request, engine, connection, Base, init_models):
sa.orm.configure_mappers()
Base.metadata.create_all(connection)
Session = sessionmaker(bind=connection)
session = Session()
i18n.get_locale = get_locale
def teardown():
aggregates.manager.reset()
session.close_all()
Base.metadata.drop_all(connection)
remove_composite_listeners()
connection.close()
engine.dispose()
request.addfinalizer(teardown)
return session

View File

@@ -11,6 +11,8 @@ SQLAlchemy-Utils has been tested against the following Python platforms.
- cPython 2.6 - cPython 2.6
- cPython 2.7 - cPython 2.7
- cPython 3.3 - cPython 3.3
- cPython 3.4
- cPython 3.5
Installing an official release Installing an official release

View File

@@ -89,11 +89,11 @@ setup(
'Operating System :: OS Independent', 'Operating System :: OS Independent',
'Programming Language :: Python', 'Programming Language :: Python',
'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Topic :: Internet :: WWW/HTTP :: Dynamic Content', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules' 'Topic :: Software Development :: Libraries :: Python Modules'
] ]

View File

@@ -365,7 +365,6 @@ TODO
from collections import defaultdict from collections import defaultdict
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.functions import _FunctionGenerator from sqlalchemy.sql.functions import _FunctionGenerator
@@ -519,7 +518,7 @@ class AggregationManager(object):
) )
def update_generator_registry(self): def update_generator_registry(self):
for class_, attrs in six.iteritems(aggregated_attrs): for class_, attrs in aggregated_attrs.items():
for expr, path, column in attrs: for expr, path, column in attrs:
value = AggregatedValue( value = AggregatedValue(
class_=class_, class_=class_,
@@ -539,7 +538,7 @@ class AggregationManager(object):
if class_ in self.generator_registry: if class_ in self.generator_registry:
object_dict[class_].append(obj) object_dict[class_].append(obj)
for class_, objects in six.iteritems(object_dict): for class_, objects in object_dict.items():
for aggregate_value in self.generator_registry[class_]: for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects) query = aggregate_value.update_query(objects)
if query is not None: if query is not None:

View File

@@ -10,7 +10,7 @@ from sqlalchemy.sql.expression import (
) )
from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy_utils.functions.orm import quote from .functions.orm import quote
class explain(Executable, ClauseElement): class explain(Executable, ClauseElement):

View File

@@ -7,8 +7,7 @@ import sqlalchemy as sa
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy_utils.expressions import explain_analyze from ..expressions import explain_analyze
from ..utils import starts_with from ..utils import starts_with
from .orm import quote from .orm import quote

View File

@@ -1,7 +1,6 @@
from collections import defaultdict from collections import defaultdict
from itertools import groupby from itertools import groupby
import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm import object_session from sqlalchemy.orm import object_session
@@ -167,7 +166,7 @@ def merge_references(from_, to, foreign_keys=None):
new_values = get_foreign_key_values(fk, to) new_values = get_foreign_key_values(fk, to)
criteria = ( criteria = (
getattr(fk.constraint.table.c, key) == value getattr(fk.constraint.table.c, key) == value
for key, value in six.iteritems(old_values) for key, value in old_values.items()
) )
try: try:
mapper = get_mapper(fk.constraint.table) mapper = get_mapper(fk.constraint.table)

View File

@@ -19,7 +19,7 @@ from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy_utils.utils import is_sequence from ..utils import is_sequence
def get_class_by_table(base, table, data=None): def get_class_by_table(base, table, data=None):

View File

@@ -8,9 +8,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session from sqlalchemy.orm.session import _state_session
from sqlalchemy.util import set_creation_order from sqlalchemy.util import set_creation_order
from sqlalchemy_utils.functions import identity
from .exceptions import ImproperlyConfigured from .exceptions import ImproperlyConfigured
from .functions import identity
class GenericAttributeImpl(attributes.ScalarAttributeImpl): class GenericAttributeImpl(attributes.ScalarAttributeImpl):

View File

@@ -8,6 +8,7 @@ from .exceptions import ImproperlyConfigured
try: try:
import babel import babel
import babel.dates
except ImportError: except ImportError:
babel = None babel = None

View File

@@ -154,9 +154,9 @@ from collections import defaultdict, Iterable, namedtuple
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr, has_changes from .functions import getdotattr, has_changes
from sqlalchemy_utils.path import AttrPath from .path import AttrPath
from sqlalchemy_utils.utils import is_sequence from .utils import is_sequence
Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath'])

View File

@@ -1,7 +1,7 @@
import six import six
from sqlalchemy_utils import i18n from .. import i18n
from sqlalchemy_utils.utils import str_coercible from ..utils import str_coercible
@str_coercible @str_coercible

View File

@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import six import six
from sqlalchemy_utils import i18n, ImproperlyConfigured from .. import i18n, ImproperlyConfigured
from sqlalchemy_utils.utils import str_coercible from ..utils import str_coercible
@str_coercible @str_coercible

View File

@@ -4,8 +4,8 @@ try:
except ImportError: except ImportError:
# Python 2.6 port # Python 2.6 port
from total_ordering import total_ordering from total_ordering import total_ordering
from sqlalchemy_utils import i18n from .. import i18n
from sqlalchemy_utils.utils import str_coercible from ..utils import str_coercible
@str_coercible @str_coercible

View File

@@ -1,7 +1,6 @@
import six import six
from sqlalchemy_utils.utils import str_coercible from ..utils import str_coercible
from .weekday import WeekDay from .weekday import WeekDay

View File

@@ -6,8 +6,7 @@ from datetime import datetime
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible
arrow = None arrow = None

View File

@@ -1,8 +1,7 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible
colour = None colour = None

View File

@@ -1,8 +1,7 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils.primitives import Country from ..primitives import Country
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible

View File

@@ -1,9 +1,8 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils import i18n, ImproperlyConfigured from .. import i18n, ImproperlyConfigured
from sqlalchemy_utils.primitives import Currency from ..primitives import Currency
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible

View File

@@ -5,8 +5,7 @@ import datetime
import six import six
from sqlalchemy.types import Binary, String, TypeDecorator from sqlalchemy.types import Binary, String, TypeDecorator
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible
cryptography = None cryptography = None
@@ -84,7 +83,7 @@ class AesEngine(EncryptionDecryptionBaseEngine):
value = str(value) value = str(value)
decryptor = self.cipher.decryptor() decryptor = self.cipher.decryptor()
decrypted = base64.b64decode(value) decrypted = base64.b64decode(value)
decrypted = decryptor.update(decrypted)+decryptor.finalize() decrypted = decryptor.update(decrypted) + decryptor.finalize()
decrypted = decrypted.rstrip(self.PADDING) decrypted = decrypted.rstrip(self.PADDING)
if not isinstance(decrypted, six.string_types): if not isinstance(decrypted, six.string_types):
decrypted = decrypted.decode('utf-8') decrypted = decrypted.decode('utf-8')

View File

@@ -1,8 +1,7 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible
ip_address = None ip_address = None

View File

@@ -5,8 +5,7 @@ from sqlalchemy import types
from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.dialects import oracle, postgresql
from sqlalchemy.ext.mutable import Mutable from sqlalchemy.ext.mutable import Mutable
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible
passlib = None passlib = None

View File

@@ -109,7 +109,7 @@ from sqlalchemy.types import (
UserDefinedType UserDefinedType
) )
from sqlalchemy_utils import ImproperlyConfigured from .. import ImproperlyConfigured
psycopg2 = None psycopg2 = None
CompositeCaster = None CompositeCaster = None

View File

@@ -1,8 +1,7 @@
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from sqlalchemy_utils.utils import str_coercible from ..utils import str_coercible
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible
try: try:

View File

@@ -1,8 +1,7 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible

View File

@@ -1,10 +1,9 @@
import six import six
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy_utils import i18n from .. import i18n
from sqlalchemy_utils.exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
from sqlalchemy_utils.primitives import WeekDay, WeekDays from ..primitives import WeekDay, WeekDays
from .bit import BitType from .bit import BitType
from .scalar_coercible import ScalarCoercible from .scalar_coercible import ScalarCoercible

View File

@@ -1,132 +1,3 @@
import warnings
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import (
aggregates,
coercion_listener,
i18n,
InstrumentedList
)
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
try:
conn.query_count += 1
except AttributeError:
conn.query_count = 0
warnings.simplefilter('error', sa.exc.SAWarning)
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
def get_locale():
class Locale():
territories = {'FI': 'Finland'}
return Locale()
class TestCase(object):
dns = 'sqlite:///:memory:'
create_tables = True
def setup_method(self, method):
self.engine = create_engine(self.dns)
# self.engine.echo = True
self.connection = self.engine.connect()
self.Base = declarative_base()
self.create_models()
sa.orm.configure_mappers()
if self.create_tables:
self.Base.metadata.create_all(self.connection)
Session = sessionmaker(bind=self.connection)
self.session = Session()
i18n.get_locale = get_locale
def teardown_method(self, method):
aggregates.manager.reset()
self.session.close_all()
if self.create_tables:
self.Base.metadata.drop_all(self.connection)
remove_composite_listeners()
self.connection.close()
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
title = sa.Column(sa.Unicode(255))
@hybrid_property
def full_name(self):
return u'%s %s' % (self.title, self.name)
@full_name.expression
def full_name(self):
return sa.func.concat(self.title, ' ', self.name)
@hybrid_property
def articles_count(self):
return len(self.articles)
@articles_count.expression
def articles_count(cls):
return (
sa.select([sa.func.count(self.Article.id)])
.where(self.Article.category_id == self.Category.id)
.correlate(self.Article.__table__)
.label('article_count')
)
@property
def name_alias(self):
return self.name
@synonym_for('name')
@property
def name_synonym(self):
return self.name
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), index=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
self.User = User
self.Category = Category
self.Article = Article
def assert_contains(clause, query): def assert_contains(clause, query):
# Test that query executes # Test that query executes
query.all() query.all()

View File

@@ -1,61 +1,76 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationWithBackrefs(TestCase): @pytest.fixture
def create_models(self): def Thread(Base):
class Thread(self.Base): class Thread(Base):
__tablename__ = 'thread' __tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0)) @aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self): def comment_count(self):
return sa.func.count('1') return sa.func.count('1')
return Thread
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
thread = sa.orm.relationship(Thread, backref='comments') @pytest.fixture
def Comment(Base, Thread):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread thread = sa.orm.relationship(Thread, backref='comments')
self.Comment = Comment return Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread() @pytest.fixture
def init_models(Thread, Comment):
pass
class TestAggregateValueGenerationWithBackrefs(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self): def test_assigns_aggregates_on_separate_insert(
thread = self.Thread() self,
session,
Thread,
Comment
):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self): def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = self.Thread() thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.delete(comment) session.delete(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 0 assert thread.comment_count == 0

View File

@@ -1,67 +1,76 @@
from decimal import Decimal from decimal import Decimal
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
def create_models(self): catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
class Catalog(self.Base): return Product
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('products', sa.Column(sa.Numeric, default=0))
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product', backref='catalog') @pytest.fixture
def Catalog(Base, Product):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base): @aggregated('products', sa.Column(sa.Numeric, default=0))
__tablename__ = 'product' def net_worth(self):
id = sa.Column(sa.Integer, primary_key=True) return sa.func.sum(Product.price)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) products = sa.orm.relationship('Product', backref='catalog')
return Catalog
self.Catalog = Catalog
self.Product = Product
def test_assigns_aggregates_on_insert(self): @pytest.fixture
catalog = self.Catalog( def init_models(Product, Catalog):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestLazyEvaluatedSelectExpressionsForAggregates(object):
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog' name=u'Some catalog'
) )
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
product = self.Product( product = Product(
name=u'Some product', name=u'Some product',
price=Decimal('1000'), price=Decimal('1000'),
catalog=catalog catalog=catalog
) )
self.session.add(product) session.add(product)
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
assert catalog.net_worth == Decimal('1000') assert catalog.net_worth == Decimal('1000')
def test_assigns_aggregates_on_update(self): def test_assigns_aggregates_on_update(self, session, Product, Catalog):
catalog = self.Catalog( catalog = Catalog(
name=u'Some catalog' name=u'Some catalog'
) )
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
product = self.Product( product = Product(
name=u'Some product', name=u'Some product',
price=Decimal('1000'), price=Decimal('1000'),
catalog=catalog catalog=catalog
) )
self.session.add(product) session.add(product)
self.session.commit() session.commit()
product.price = Decimal('500') product.price = Decimal('500')
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
assert catalog.net_worth == Decimal('500') assert catalog.net_worth == Decimal('500')

View File

@@ -1,101 +1,121 @@
from decimal import Decimal from decimal import Decimal
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
def create_models(self): catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
class Catalog(self.Base): return Product
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type
}
@aggregated('products', sa.Column(sa.Numeric, default=0)) @pytest.fixture
def net_worth(self): def Catalog(Base, Product):
return sa.func.sum(Product.price) class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255))
products = sa.orm.relationship('Product', backref='catalog') __mapper_args__ = {
'polymorphic_on': type
}
class CostumeCatalog(Catalog): @aggregated('products', sa.Column(sa.Numeric, default=0))
__tablename__ = 'costume_catalog' def net_worth(self):
id = sa.Column( return sa.func.sum(Product.price)
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
__mapper_args__ = { products = sa.orm.relationship('Product', backref='catalog')
'polymorphic_identity': 'costumes', return Catalog
}
class CarCatalog(Catalog):
__tablename__ = 'car_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
__mapper_args__ = { @pytest.fixture
'polymorphic_identity': 'cars', def CostumeCatalog(Catalog):
} class CostumeCatalog(Catalog):
__tablename__ = 'costume_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
class Product(self.Base): __mapper_args__ = {
__tablename__ = 'product' 'polymorphic_identity': 'costumes',
id = sa.Column(sa.Integer, primary_key=True) }
name = sa.Column(sa.Unicode(255)) return CostumeCatalog
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
self.Catalog = Catalog @pytest.fixture
self.CostumeCatalog = CostumeCatalog def CarCatalog(Catalog):
self.CarCatalog = CarCatalog class CarCatalog(Catalog):
self.Product = Product __tablename__ = 'car_catalog'
id = sa.Column(
sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True
)
def test_columns_inherited_from_parent(self): __mapper_args__ = {
assert self.CarCatalog.net_worth 'polymorphic_identity': 'cars',
assert self.CostumeCatalog.net_worth }
assert self.Catalog.net_worth return CarCatalog
assert not hasattr(self.CarCatalog.__table__.c, 'net_worth')
assert not hasattr(self.CostumeCatalog.__table__.c, 'net_worth')
def test_assigns_aggregates_on_insert(self):
catalog = self.Catalog( @pytest.fixture
def init_models(Product, Catalog, CostumeCatalog, CarCatalog):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestLazyEvaluatedSelectExpressionsForAggregates(object):
def test_columns_inherited_from_parent(
self,
Catalog,
CarCatalog,
CostumeCatalog
):
assert CarCatalog.net_worth
assert CostumeCatalog.net_worth
assert Catalog.net_worth
assert not hasattr(CarCatalog.__table__.c, 'net_worth')
assert not hasattr(CostumeCatalog.__table__.c, 'net_worth')
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog' name=u'Some catalog'
) )
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
product = self.Product( product = Product(
name=u'Some product', name=u'Some product',
price=Decimal('1000'), price=Decimal('1000'),
catalog=catalog catalog=catalog
) )
self.session.add(product) session.add(product)
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
assert catalog.net_worth == Decimal('1000') assert catalog.net_worth == Decimal('1000')
def test_assigns_aggregates_on_update(self): def test_assigns_aggregates_on_update(self, session, Catalog, Product):
catalog = self.Catalog( catalog = Catalog(
name=u'Some catalog' name=u'Some catalog'
) )
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
product = self.Product( product = Product(
name=u'Some product', name=u'Some product',
price=Decimal('1000'), price=Decimal('1000'),
catalog=catalog catalog=catalog
) )
self.session.add(product) session.add(product)
self.session.commit() session.commit()
product.price = Decimal('500') product.price = Decimal('500')
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
assert catalog.net_worth == Decimal('500') assert catalog.net_worth == Decimal('500')

View File

@@ -1,72 +1,81 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregatesWithManyToManyRelationships(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def User(Base):
user_group = sa.Table(
'user_group',
Base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
)
def create_models(self): class User(Base):
user_group = sa.Table( __tablename__ = 'user'
'user_group', id = sa.Column(sa.Integer, primary_key=True)
self.Base.metadata, name = sa.Column(sa.Unicode(255))
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) @aggregated('groups', sa.Column(sa.Integer, default=0))
def group_count(self):
return sa.func.count('1')
groups = sa.orm.relationship(
'Group',
backref='users',
secondary=user_group
) )
return User
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('groups', sa.Column(sa.Integer, default=0)) @pytest.fixture
def group_count(self): def Group(Base):
return sa.func.count('1') class Group(Base):
__tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Group
groups = sa.orm.relationship(
'Group',
backref='users',
secondary=user_group
)
class Group(self.Base): @pytest.fixture
__tablename__ = 'group' def init_models(User, Group):
id = sa.Column(sa.Integer, primary_key=True) pass
name = sa.Column(sa.Unicode(255))
self.User = User
self.Group = Group
def test_assigns_aggregates_on_insert(self): @pytest.mark.usefixtures('postgresql_dsn')
user = self.User( class TestAggregatesWithManyToManyRelationships(object):
def test_assigns_aggregates_on_insert(self, session, User, Group):
user = User(
name=u'John Matrix' name=u'John Matrix'
) )
self.session.add(user) session.add(user)
self.session.commit() session.commit()
group = self.Group( group = Group(
name=u'Some group', name=u'Some group',
users=[user] users=[user]
) )
self.session.add(group) session.add(group)
self.session.commit() session.commit()
self.session.refresh(user) session.refresh(user)
assert user.group_count == 1 assert user.group_count == 1
def test_updates_aggregates_on_delete(self): def test_updates_aggregates_on_delete(self, session, User, Group):
user = self.User( user = User(
name=u'John Matrix' name=u'John Matrix'
) )
self.session.add(user) session.add(user)
self.session.commit() session.commit()
group = self.Group( group = Group(
name=u'Some group', name=u'Some group',
users=[user] users=[user]
) )
self.session.add(group) session.add(group)
self.session.commit() session.commit()
self.session.refresh(user) session.refresh(user)
user.groups = [] user.groups = []
self.session.commit() session.commit()
self.session.refresh(user) session.refresh(user)
assert user.group_count == 0 assert user.group_count == 0

View File

@@ -1,80 +1,92 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import aggregated from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateManyToManyAndManyToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Category
def create_models(self):
catalog_products = sa.Table( @pytest.fixture
'catalog_product', def Catalog(Base, Category):
self.Base.metadata, class Catalog(Base):
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), __tablename__ = 'catalog'
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
return Catalog
@pytest.fixture
def Product(Base, Catalog, Category):
catalog_products = sa.Table(
'catalog_product',
Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
product_categories = sa.Table(
'category_product',
Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
) )
product_categories = sa.Table( catalogs = sa.orm.relationship(
'category_product', Catalog,
self.Base.metadata, backref='products',
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), secondary=catalog_products
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
) )
class Catalog(self.Base): categories = sa.orm.relationship(
__tablename__ = 'catalog' Category,
id = sa.Column(sa.Integer, primary_key=True) backref='products',
name = sa.Column(sa.Unicode(255)) secondary=product_categories
)
return Product
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
class Category(self.Base): @pytest.fixture
__tablename__ = 'category' def init_models(Category, Catalog, Product):
id = sa.Column(sa.Integer, primary_key=True) pass
name = sa.Column(sa.Unicode(255))
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column( @pytest.mark.usefixtures('postgresql_dsn')
sa.Integer, sa.ForeignKey('catalog.id') class TestAggregateManyToManyAndManyToMany(object):
)
catalogs = sa.orm.relationship( def test_insert(self, session, Product, Category, Catalog):
Catalog, category = Category()
backref='products',
secondary=catalog_products
)
categories = sa.orm.relationship(
Category,
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [ products = [
self.Product(categories=[category]), Product(categories=[category]),
self.Product(categories=[category]) Product(categories=[category])
] ]
catalog = self.Catalog(products=products) catalog = Catalog(products=products)
self.session.add(catalog) session.add(catalog)
catalog2 = self.Catalog(products=products) catalog2 = Catalog(products=products)
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
assert catalog.category_count == 1 assert catalog.category_count == 1
assert catalog2.category_count == 1 assert catalog2.category_count == 1

View File

@@ -1,81 +1,96 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase): @pytest.fixture
def create_models(self): def Comment(Base):
class Thread(self.Base): class Comment(Base):
__tablename__ = 'thread' __tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
return Comment
@aggregated(
'comments',
sa.Column(sa.Integer, default=0)
)
def comment_count(self):
return sa.func.count('1')
@aggregated('comments', sa.Column(sa.Integer)) @pytest.fixture
def last_comment_id(self): def Thread(Base, Comment):
return sa.func.max(Comment.id) class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
comments = sa.orm.relationship( @aggregated(
'Comment', 'comments',
backref='thread' sa.Column(sa.Integer, default=0)
) )
def comment_count(self):
return sa.func.count('1')
Thread.last_comment = sa.orm.relationship( @aggregated('comments', sa.Column(sa.Integer))
def last_comment_id(self):
return sa.func.max(Comment.id)
comments = sa.orm.relationship(
'Comment', 'Comment',
primaryjoin='Thread.last_comment_id == Comment.id', backref='thread'
foreign_keys=[Thread.last_comment_id],
viewonly=True
) )
class Comment(self.Base): Thread.last_comment = sa.orm.relationship(
__tablename__ = 'comment' 'Comment',
id = sa.Column(sa.Integer, primary_key=True) primaryjoin='Thread.last_comment_id == Comment.id',
content = sa.Column(sa.Unicode(255)) foreign_keys=[Thread.last_comment_id],
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) viewonly=True
)
return Thread
self.Thread = Thread
self.Comment = Comment
def test_assigns_aggregates_on_insert(self): @pytest.fixture
thread = self.Thread() def init_models(Comment, Thread):
pass
class TestAggregateValueGenerationForSimpleModelPaths(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
assert thread.last_comment_id == comment.id assert thread.last_comment_id == comment.id
def test_assigns_aggregates_on_separate_insert(self): def test_assigns_aggregates_on_separate_insert(
thread = self.Thread() self,
session,
Thread,
Comment
):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
assert thread.last_comment_id == 1 assert thread.last_comment_id == 1
def test_assigns_aggregates_on_delete(self): def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = self.Thread() thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.delete(comment) session.delete(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 0 assert thread.comment_count == 0
assert thread.last_comment_id is None assert thread.last_comment_id is None

View File

@@ -1,76 +1,88 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import aggregated from sqlalchemy_utils import aggregated
from tests import TestCase
class TestAggregateOneToManyAndManyToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
return Category
def create_models(self):
product_categories = sa.Table( @pytest.fixture
'category_product', def Catalog(Base, Category):
self.Base.metadata, class Catalog(Base):
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), __tablename__ = 'catalog'
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated(
'products.categories',
sa.Column(sa.Integer, default=0)
)
def category_count(self):
return sa.func.count(sa.distinct(Category.id))
return Catalog
@pytest.fixture
def Product(Base, Catalog, Category):
product_categories = sa.Table(
'category_product',
Base.metadata,
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
)
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
) )
class Catalog(self.Base): catalog = sa.orm.relationship(
__tablename__ = 'catalog' Catalog,
id = sa.Column(sa.Integer, primary_key=True) backref='products'
name = sa.Column(sa.Unicode(255)) )
@aggregated( categories = sa.orm.relationship(
'products.categories', Category,
sa.Column(sa.Integer, default=0) backref='products',
) secondary=product_categories
def category_count(self): )
return sa.func.count(sa.distinct(Category.id)) return Product
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base): @pytest.fixture
__tablename__ = 'product' def init_models(Category, Catalog, Product):
id = sa.Column(sa.Integer, primary_key=True) pass
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(
sa.Integer, sa.ForeignKey('catalog.id')
)
catalog = sa.orm.relationship( @pytest.mark.usefixtures('postgresql_dsn')
Catalog, class TestAggregateOneToManyAndManyToMany(object):
backref='products'
)
categories = sa.orm.relationship( def test_insert(self, session, Category, Catalog, Product):
Category, category = Category()
backref='products',
secondary=product_categories
)
self.Catalog = Catalog
self.Category = Category
self.Product = Product
def test_insert(self):
category = self.Category()
products = [ products = [
self.Product(categories=[category]), Product(categories=[category]),
self.Product(categories=[category]) Product(categories=[category])
] ]
catalog = self.Catalog(products=products) catalog = Catalog(products=products)
self.session.add(catalog) session.add(catalog)
products2 = [ products2 = [
self.Product(categories=[category]), Product(categories=[category]),
self.Product(categories=[category]) Product(categories=[category])
] ]
catalog2 = self.Catalog(products=products2) catalog2 = Catalog(products=products2)
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
assert catalog.category_count == 1 assert catalog.category_count == 1
assert catalog2.category_count == 1 assert catalog2.category_count == 1

View File

@@ -1,64 +1,76 @@
from decimal import Decimal from decimal import Decimal
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateOneToManyAndOneToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def create_models(self): @aggregated(
class Catalog(self.Base): 'categories.products',
__tablename__ = 'catalog' sa.Column(sa.Integer, default=0)
id = sa.Column(sa.Integer, primary_key=True) )
name = sa.Column(sa.Unicode(255)) def product_count(self):
return sa.func.count('1')
@aggregated( categories = sa.orm.relationship('Category', backref='catalog')
'categories.products', return Catalog
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base): @pytest.fixture
__tablename__ = 'category' def Category(Base):
id = sa.Column(sa.Integer, primary_key=True) class Category(Base):
name = sa.Column(sa.Unicode(255)) __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category') products = sa.orm.relationship('Product', backref='category')
return Category
class Product(self.Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) @pytest.fixture
def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
self.Catalog = Catalog category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
self.Category = Category return Product
self.Product = Product
def test_assigns_aggregates(self):
category = self.Category(name=u'Some category') @pytest.fixture
catalog = self.Catalog( def init_models(Catalog, Category, Product):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregateOneToManyAndOneToMany(object):
def test_assigns_aggregates(self, session, Category, Catalog, Product):
category = Category(name=u'Some category')
catalog = Catalog(
categories=[category] categories=[category]
) )
catalog.name = u'Some catalog' catalog.name = u'Some catalog'
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
product = self.Product( product = Product(
name=u'Some product', name=u'Some product',
price=Decimal('1000'), price=Decimal('1000'),
category=category category=category
) )
self.session.add(product) session.add(product)
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
assert catalog.product_count == 1 assert catalog.product_count == 1

View File

@@ -1,88 +1,129 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import aggregated from sqlalchemy_utils import aggregated
from tests import TestCase
class Test3LevelDeepOneToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
def create_models(self): @aggregated(
class Catalog(self.Base): 'categories.sub_categories.products',
__tablename__ = 'catalog' sa.Column(sa.Integer, default=0)
id = sa.Column(sa.Integer, primary_key=True) )
def product_count(self):
return sa.func.count('1')
@aggregated( categories = sa.orm.relationship('Category', backref='catalog')
'categories.sub_categories.products', return Catalog
sa.Column(sa.Integer, default=0)
)
def product_count(self):
return sa.func.count('1')
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base): @pytest.fixture
__tablename__ = 'category' def Category(Base):
id = sa.Column(sa.Integer, primary_key=True) class Category(Base):
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship( sub_categories = sa.orm.relationship(
'SubCategory', backref='category' 'SubCategory', backref='category'
) )
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base): @pytest.fixture
__tablename__ = 'product' def SubCategory(Base):
id = sa.Column(sa.Integer, primary_key=True) class SubCategory(Base):
price = sa.Column(sa.Numeric) __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog @pytest.fixture
self.Category = Category def Product(Base):
self.SubCategory = SubCategory class Product(Base):
self.Product = Product __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
def test_assigns_aggregates(self): sub_category_id = sa.Column(
catalog = self.catalog_factory() sa.Integer, sa.ForeignKey('sub_category.id')
self.session.commit() )
self.session.refresh(catalog) return Product
assert catalog.product_count == 1
def catalog_factory(self):
product = self.Product() @pytest.fixture
sub_category = self.SubCategory( def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog_factory(Product, SubCategory, Category, Catalog, session):
def catalog_factory():
product = Product()
sub_category = SubCategory(
products=[product] products=[product]
) )
category = self.Category(sub_categories=[sub_category]) category = Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category]) catalog = Catalog(categories=[category])
self.session.add(catalog) session.add(catalog)
return catalog
return catalog_factory
@pytest.mark.usefixtures('postgresql_dsn')
class Test3LevelDeepOneToMany(object):
def test_assigns_aggregates(self, session, catalog_factory):
catalog = catalog_factory()
session.commit()
session.refresh(catalog)
assert catalog.product_count == 1
def catalog_factory(
self,
session,
Product,
SubCategory,
Category,
Catalog
):
product = Product()
sub_category = SubCategory(
products=[product]
)
category = Category(sub_categories=[sub_category])
catalog = Catalog(categories=[category])
session.add(catalog)
return catalog return catalog
def test_only_updates_affected_aggregates(self): def test_only_updates_affected_aggregates(
catalog = self.catalog_factory() self,
catalog2 = self.catalog_factory() session,
self.session.commit() catalog_factory,
Product
):
catalog = catalog_factory()
catalog2 = catalog_factory()
session.commit()
# force set catalog2 product_count to zero in order to check if it gets # force set catalog2 product_count to zero in order to check if it gets
# updated when the other catalog's product count gets updated # updated when the other catalog's product count gets updated
self.session.execute( session.execute(
'UPDATE catalog SET product_count = 0 WHERE id = %d' 'UPDATE catalog SET product_count = 0 WHERE id = %d'
% catalog2.id % catalog2.id
) )
catalog.categories[0].sub_categories[0].products.append( catalog.categories[0].sub_categories[0].products.append(
self.Product() Product()
) )
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
self.session.refresh(catalog2) session.refresh(catalog2)
assert catalog.product_count == 2 assert catalog.product_count == 2
assert catalog2.product_count == 0 assert catalog2.product_count == 0

View File

@@ -1,7 +1,7 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import aggregated, TSVectorType from sqlalchemy_utils import aggregated, TSVectorType
from tests import TestCase
def tsvector_reduce_concat(vectors): def tsvector_reduce_concat(vectors):
@@ -13,45 +13,54 @@ def tsvector_reduce_concat(vectors):
) )
class TestSearchVectorAggregates(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
def create_models(self): catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
class Catalog(self.Base): return Product
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('products', sa.Column(TSVectorType))
def product_search_vector(self):
return tsvector_reduce_concat(
sa.func.to_tsvector(Product.name)
)
products = sa.orm.relationship('Product', backref='catalog') @pytest.fixture
def Catalog(Base, Product):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Product(self.Base): @aggregated('products', sa.Column(TSVectorType))
__tablename__ = 'product' def product_search_vector(self):
id = sa.Column(sa.Integer, primary_key=True) return tsvector_reduce_concat(
name = sa.Column(sa.Unicode(255)) sa.func.to_tsvector(Product.name)
price = sa.Column(sa.Numeric) )
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) products = sa.orm.relationship('Product', backref='catalog')
return Catalog
self.Catalog = Catalog
self.Product = Product
def test_assigns_aggregates_on_insert(self): @pytest.fixture
catalog = self.Catalog( def init_models(Product, Catalog):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestSearchVectorAggregates(object):
def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
catalog = Catalog(
name=u'Some catalog' name=u'Some catalog'
) )
self.session.add(catalog) session.add(catalog)
self.session.commit() session.commit()
product = self.Product( product = Product(
name=u'Product XYZ', name=u'Product XYZ',
catalog=catalog catalog=catalog
) )
self.session.add(product) session.add(product)
self.session.commit() session.commit()
self.session.refresh(catalog) session.refresh(catalog)
assert catalog.product_search_vector == "'product':1 'xyz':2" assert catalog.product_search_vector == "'product':1 'xyz':2"

View File

@@ -1,61 +1,76 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationForSimpleModelPaths(TestCase): @pytest.fixture
def create_models(self): def Thread(Base):
class Thread(self.Base): class Thread(Base):
__tablename__ = 'thread' __tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0)) @aggregated('comments', sa.Column(sa.Integer, default=0))
def comment_count(self): def comment_count(self):
return sa.func.count('1') return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread') comments = sa.orm.relationship('Comment', backref='thread')
return Thread
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread @pytest.fixture
self.Comment = Comment def Comment(Base):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.Unicode(255))
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
return Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread() @pytest.fixture
def init_models(Thread, Comment):
pass
class TestAggregateValueGenerationForSimpleModelPaths(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self): def test_assigns_aggregates_on_separate_insert(
thread = self.Thread() self,
session,
Thread,
Comment
):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self): def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = self.Thread() thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.delete(comment) session.delete(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 0 assert thread.comment_count == 0

View File

@@ -1,59 +1,74 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregatedWithColumnAlias(TestCase): @pytest.fixture
def create_models(self): def Thread(Base):
class Thread(self.Base): class Thread(Base):
__tablename__ = 'thread' __tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
@aggregated( @aggregated(
'comments', 'comments',
sa.Column('_comment_count', sa.Integer, default=0) sa.Column('_comment_count', sa.Integer, default=0)
) )
def comment_count(self): def comment_count(self):
return sa.func.count('1') return sa.func.count('1')
comments = sa.orm.relationship('Comment', backref='thread') comments = sa.orm.relationship('Comment', backref='thread')
return Thread
class Comment(self.Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
self.Thread = Thread @pytest.fixture
self.Comment = Comment def Comment(Base):
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
return Comment
def test_assigns_aggregates_on_insert(self):
thread = self.Thread() @pytest.fixture
self.session.add(thread) def init_models(Thread, Comment):
comment = self.Comment(thread=thread) pass
self.session.add(comment)
self.session.commit()
self.session.refresh(thread) class TestAggregatedWithColumnAlias(object):
def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
thread = Thread()
session.add(thread)
comment = Comment(thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
def test_assigns_aggregates_on_separate_insert(self): def test_assigns_aggregates_on_separate_insert(
thread = self.Thread() self,
self.session.add(thread) session,
self.session.commit() Thread,
comment = self.Comment(thread=thread) Comment
self.session.add(comment) ):
self.session.commit() thread = Thread()
self.session.refresh(thread) session.add(thread)
session.commit()
comment = Comment(thread=thread)
session.add(comment)
session.commit()
session.refresh(thread)
assert thread.comment_count == 1 assert thread.comment_count == 1
def test_assigns_aggregates_on_delete(self): def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
thread = self.Thread() thread = Thread()
self.session.add(thread) session.add(thread)
self.session.commit() session.commit()
comment = self.Comment(thread=thread) comment = Comment(thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.delete(comment) session.delete(comment)
self.session.commit() session.commit()
self.session.refresh(thread) session.refresh(thread)
assert thread.comment_count == 0 assert thread.comment_count == 0

View File

@@ -1,47 +1,56 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.aggregates import aggregated from sqlalchemy_utils.aggregates import aggregated
from tests import TestCase
class TestAggregateValueGenerationWithCascadeDelete(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Thread(Base):
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def create_models(self): @aggregated('comments', sa.Column(sa.Integer, default=0))
class Thread(self.Base): def comment_count(self):
__tablename__ = 'thread' return sa.func.count('1')
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer, default=0)) comments = sa.orm.relationship(
def comment_count(self): 'Comment',
return sa.func.count('1') passive_deletes=True,
backref='thread'
)
return Thread
comments = sa.orm.relationship(
'Comment',
passive_deletes=True,
backref='thread'
)
class Comment(self.Base): @pytest.fixture
__tablename__ = 'comment' def Comment(Base):
id = sa.Column(sa.Integer, primary_key=True) class Comment(Base):
content = sa.Column(sa.Unicode(255)) __tablename__ = 'comment'
thread_id = sa.Column( id = sa.Column(sa.Integer, primary_key=True)
sa.Integer, content = sa.Column(sa.Unicode(255))
sa.ForeignKey('thread.id', ondelete='CASCADE') thread_id = sa.Column(
) sa.Integer,
sa.ForeignKey('thread.id', ondelete='CASCADE')
)
return Comment
self.Thread = Thread
self.Comment = Comment
def test_something(self): @pytest.fixture
thread = self.Thread() def init_models(Thread, Comment):
pass
@pytest.mark.usefixtures('postgresql_dsn')
class TestAggregateValueGenerationWithCascadeDelete(object):
def test_something(self, session, Thread, Comment):
thread = Thread()
thread.name = u'some article name' thread.name = u'some article name'
self.session.add(thread) session.add(thread)
comment = self.Comment(content=u'Some content', thread=thread) comment = Comment(content=u'Some content', thread=thread)
self.session.add(comment) session.add(comment)
self.session.commit() session.commit()
self.session.expire_all() session.expire_all()
self.session.delete(thread) session.delete(thread)
self.session.commit() session.commit()

View File

@@ -1,29 +1,35 @@
import pytest
from sqlalchemy_utils import analyze from sqlalchemy_utils import analyze
from tests import TestCase
class TestAnalyzeWithPostgres(TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestAnalyzeWithPostgres(object):
def test_runtime(self): def test_runtime(self, session, connection, Article):
query = self.session.query(self.Article) query = session.query(Article)
assert analyze(self.connection, query).runtime assert analyze(connection, query).runtime
def test_node_types_with_join(self): def test_node_types_with_join(self, session, connection, Article):
query = ( query = (
self.session.query(self.Article) session.query(Article)
.join(self.Article.category) .join(Article.category)
) )
analysis = analyze(self.connection, query) analysis = analyze(connection, query)
assert analysis.node_types == [ assert analysis.node_types == [
u'Hash Join', u'Seq Scan', u'Hash', u'Seq Scan' u'Hash Join', u'Seq Scan', u'Hash', u'Seq Scan'
] ]
def test_node_types_with_index_only_scan(self): def test_node_types_with_index_only_scan(
self,
session,
connection,
Article
):
query = ( query = (
self.session.query(self.Article.name) session.query(Article.name)
.order_by(self.Article.name) .order_by(Article.name)
.limit(10) .limit(10)
) )
analysis = analyze(self.connection, query) analysis = analyze(connection, query)
assert analysis.node_types == [u'Limit', u'Index Only Scan'] assert analysis.node_types == [u'Limit', u'Index Only Scan']

View File

@@ -1,11 +1,8 @@
import os import pytest
import sqlalchemy as sa import sqlalchemy as sa
from flexmock import flexmock from flexmock import flexmock
from pytest import mark
from sqlalchemy_utils import create_database, database_exists, drop_database from sqlalchemy_utils import create_database, database_exists, drop_database
from tests import TestCase
pymysql = None pymysql = None
try: try:
@@ -14,38 +11,73 @@ except ImportError:
pass pass
class DatabaseTest(TestCase): class DatabaseTest(object):
def test_create_and_drop(self): def test_create_and_drop(self, dsn):
assert not database_exists(self.url) assert not database_exists(dsn)
create_database(self.url) create_database(dsn)
assert database_exists(self.url) assert database_exists(dsn)
drop_database(self.url) drop_database(dsn)
assert not database_exists(self.url) assert not database_exists(dsn)
class TestDatabaseSQLite(DatabaseTest): @pytest.mark.usefixtures('sqlite_memory_dsn')
url = 'sqlite:///sqlalchemy_utils.db' class TestDatabaseSQLiteMemory(object):
def setup(self): def test_exists_memory(self, dsn):
if os.path.exists('sqlalchemy_utils.db'): assert database_exists(dsn)
os.remove('sqlalchemy_utils.db')
def test_exists_memory(self):
assert database_exists('sqlite:///:memory:')
@mark.skipif('pymysql is None') @pytest.mark.usefixtures('sqlite_file_dsn')
class TestDatabaseSQLiteFile(DatabaseTest):
pass
@pytest.mark.skipif('pymysql is None')
@pytest.mark.usefixtures('mysql_dsn')
class TestDatabaseMySQL(DatabaseTest): class TestDatabaseMySQL(DatabaseTest):
url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy_util'
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy_util'
@mark.skipif('pymysql is None') @pytest.mark.skipif('pymysql is None')
@pytest.mark.usefixtures('mysql_dsn')
class TestDatabaseMySQLWithQuotedName(DatabaseTest): class TestDatabaseMySQLWithQuotedName(DatabaseTest):
url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy-util'
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy-util'
@pytest.mark.usefixtures('postgresql_dsn')
class TestDatabasePostgres(DatabaseTest):
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy_util'
def test_template(self):
(
flexmock(sa.engine.Engine)
.should_receive('execute')
.with_args(
"CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' "
"TEMPLATE my_template"
)
)
create_database(
'postgres://postgres@localhost/db_test_sqlalchemy_util',
template='my_template'
)
@pytest.mark.usefixtures('postgresql_dsn')
class TestDatabasePostgresWithQuotedName(DatabaseTest): class TestDatabasePostgresWithQuotedName(DatabaseTest):
url = 'postgres://postgres@localhost/db_test_sqlalchemy-util'
@pytest.fixture
def db_name(self):
return 'db_test_sqlalchemy-util'
def test_template(self): def test_template(self):
( (
@@ -61,21 +93,3 @@ class TestDatabasePostgresWithQuotedName(DatabaseTest):
'postgres://postgres@localhost/db_test_sqlalchemy-util', 'postgres://postgres@localhost/db_test_sqlalchemy-util',
template='my-template' template='my-template'
) )
class TestDatabasePostgres(DatabaseTest):
url = 'postgres://postgres@localhost/db_test_sqlalchemy_util'
def test_template(self):
(
flexmock(sa.engine.Engine)
.should_receive('execute')
.with_args(
"CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' "
"TEMPLATE my_template"
)
)
create_database(
'postgres://postgres@localhost/db_test_sqlalchemy_util',
template='my_template'
)

View File

@@ -1,18 +1,23 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys
from tests import TestCase
class TestDependentObjects(TestCase): class TestDependentObjects(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255)) first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255))
return User
class Article(self.Base): @pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
@@ -22,8 +27,11 @@ class TestDependentObjects(TestCase):
author = sa.orm.relationship(User, foreign_keys=[author_id]) author = sa.orm.relationship(User, foreign_keys=[author_id])
owner = sa.orm.relationship(User, foreign_keys=[owner_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id])
return Article
class BlogPost(self.Base): @pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post' __tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
owner_id = sa.Column( owner_id = sa.Column(
@@ -31,21 +39,22 @@ class TestDependentObjects(TestCase):
) )
owner = sa.orm.relationship(User) owner = sa.orm.relationship(User)
return BlogPost
self.User = User @pytest.fixture
self.Article = Article def init_models(self, User, Article, BlogPost):
self.BlogPost = BlogPost pass
def test_returns_all_dependent_objects(self): def test_returns_all_dependent_objects(self, session, User, Article):
user = self.User(first_name=u'John') user = User(first_name=u'John')
articles = [ articles = [
self.Article(author=user), Article(author=user),
self.Article(), Article(),
self.Article(owner=user), Article(owner=user),
self.Article(author=user, owner=user) Article(author=user, owner=user)
] ]
self.session.add_all(articles) session.add_all(articles)
self.session.commit() session.commit()
deps = list(dependent_objects(user)) deps = list(dependent_objects(user))
assert len(deps) == 3 assert len(deps) == 3
@@ -53,23 +62,29 @@ class TestDependentObjects(TestCase):
assert articles[2] in deps assert articles[2] in deps
assert articles[3] in deps assert articles[3] in deps
def test_with_foreign_keys_parameter(self): def test_with_foreign_keys_parameter(
user = self.User(first_name=u'John') self,
session,
User,
Article,
BlogPost
):
user = User(first_name=u'John')
objects = [ objects = [
self.Article(author=user), Article(author=user),
self.Article(), Article(),
self.Article(owner=user), Article(owner=user),
self.Article(author=user, owner=user), Article(author=user, owner=user),
self.BlogPost(owner=user) BlogPost(owner=user)
] ]
self.session.add_all(objects) session.add_all(objects)
self.session.commit() session.commit()
deps = list( deps = list(
dependent_objects( dependent_objects(
user, user,
( (
fk for fk in get_referencing_foreign_keys(self.User) fk for fk in get_referencing_foreign_keys(User)
if fk.ondelete == 'RESTRICT' or fk.ondelete is None if fk.ondelete == 'RESTRICT' or fk.ondelete is None
) )
).limit(5) ).limit(5)
@@ -79,15 +94,20 @@ class TestDependentObjects(TestCase):
assert objects[3] in deps assert objects[3] in deps
class TestDependentObjectsWithColumnAliases(TestCase): class TestDependentObjectsWithColumnAliases(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255)) first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255))
return User
class Article(self.Base): @pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column( author_id = sa.Column(
@@ -100,8 +120,11 @@ class TestDependentObjectsWithColumnAliases(TestCase):
author = sa.orm.relationship(User, foreign_keys=[author_id]) author = sa.orm.relationship(User, foreign_keys=[author_id])
owner = sa.orm.relationship(User, foreign_keys=[owner_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id])
return Article
class BlogPost(self.Base): @pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post' __tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
owner_id = sa.Column( owner_id = sa.Column(
@@ -110,21 +133,22 @@ class TestDependentObjectsWithColumnAliases(TestCase):
) )
owner = sa.orm.relationship(User) owner = sa.orm.relationship(User)
return BlogPost
self.User = User @pytest.fixture
self.Article = Article def init_models(self, User, Article, BlogPost):
self.BlogPost = BlogPost pass
def test_returns_all_dependent_objects(self): def test_returns_all_dependent_objects(self, session, User, Article):
user = self.User(first_name=u'John') user = User(first_name=u'John')
articles = [ articles = [
self.Article(author=user), Article(author=user),
self.Article(), Article(),
self.Article(owner=user), Article(owner=user),
self.Article(author=user, owner=user) Article(author=user, owner=user)
] ]
self.session.add_all(articles) session.add_all(articles)
self.session.commit() session.commit()
deps = list(dependent_objects(user)) deps = list(dependent_objects(user))
assert len(deps) == 3 assert len(deps) == 3
@@ -132,23 +156,29 @@ class TestDependentObjectsWithColumnAliases(TestCase):
assert articles[2] in deps assert articles[2] in deps
assert articles[3] in deps assert articles[3] in deps
def test_with_foreign_keys_parameter(self): def test_with_foreign_keys_parameter(
user = self.User(first_name=u'John') self,
session,
User,
Article,
BlogPost
):
user = User(first_name=u'John')
objects = [ objects = [
self.Article(author=user), Article(author=user),
self.Article(), Article(),
self.Article(owner=user), Article(owner=user),
self.Article(author=user, owner=user), Article(author=user, owner=user),
self.BlogPost(owner=user) BlogPost(owner=user)
] ]
self.session.add_all(objects) session.add_all(objects)
self.session.commit() session.commit()
deps = list( deps = list(
dependent_objects( dependent_objects(
user, user,
( (
fk for fk in get_referencing_foreign_keys(self.User) fk for fk in get_referencing_foreign_keys(User)
if fk.ondelete == 'RESTRICT' or fk.ondelete is None if fk.ondelete == 'RESTRICT' or fk.ondelete is None
) )
).limit(5) ).limit(5)
@@ -158,50 +188,64 @@ class TestDependentObjectsWithColumnAliases(TestCase):
assert objects[3] in deps assert objects[3] in deps
class TestDependentObjectsWithManyReferences(TestCase): class TestDependentObjectsWithManyReferences(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
first_name = sa.Column(sa.Unicode(255)) first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255))
return User
class BlogPost(self.Base): @pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post' __tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User) author = sa.orm.relationship(User)
return BlogPost
class Article(self.Base): @pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User) author = sa.orm.relationship(User)
return Article
self.User = User @pytest.fixture
self.Article = Article def init_models(self, User, BlogPost, Article):
self.BlogPost = BlogPost pass
def test_with_many_dependencies(self): def test_with_many_dependencies(self, session, User, Article, BlogPost):
user = self.User(first_name=u'John') user = User(first_name=u'John')
objects = [ objects = [
self.Article(author=user), Article(author=user),
self.BlogPost(author=user) BlogPost(author=user)
] ]
self.session.add_all(objects) session.add_all(objects)
self.session.commit() session.commit()
deps = list(dependent_objects(user)) deps = list(dependent_objects(user))
assert len(deps) == 2 assert len(deps) == 2
class TestDependentObjectsWithCompositeKeys(TestCase): class TestDependentObjectsWithCompositeKeys(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True) first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True)
return User
class Article(self.Base): @pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255)) author_first_name = sa.Column(sa.Unicode(255))
@@ -214,20 +258,22 @@ class TestDependentObjectsWithCompositeKeys(TestCase):
) )
author = sa.orm.relationship(User) author = sa.orm.relationship(User)
return Article
self.User = User @pytest.fixture
self.Article = Article def init_models(self, User, Article):
pass
def test_returns_all_dependent_objects(self): def test_returns_all_dependent_objects(self, session, User, Article):
user = self.User(first_name=u'John', last_name=u'Smith') user = User(first_name=u'John', last_name=u'Smith')
articles = [ articles = [
self.Article(author=user), Article(author=user),
self.Article(), Article(),
self.Article(), Article(),
self.Article(author=user) Article(author=user)
] ]
self.session.add_all(articles) session.add_all(articles)
self.session.commit() session.commit()
deps = list(dependent_objects(user)) deps = list(dependent_objects(user))
assert len(deps) == 2 assert len(deps) == 2
@@ -235,14 +281,19 @@ class TestDependentObjectsWithCompositeKeys(TestCase):
assert articles[3] in deps assert articles[3] in deps
class TestDependentObjectsWithSingleTableInheritance(TestCase): class TestDependentObjectsWithSingleTableInheritance(object):
def create_models(self):
class Category(self.Base): @pytest.fixture
def Category(self, Base):
class Category(Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Category
class TextItem(self.Base): @pytest.fixture
def TextItem(self, Base, Category):
class TextItem(Base):
__tablename__ = 'text_item' __tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@@ -261,33 +312,39 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': type, 'polymorphic_on': type,
} }
return TextItem
@pytest.fixture
def Article(self, TextItem):
class Article(TextItem): class Article(TextItem):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': u'article' 'polymorphic_identity': u'article'
} }
return Article
@pytest.fixture
def BlogPost(self, TextItem):
class BlogPost(TextItem): class BlogPost(TextItem):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': u'blog_post' 'polymorphic_identity': u'blog_post'
} }
return BlogPost
self.Category = Category @pytest.fixture
self.TextItem = TextItem def init_models(self, Category, TextItem, Article, BlogPost):
self.Article = Article pass
self.BlogPost = BlogPost
def test_returns_all_dependent_objects(self): def test_returns_all_dependent_objects(self, session, Category, Article):
category1 = self.Category(name=u'Category #1') category1 = Category(name=u'Category #1')
category2 = self.Category(name=u'Category #2') category2 = Category(name=u'Category #2')
articles = [ articles = [
self.Article(category=category1), Article(category=category1),
self.Article(category=category1), Article(category=category1),
self.Article(category=category2), Article(category=category2),
self.Article(category=category2), Article(category=category2),
] ]
self.session.add_all(articles) session.add_all(articles)
self.session.commit() session.commit()
deps = list(dependent_objects(category1)) deps = list(dependent_objects(category1))
assert len(deps) == 2 assert len(deps) == 2

View File

@@ -1,7 +1,6 @@
from sqlalchemy_utils import escape_like from sqlalchemy_utils import escape_like
from tests import TestCase
class TestEscapeLike(TestCase): class TestEscapeLike(object):
def test_escapes_wildcards(self): def test_escapes_wildcards(self):
assert escape_like('_*%') == '*_***%' assert escape_like('_*%') == '*_***%'

View File

@@ -1,21 +1,20 @@
from pytest import raises import pytest
from sqlalchemy_utils import get_bind from sqlalchemy_utils import get_bind
from tests import TestCase
class TestGetBind(TestCase): class TestGetBind(object):
def test_with_session(self): def test_with_session(self, session, connection):
assert get_bind(self.session) == self.connection assert get_bind(session) == connection
def test_with_connection(self): def test_with_connection(self, session, connection):
assert get_bind(self.connection) == self.connection assert get_bind(connection) == connection
def test_with_model_object(self): def test_with_model_object(self, session, connection, Article):
article = self.Article() article = Article()
self.session.add(article) session.add(article)
assert get_bind(article) == self.connection assert get_bind(article) == connection
def test_with_unknown_type(self): def test_with_unknown_type(self):
with raises(TypeError): with pytest.raises(TypeError):
get_bind(None) get_bind(None)

View File

@@ -1,15 +1,14 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_class_by_table from sqlalchemy_utils import get_class_by_table
class TestGetClassByTableWithJoinedTableInheritance(object): class TestGetClassByTableWithJoinedTableInheritance(object):
def setup_method(self, method):
self.Base = declarative_base()
class Entity(self.Base): @pytest.fixture
def Entity(self, Base):
class Entity(Base):
__tablename__ = 'entity' __tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String) name = sa.Column(sa.String)
@@ -18,7 +17,10 @@ class TestGetClassByTableWithJoinedTableInheritance(object):
'polymorphic_on': type, 'polymorphic_on': type,
'polymorphic_identity': 'entity' 'polymorphic_identity': 'entity'
} }
return Entity
@pytest.fixture
def User(self, Entity):
class User(Entity): class User(Entity):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column( id = sa.Column(
@@ -29,31 +31,29 @@ class TestGetClassByTableWithJoinedTableInheritance(object):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': 'user' 'polymorphic_identity': 'user'
} }
return User
self.Entity = Entity def test_returns_class(self, Base, User, Entity):
self.User = User assert get_class_by_table(Base, User.__table__) == User
def test_returns_class(self):
assert get_class_by_table(self.Base, self.User.__table__) == self.User
assert get_class_by_table( assert get_class_by_table(
self.Base, Base,
self.Entity.__table__ Entity.__table__
) == self.Entity ) == Entity
def test_table_with_no_associated_class(self): def test_table_with_no_associated_class(self, Base):
table = sa.Table( table = sa.Table(
'some_table', 'some_table',
self.Base.metadata, Base.metadata,
sa.Column('id', sa.Integer) sa.Column('id', sa.Integer)
) )
assert get_class_by_table(self.Base, table) is None assert get_class_by_table(Base, table) is None
class TestGetClassByTableWithSingleTableInheritance(object): class TestGetClassByTableWithSingleTableInheritance(object):
def setup_method(self, method):
self.Base = declarative_base()
class Entity(self.Base): @pytest.fixture
def Entity(self, Base):
class Entity(Base):
__tablename__ = 'entity' __tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String) name = sa.Column(sa.String)
@@ -62,38 +62,39 @@ class TestGetClassByTableWithSingleTableInheritance(object):
'polymorphic_on': type, 'polymorphic_on': type,
'polymorphic_identity': 'entity' 'polymorphic_identity': 'entity'
} }
return Entity
@pytest.fixture
def User(self, Entity):
class User(Entity): class User(Entity):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': 'user' 'polymorphic_identity': 'user'
} }
return User
self.Entity = Entity def test_multiple_classes_without_data_parameter(self, Base, Entity, User):
self.User = User with pytest.raises(ValueError):
def test_multiple_classes_without_data_parameter(self):
with raises(ValueError):
assert get_class_by_table( assert get_class_by_table(
self.Base, Base,
self.Entity.__table__ Entity.__table__
) )
def test_multiple_classes_with_data_parameter(self): def test_multiple_classes_with_data_parameter(self, Base, Entity, User):
assert get_class_by_table( assert get_class_by_table(
self.Base, Base,
self.Entity.__table__, Entity.__table__,
{'type': 'entity'} {'type': 'entity'}
) == self.Entity ) == Entity
assert get_class_by_table( assert get_class_by_table(
self.Base, Base,
self.Entity.__table__, Entity.__table__,
{'type': 'user'} {'type': 'user'}
) == self.User ) == User
def test_multiple_classes_with_bogus_data(self): def test_multiple_classes_with_bogus_data(self, Base, Entity, User):
with raises(ValueError): with pytest.raises(ValueError):
assert get_class_by_table( assert get_class_by_table(
self.Base, Base,
self.Entity.__table__, Entity.__table__,
{'type': 'unknown'} {'type': 'unknown'}
) )

View File

@@ -1,42 +1,44 @@
from copy import copy from copy import copy
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_column_key from sqlalchemy_utils import get_column_key
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
return Building
@pytest.fixture
def Movie(Base):
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
return Movie
class TestGetColumnKey(object): class TestGetColumnKey(object):
def setup_method(self, method):
Base = declarative_base()
class Building(Base): def test_supports_aliases(self, Building):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.Movie = Movie
def test_supports_aliases(self):
assert ( assert (
get_column_key(self.Building, self.Building.__table__.c.id) == get_column_key(Building, Building.__table__.c.id) ==
'id' 'id'
) )
assert ( assert (
get_column_key(self.Building, self.Building.__table__.c._name) == get_column_key(Building, Building.__table__.c._name) ==
'name' 'name'
) )
def test_supports_vague_matching_of_column_objects(self): def test_supports_vague_matching_of_column_objects(self, Building):
column = copy(self.Building.__table__.c._name) column = copy(Building.__table__.c._name)
assert get_column_key(self.Building, column) == 'name' assert get_column_key(Building, column) == 'name'
def test_throws_value_error_for_unknown_column(self): def test_throws_value_error_for_unknown_column(self, Building, Movie):
with raises(sa.orm.exc.UnmappedColumnError): with pytest.raises(sa.orm.exc.UnmappedColumnError):
get_column_key(self.Building, self.Movie.__table__.c.id) get_column_key(Building, Movie.__table__.c.id)

View File

@@ -1,65 +1,65 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_columns from sqlalchemy_utils import get_columns
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
return Building
class TestGetColumns(object): class TestGetColumns(object):
def setup_method(self, method):
Base = declarative_base()
class Building(Base): def test_table(self, Building):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
self.Building = Building
def test_table(self):
assert isinstance( assert isinstance(
get_columns(self.Building.__table__), get_columns(Building.__table__),
sa.sql.base.ImmutableColumnCollection sa.sql.base.ImmutableColumnCollection
) )
def test_instrumented_attribute(self): def test_instrumented_attribute(self, Building):
assert get_columns(self.Building.id) == [self.Building.__table__.c._id] assert get_columns(Building.id) == [Building.__table__.c._id]
def test_column_property(self): def test_column_property(self, Building):
assert get_columns(self.Building.id.property) == [ assert get_columns(Building.id.property) == [
self.Building.__table__.c._id Building.__table__.c._id
] ]
def test_column(self): def test_column(self, Building):
assert get_columns(self.Building.__table__.c._id) == [ assert get_columns(Building.__table__.c._id) == [
self.Building.__table__.c._id Building.__table__.c._id
] ]
def test_declarative_class(self): def test_declarative_class(self, Building):
assert isinstance( assert isinstance(
get_columns(self.Building), get_columns(Building),
sa.util._collections.OrderedProperties sa.util._collections.OrderedProperties
) )
def test_declarative_object(self): def test_declarative_object(self, Building):
assert isinstance( assert isinstance(
get_columns(self.Building()), get_columns(Building()),
sa.util._collections.OrderedProperties sa.util._collections.OrderedProperties
) )
def test_mapper(self): def test_mapper(self, Building):
assert isinstance( assert isinstance(
get_columns(self.Building.__mapper__), get_columns(Building.__mapper__),
sa.util._collections.OrderedProperties sa.util._collections.OrderedProperties
) )
def test_class_alias(self): def test_class_alias(self, Building):
assert isinstance( assert isinstance(
get_columns(sa.orm.aliased(self.Building)), get_columns(sa.orm.aliased(Building)),
sa.util._collections.OrderedProperties sa.util._collections.OrderedProperties
) )
def test_table_alias(self): def test_table_alias(self, Building):
alias = sa.orm.aliased(self.Building.__table__) alias = sa.orm.aliased(Building.__table__)
assert isinstance( assert isinstance(
get_columns(alias), get_columns(alias),
sa.sql.base.ImmutableColumnCollection sa.sql.base.ImmutableColumnCollection

View File

@@ -1,41 +1,41 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import get_hybrid_properties from sqlalchemy_utils import get_hybrid_properties
@pytest.fixture
def Category(Base):
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def lowercase_name(self):
return self.name.lower()
@lowercase_name.expression
def lowercase_name(cls):
return sa.func.lower(cls.name)
return Category
class TestGetHybridProperties(object): class TestGetHybridProperties(object):
def setup_method(self, method):
Base = declarative_base()
class Category(Base): def test_declarative_model(self, Category):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def lowercase_name(self):
return self.name.lower()
@lowercase_name.expression
def lowercase_name(cls):
return sa.func.lower(cls.name)
self.Category = Category
def test_declarative_model(self):
assert ( assert (
list(get_hybrid_properties(self.Category).keys()) == list(get_hybrid_properties(Category).keys()) ==
['lowercase_name'] ['lowercase_name']
) )
def test_mapper(self): def test_mapper(self, Category):
assert ( assert (
list(get_hybrid_properties(sa.inspect(self.Category)).keys()) == list(get_hybrid_properties(sa.inspect(Category)).keys()) ==
['lowercase_name'] ['lowercase_name']
) )
def test_aliased_class(self): def test_aliased_class(self, Category):
props = get_hybrid_properties(sa.orm.aliased(self.Category)) props = get_hybrid_properties(sa.orm.aliased(Category))
assert list(props.keys()) == ['lowercase_name'] assert list(props.keys()) == ['lowercase_name']

View File

@@ -1,104 +1,106 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_mapper from sqlalchemy_utils import get_mapper
from tests import TestCase
class TestGetMapper(object): class TestGetMapper(object):
def setup_method(self, method):
self.Base = declarative_base()
class Building(self.Base): @pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building' __tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
return Building
self.Building = Building def test_table(self, Building):
assert get_mapper(Building.__table__) == sa.inspect(Building)
def test_table(self): def test_declarative_class(self, Building):
assert get_mapper(self.Building.__table__) == sa.inspect(self.Building)
def test_declarative_class(self):
assert ( assert (
get_mapper(self.Building) == get_mapper(Building) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_declarative_object(self): def test_declarative_object(self, Building):
assert ( assert (
get_mapper(self.Building()) == get_mapper(Building()) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_mapper(self): def test_mapper(self, Building):
assert ( assert (
get_mapper(self.Building.__mapper__) == get_mapper(Building.__mapper__) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_class_alias(self): def test_class_alias(self, Building):
assert ( assert (
get_mapper(sa.orm.aliased(self.Building)) == get_mapper(sa.orm.aliased(Building)) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_instrumented_attribute(self): def test_instrumented_attribute(self, Building):
assert ( assert (
get_mapper(self.Building.id) == sa.inspect(self.Building) get_mapper(Building.id) == sa.inspect(Building)
) )
def test_table_alias(self): def test_table_alias(self, Building):
alias = sa.orm.aliased(self.Building.__table__) alias = sa.orm.aliased(Building.__table__)
assert ( assert (
get_mapper(alias) == get_mapper(alias) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_column(self): def test_column(self, Building):
assert ( assert (
get_mapper(self.Building.__table__.c.id) == get_mapper(Building.__table__.c.id) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_column_of_an_alias(self): def test_column_of_an_alias(self, Building):
assert ( assert (
get_mapper(sa.orm.aliased(self.Building.__table__).c.id) == get_mapper(sa.orm.aliased(Building.__table__).c.id) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
class TestGetMapperWithQueryEntities(TestCase): class TestGetMapperWithQueryEntities(object):
def create_models(self):
class Building(self.Base): @pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building' __tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
return Building
self.Building = Building @pytest.fixture
def init_models(self, Building):
pass
def test_mapper_entity_with_mapper(self): def test_mapper_entity_with_mapper(self, session, Building):
entity = self.session.query(self.Building.__mapper__)._entities[0] entity = session.query(Building.__mapper__)._entities[0]
assert ( assert (
get_mapper(entity) == get_mapper(entity) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_mapper_entity_with_class(self): def test_mapper_entity_with_class(self, session, Building):
entity = self.session.query(self.Building)._entities[0] entity = session.query(Building)._entities[0]
assert ( assert (
get_mapper(entity) == get_mapper(entity) ==
sa.inspect(self.Building) sa.inspect(Building)
) )
def test_column_entity(self): def test_column_entity(self, session, Building):
query = self.session.query(self.Building.id) query = session.query(Building.id)
assert get_mapper(query._entities[0]) == sa.inspect(self.Building) assert get_mapper(query._entities[0]) == sa.inspect(Building)
class TestGetMapperWithMultipleMappersFound(object): class TestGetMapperWithMultipleMappersFound(object):
def setup_method(self, method):
Base = declarative_base()
@pytest.fixture
def Building(self, Base):
class Building(Base): class Building(Base):
__tablename__ = 'building' __tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
@@ -106,29 +108,30 @@ class TestGetMapperWithMultipleMappersFound(object):
class BigBuilding(Building): class BigBuilding(Building):
pass pass
self.Building = Building return Building
self.BigBuilding = BigBuilding
def test_table(self): def test_table(self, Building):
with raises(ValueError): with pytest.raises(ValueError):
get_mapper(self.Building.__table__) get_mapper(Building.__table__)
def test_table_alias(self): def test_table_alias(self, Building):
alias = sa.orm.aliased(self.Building.__table__) alias = sa.orm.aliased(Building.__table__)
with raises(ValueError): with pytest.raises(ValueError):
get_mapper(alias) get_mapper(alias)
class TestGetMapperForTableWithoutMapper(object): class TestGetMapperForTableWithoutMapper(object):
def setup_method(self, method):
@pytest.fixture
def building(self):
metadata = sa.MetaData() metadata = sa.MetaData()
self.building = sa.Table('building', metadata) return sa.Table('building', metadata)
def test_table(self): def test_table(self, building):
with raises(ValueError): with pytest.raises(ValueError):
get_mapper(self.building) get_mapper(building)
def test_table_alias(self): def test_table_alias(self, building):
alias = sa.orm.aliased(self.building) alias = sa.orm.aliased(building)
with raises(ValueError): with pytest.raises(ValueError):
get_mapper(alias) get_mapper(alias)

View File

@@ -1,5 +1,5 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_primary_keys from sqlalchemy_utils import get_primary_keys
@@ -9,40 +9,40 @@ except ImportError:
from ordereddict import OrderedDict from ordereddict import OrderedDict
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.Unicode(255))
return Building
class TestGetPrimaryKeys(object): class TestGetPrimaryKeys(object):
def setup_method(self, method):
Base = declarative_base()
class Building(Base): def test_table(self, Building):
__tablename__ = 'building' assert get_primary_keys(Building.__table__) == OrderedDict({
id = sa.Column('_id', sa.Integer, primary_key=True) '_id': Building.__table__.c._id
name = sa.Column('_name', sa.Unicode(255))
self.Building = Building
def test_table(self):
assert get_primary_keys(self.Building.__table__) == OrderedDict({
'_id': self.Building.__table__.c._id
}) })
def test_declarative_class(self): def test_declarative_class(self, Building):
assert get_primary_keys(self.Building) == OrderedDict({ assert get_primary_keys(Building) == OrderedDict({
'id': self.Building.__table__.c._id 'id': Building.__table__.c._id
}) })
def test_declarative_object(self): def test_declarative_object(self, Building):
assert get_primary_keys(self.Building()) == OrderedDict({ assert get_primary_keys(Building()) == OrderedDict({
'id': self.Building.__table__.c._id 'id': Building.__table__.c._id
}) })
def test_class_alias(self): def test_class_alias(self, Building):
alias = sa.orm.aliased(self.Building) alias = sa.orm.aliased(Building)
assert get_primary_keys(alias) == OrderedDict({ assert get_primary_keys(alias) == OrderedDict({
'id': self.Building.__table__.c._id 'id': Building.__table__.c._id
}) })
def test_table_alias(self): def test_table_alias(self, Building):
alias = sa.orm.aliased(self.Building.__table__) alias = sa.orm.aliased(Building.__table__)
assert get_primary_keys(alias) == OrderedDict({ assert get_primary_keys(alias) == OrderedDict({
'_id': alias.c._id '_id': alias.c._id
}) })

View File

@@ -1,102 +1,115 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import get_query_entities from sqlalchemy_utils import get_query_entities
from tests import TestCase
class TestGetQueryEntities(TestCase): @pytest.fixture
def create_models(self): def TextItem(Base):
class TextItem(self.Base): class TextItem(Base):
__tablename__ = 'text_item' __tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode(255)) type = sa.Column(sa.Unicode(255))
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': type, 'polymorphic_on': type,
} }
return TextItem
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
class BlogPost(TextItem): @pytest.fixture
__tablename__ = 'blog_post' def Article(TextItem):
id = sa.Column( class Article(TextItem):
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True __tablename__ = 'article'
) id = sa.Column(
__mapper_args__ = { sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
'polymorphic_identity': u'blog_post' )
} category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
return Article
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
def test_mapper(self): @pytest.fixture
query = self.session.query(sa.inspect(self.TextItem)) def BlogPost(TextItem):
assert get_query_entities(query) == [self.TextItem] class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
return BlogPost
def test_entity(self):
query = self.session.query(self.TextItem)
assert get_query_entities(query) == [self.TextItem]
def test_instrumented_attribute(self): @pytest.fixture
query = self.session.query(self.TextItem.id) def init_models(TextItem, Article, BlogPost):
assert get_query_entities(query) == [self.TextItem] pass
def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id)
assert get_query_entities(query) == [self.TextItem.__table__]
def test_aliased_selectable(self): class TestGetQueryEntities(object):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable) def test_mapper(self, session, TextItem):
query = session.query(sa.inspect(TextItem))
assert get_query_entities(query) == [TextItem]
def test_entity(self, session, TextItem):
query = session.query(TextItem)
assert get_query_entities(query) == [TextItem]
def test_instrumented_attribute(self, session, TextItem):
query = session.query(TextItem.id)
assert get_query_entities(query) == [TextItem]
def test_column(self, session, TextItem):
query = session.query(TextItem.__table__.c.id)
assert get_query_entities(query) == [TextItem.__table__]
def test_aliased_selectable(self, session, TextItem, BlogPost):
selectable = sa.orm.with_polymorphic(TextItem, [BlogPost])
query = session.query(selectable)
assert get_query_entities(query) == [selectable] assert get_query_entities(query) == [selectable]
def test_joined_entity(self): def test_joined_entity(self, session, TextItem, BlogPost):
query = self.session.query(self.TextItem).join( query = session.query(TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id BlogPost, BlogPost.id == TextItem.id
) )
assert get_query_entities(query) == [ assert get_query_entities(query) == [
self.TextItem, sa.inspect(self.BlogPost) TextItem, sa.inspect(BlogPost)
] ]
def test_joined_aliased_entity(self): def test_joined_aliased_entity(self, session, TextItem, BlogPost):
alias = sa.orm.aliased(self.BlogPost) alias = sa.orm.aliased(BlogPost)
query = self.session.query(self.TextItem).join( query = session.query(TextItem).join(
alias, alias.id == self.TextItem.id alias, alias.id == TextItem.id
) )
assert get_query_entities(query) == [self.TextItem, alias] assert get_query_entities(query) == [TextItem, alias]
def test_column_entity_with_label(self): def test_column_entity_with_label(self, session, Article):
query = self.session.query(self.Article.id.label('id')) query = session.query(Article.id.label('id'))
assert get_query_entities(query) == [self.Article] assert get_query_entities(query) == [Article]
def test_with_subquery(self): def test_with_subquery(self, session, Article):
number_of_articles = ( number_of_articles = (
sa.select( sa.select(
[sa.func.count(self.Article.id)], [sa.func.count(Article.id)],
) )
.select_from( .select_from(
self.Article.__table__ Article.__table__
) )
).label('number_of_articles') ).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles) query = session.query(Article, number_of_articles)
assert get_query_entities(query) == [ assert get_query_entities(query) == [
self.Article, Article,
number_of_articles number_of_articles
] ]
def test_aliased_entity(self): def test_aliased_entity(self, session, Article):
alias = sa.orm.aliased(self.Article) alias = sa.orm.aliased(Article)
query = self.session.query(alias) query = session.query(alias)
assert get_query_entities(query) == [alias] assert get_query_entities(query) == [alias]

View File

@@ -1,17 +1,22 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import get_referencing_foreign_keys from sqlalchemy_utils import get_referencing_foreign_keys
from tests import TestCase
class TestGetReferencingFksWithCompositeKeys(TestCase): class TestGetReferencingFksWithCompositeKeys(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True) first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True)
return User
class Article(self.Base): @pytest.fixture
def Article(self, Base, User):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255)) author_first_name = sa.Column(sa.Unicode(255))
@@ -22,22 +27,26 @@ class TestGetReferencingFksWithCompositeKeys(TestCase):
[User.first_name, User.last_name] [User.first_name, User.last_name]
), ),
) )
return Article
self.User = User @pytest.fixture
self.Article = Article def init_models(self, User, Article):
pass
def test_with_declarative_class(self): def test_with_declarative_class(self, User, Article):
fks = get_referencing_foreign_keys(self.User) fks = get_referencing_foreign_keys(User)
assert self.Article.__table__.foreign_keys == fks assert Article.__table__.foreign_keys == fks
def test_with_table(self): def test_with_table(self, User, Article):
fks = get_referencing_foreign_keys(self.User.__table__) fks = get_referencing_foreign_keys(User.__table__)
assert self.Article.__table__.foreign_keys == fks assert Article.__table__.foreign_keys == fks
class TestGetReferencingFksWithInheritance(TestCase): class TestGetReferencingFksWithInheritance(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode) type = sa.Column(sa.Unicode)
@@ -47,14 +56,20 @@ class TestGetReferencingFksWithInheritance(TestCase):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': 'type' 'polymorphic_on': 'type'
} }
return User
@pytest.fixture
def Admin(self, User):
class Admin(User): class Admin(User):
__tablename__ = 'admin' __tablename__ = 'admin'
id = sa.Column( id = sa.Column(
sa.Integer, sa.ForeignKey(User.id), primary_key=True sa.Integer, sa.ForeignKey(User.id), primary_key=True
) )
return Admin
class TextItem(self.Base): @pytest.fixture
def TextItem(self, Base, User):
class TextItem(Base):
__tablename__ = 'textitem' __tablename__ = 'textitem'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode) type = sa.Column(sa.Unicode)
@@ -62,7 +77,10 @@ class TestGetReferencingFksWithInheritance(TestCase):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': 'type' 'polymorphic_on': 'type'
} }
return TextItem
@pytest.fixture
def Article(self, TextItem):
class Article(TextItem): class Article(TextItem):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column( id = sa.Column(
@@ -71,16 +89,16 @@ class TestGetReferencingFksWithInheritance(TestCase):
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': 'article' 'polymorphic_identity': 'article'
} }
return Article
self.Admin = Admin @pytest.fixture
self.User = User def init_models(self, User, Admin, TextItem, Article):
self.Article = Article pass
self.TextItem = TextItem
def test_with_declarative_class(self): def test_with_declarative_class(self, Admin, TextItem):
fks = get_referencing_foreign_keys(self.Admin) fks = get_referencing_foreign_keys(Admin)
assert self.TextItem.__table__.foreign_keys == fks assert TextItem.__table__.foreign_keys == fks
def test_with_table(self): def test_with_table(self, Admin):
fks = get_referencing_foreign_keys(self.Admin.__table__) fks = get_referencing_foreign_keys(Admin.__table__)
assert fks == set([]) assert fks == set([])

View File

@@ -1,76 +1,86 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import get_tables from sqlalchemy_utils import get_tables
from tests import TestCase
class TestGetTables(TestCase): @pytest.fixture
def create_models(self): def TextItem(Base):
class TextItem(self.Base): class TextItem(Base):
__tablename__ = 'text_item' __tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
type = sa.Column(sa.Unicode(255)) type = sa.Column(sa.Unicode(255))
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': type, 'polymorphic_on': type,
'with_polymorphic': '*' 'with_polymorphic': '*'
} }
return TextItem
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'article'
}
self.TextItem = TextItem @pytest.fixture
self.Article = Article def Article(TextItem):
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'article'
}
return Article
def test_child_class_using_join_table_inheritance(self):
assert get_tables(self.Article) == [ @pytest.fixture
self.TextItem.__table__, def init_models(TextItem, Article):
self.Article.__table__ pass
class TestGetTables(object):
def test_child_class_using_join_table_inheritance(self, TextItem, Article):
assert get_tables(Article) == [
TextItem.__table__,
Article.__table__
] ]
def test_entity_using_with_polymorphic(self): def test_entity_using_with_polymorphic(self, TextItem, Article):
assert get_tables(self.TextItem) == [ assert get_tables(TextItem) == [
self.TextItem.__table__, TextItem.__table__,
self.Article.__table__ Article.__table__
] ]
def test_instrumented_attribute(self): def test_instrumented_attribute(self, TextItem):
assert get_tables(self.TextItem.name) == [ assert get_tables(TextItem.name) == [
self.TextItem.__table__, TextItem.__table__,
] ]
def test_polymorphic_instrumented_attribute(self): def test_polymorphic_instrumented_attribute(self, TextItem, Article):
assert get_tables(self.Article.id) == [ assert get_tables(Article.id) == [
self.TextItem.__table__, TextItem.__table__,
self.Article.__table__ Article.__table__
] ]
def test_column(self): def test_column(self, Article):
assert get_tables(self.Article.__table__.c.id) == [ assert get_tables(Article.__table__.c.id) == [
self.Article.__table__ Article.__table__
] ]
def test_mapper_entity_with_class(self): def test_mapper_entity_with_class(self, session, TextItem, Article):
query = self.session.query(self.Article) query = session.query(Article)
assert get_tables(query._entities[0]) == [ assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__ TextItem.__table__, Article.__table__
] ]
def test_mapper_entity_with_mapper(self): def test_mapper_entity_with_mapper(self, session, TextItem, Article):
query = self.session.query(sa.inspect(self.Article)) query = session.query(sa.inspect(Article))
assert get_tables(query._entities[0]) == [ assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__ TextItem.__table__, Article.__table__
] ]
def test_column_entity(self): def test_column_entity(self, session, TextItem, Article):
query = self.session.query(self.Article.id) query = session.query(Article.id)
assert get_tables(query._entities[0]) == [ assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__ TextItem.__table__, Article.__table__
] ]

View File

@@ -1,46 +1,49 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_type from sqlalchemy_utils import get_type
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
@pytest.fixture
def Article(Base, User):
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(User)
some_property = sa.orm.column_property(
sa.func.coalesce(id, 1)
)
return Article
class TestGetType(object): class TestGetType(object):
def setup_method(self, method):
Base = declarative_base()
class User(Base): def test_instrumented_attribute(self, Article):
__tablename__ = 'user' assert isinstance(get_type(Article.id), sa.Integer)
id = sa.Column(sa.Integer, primary_key=True)
class Article(Base): def test_column_property(self, Article):
__tablename__ = 'article' assert isinstance(get_type(Article.id.property), sa.Integer)
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) def test_column(self, Article):
author = sa.orm.relationship(User) assert isinstance(get_type(Article.__table__.c.id), sa.Integer)
some_property = sa.orm.column_property( def test_calculated_column_property(self, Article):
sa.func.coalesce(id, 1) assert isinstance(get_type(Article.some_property), sa.Integer)
)
self.Article = Article def test_relationship_property(self, Article, User):
self.User = User assert get_type(Article.author) == User
def test_instrumented_attribute(self): def test_scalar_select(self, Article):
assert isinstance(get_type(self.Article.id), sa.Integer) query = sa.select([Article.id]).as_scalar()
def test_column_property(self):
assert isinstance(get_type(self.Article.id.property), sa.Integer)
def test_column(self):
assert isinstance(get_type(self.Article.__table__.c.id), sa.Integer)
def test_calculated_column_property(self):
assert isinstance(get_type(self.Article.some_property), sa.Integer)
def test_relationship_property(self):
assert get_type(self.Article.author) == self.User
def test_scalar_select(self):
query = sa.select([self.Article.id]).as_scalar()
assert isinstance(get_type(query), sa.Integer) assert isinstance(get_type(query), sa.Integer)

View File

@@ -1,72 +1,94 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr from sqlalchemy_utils.functions import getdotattr
from tests import TestCase
class TestGetDotAttr(TestCase): @pytest.fixture
def create_models(self): def Document(Base):
class Document(self.Base): class Document(Base):
__tablename__ = 'document' __tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Document
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
document_id = sa.Column( @pytest.fixture
sa.Integer, sa.ForeignKey(Document.id) def Section(Base, Document):
) class Section(Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
document = sa.orm.relationship(Document, backref='sections') document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
class SubSection(self.Base): document = sa.orm.relationship(Document, backref='sections')
__tablename__ = 'subsection' return Section
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections') @pytest.fixture
def SubSection(Base, Section):
class SubSection(Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class SubSubSection(self.Base): section_id = sa.Column(
__tablename__ = 'subsubsection' sa.Integer, sa.ForeignKey(Section.id)
id = sa.Column(sa.Integer, primary_key=True) )
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
subsection_id = sa.Column( section = sa.orm.relationship(Section, backref='subsections')
sa.Integer, sa.ForeignKey(SubSection.id) return SubSection
)
subsection = sa.orm.relationship(
SubSection, backref='subsubsections'
)
self.Document = Document @pytest.fixture
self.Section = Section def SubSubSection(Base, SubSection):
self.SubSection = SubSection class SubSubSection(Base):
self.SubSubSection = SubSubSection __tablename__ = 'subsubsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
def test_simple_objects(self): subsection_id = sa.Column(
document = self.Document(name=u'some document') sa.Integer, sa.ForeignKey(SubSection.id)
section = self.Section(document=document) )
subsection = self.SubSection(section=section)
subsection = sa.orm.relationship(
SubSection, backref='subsubsections'
)
return SubSubSection
@pytest.fixture
def init_models(Document, Section, SubSection, SubSubSection):
pass
class TestGetDotAttr(object):
def test_simple_objects(self, Document, Section, SubSection):
document = Document(name=u'some document')
section = Section(document=document)
subsection = SubSection(section=section)
assert getdotattr( assert getdotattr(
subsection, subsection,
'section.document.name' 'section.document.name'
) == u'some document' ) == u'some document'
def test_with_instrumented_lists(self): def test_with_instrumented_lists(
document = self.Document(name=u'some document') self,
section = self.Section(document=document) Document,
subsection = self.SubSection(section=section) Section,
subsubsection = self.SubSubSection(subsection=subsection) SubSection,
SubSubSection
):
document = Document(name=u'some document')
section = Section(document=document)
subsection = SubSection(section=section)
subsubsection = SubSubSection(subsection=subsection)
assert getdotattr(document, 'sections') == [section] assert getdotattr(document, 'sections') == [section]
assert getdotattr(document, 'sections.subsections') == [ assert getdotattr(document, 'sections.subsections') == [
@@ -76,10 +98,10 @@ class TestGetDotAttr(TestCase):
subsubsection subsubsection
] ]
def test_class_paths(self): def test_class_paths(self, Document, Section, SubSection):
assert getdotattr(self.Section, 'document') is self.Section.document assert getdotattr(Section, 'document') is Section.document
assert ( assert (
getdotattr(self.SubSection, 'section.document') is getdotattr(SubSection, 'section.document') is
self.Section.document Section.document
) )
assert getdotattr(self.Section, 'document.name') is self.Document.name assert getdotattr(Section, 'document.name') is Document.name

View File

@@ -1,47 +1,44 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import has_changes from sqlalchemy_utils import has_changes
class HasChangesTestCase(object): @pytest.fixture
def setup_method(self, method): def Article(Base):
Base = declarative_base() class Article(Base):
__tablename__ = 'article_translation'
class Article(Base): id = sa.Column(sa.Integer, primary_key=True)
__tablename__ = 'article_translation' title = sa.Column(sa.String(100))
id = sa.Column(sa.Integer, primary_key=True) return Article
title = sa.Column(sa.String(100))
self.Article = Article
class TestHasChangesWithStringAttr(HasChangesTestCase): class TestHasChangesWithStringAttr(object):
def test_without_changed_attr(self): def test_without_changed_attr(self, Article):
article = self.Article() article = Article()
assert not has_changes(article, 'title') assert not has_changes(article, 'title')
def test_with_changed_attr(self): def test_with_changed_attr(self, Article):
article = self.Article(title='Some title') article = Article(title='Some title')
assert has_changes(article, 'title') assert has_changes(article, 'title')
class TestHasChangesWithMultipleAttrs(HasChangesTestCase): class TestHasChangesWithMultipleAttrs(object):
def test_without_changed_attr(self): def test_without_changed_attr(self, Article):
article = self.Article() article = Article()
assert not has_changes(article, ['title']) assert not has_changes(article, ['title'])
def test_with_changed_attr(self): def test_with_changed_attr(self, Article):
article = self.Article(title='Some title') article = Article(title='Some title')
assert has_changes(article, ['title', 'id']) assert has_changes(article, ['title', 'id'])
class TestHasChangesWithExclude(HasChangesTestCase): class TestHasChangesWithExclude(object):
def test_without_changed_attr(self): def test_without_changed_attr(self, Article):
article = self.Article() article = Article()
assert not has_changes(article, exclude=['id']) assert not has_changes(article, exclude=['id'])
def test_with_changed_attr(self): def test_with_changed_attr(self, Article):
article = self.Article(title='Some title') article = Article(title='Some title')
assert has_changes(article, exclude=['id']) assert has_changes(article, exclude=['id'])
assert not has_changes(article, exclude=['title']) assert not has_changes(article, exclude=['title'])

View File

@@ -1,14 +1,13 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_fk_constraint_for_columns, has_index from sqlalchemy_utils import get_fk_constraint_for_columns, has_index
class TestHasIndex(object): class TestHasIndex(object):
def setup_method(self, method):
Base = declarative_base()
@pytest.fixture
def table(self, Base):
class ArticleTranslation(Base): class ArticleTranslation(Base):
__tablename__ = 'article_translation' __tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
@@ -21,24 +20,23 @@ class TestHasIndex(object):
__table_args__ = ( __table_args__ = (
sa.Index('my_index', is_deleted, is_archived), sa.Index('my_index', is_deleted, is_archived),
) )
return ArticleTranslation.__table__
self.table = ArticleTranslation.__table__ def test_column_that_belongs_to_an_alias(self, table):
alias = sa.orm.aliased(table)
def test_column_that_belongs_to_an_alias(self): with pytest.raises(TypeError):
alias = sa.orm.aliased(self.table)
with raises(TypeError):
assert has_index(alias.c.id) assert has_index(alias.c.id)
def test_compound_primary_key(self): def test_compound_primary_key(self, table):
assert has_index(self.table.c.id) assert has_index(table.c.id)
assert not has_index(self.table.c.locale) assert not has_index(table.c.locale)
def test_single_column_index(self): def test_single_column_index(self, table):
assert has_index(self.table.c.is_published) assert has_index(table.c.is_published)
def test_compound_column_index(self): def test_compound_column_index(self, table):
assert has_index(self.table.c.is_deleted) assert has_index(table.c.is_deleted)
assert not has_index(self.table.c.is_archived) assert not has_index(table.c.is_archived)
def test_table_without_primary_key(self): def test_table_without_primary_key(self):
article = sa.Table( article = sa.Table(
@@ -50,8 +48,7 @@ class TestHasIndex(object):
class TestHasIndexWithFKConstraint(object): class TestHasIndexWithFKConstraint(object):
def test_composite_fk_without_index(self): def test_composite_fk_without_index(self, Base):
Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
@@ -78,8 +75,7 @@ class TestHasIndexWithFKConstraint(object):
) )
assert not has_index(constraint) assert not has_index(constraint)
def test_composite_fk_with_index(self): def test_composite_fk_with_index(self, Base):
Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
@@ -109,8 +105,7 @@ class TestHasIndexWithFKConstraint(object):
) )
assert has_index(constraint) assert has_index(constraint)
def test_composite_fk_with_partial_index_match(self): def test_composite_fk_with_partial_index_match(self, Base):
Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'

View File

@@ -1,18 +1,20 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_fk_constraint_for_columns, has_unique_index from sqlalchemy_utils import get_fk_constraint_for_columns, has_unique_index
class TestHasUniqueIndex(object): class TestHasUniqueIndex(object):
def setup_method(self, method):
Base = declarative_base()
@pytest.fixture
def articles(self, Base):
class Article(Base): class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
return Article.__table__
@pytest.fixture
def article_translations(self, Base):
class ArticleTranslation(Base): class ArticleTranslation(Base):
__tablename__ = 'article_translation' __tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
@@ -26,35 +28,33 @@ class TestHasUniqueIndex(object):
sa.Index('my_index', is_archived, is_published, unique=True), sa.Index('my_index', is_archived, is_published, unique=True),
) )
self.articles = Article.__table__ return ArticleTranslation.__table__
self.article_translations = ArticleTranslation.__table__
def test_primary_key(self): def test_primary_key(self, articles):
assert has_unique_index(self.articles.c.id) assert has_unique_index(articles.c.id)
def test_column_of_aliased_table(self): def test_column_of_aliased_table(self, articles):
alias = sa.orm.aliased(self.articles) alias = sa.orm.aliased(articles)
with raises(TypeError): with pytest.raises(TypeError):
assert has_unique_index(alias.c.id) assert has_unique_index(alias.c.id)
def test_unique_index(self): def test_unique_index(self, article_translations):
assert has_unique_index(self.article_translations.c.is_deleted) assert has_unique_index(article_translations.c.is_deleted)
def test_compound_primary_key(self): def test_compound_primary_key(self, article_translations):
assert not has_unique_index(self.article_translations.c.id) assert not has_unique_index(article_translations.c.id)
assert not has_unique_index(self.article_translations.c.locale) assert not has_unique_index(article_translations.c.locale)
def test_single_column_index(self): def test_single_column_index(self, article_translations):
assert not has_unique_index(self.article_translations.c.is_published) assert not has_unique_index(article_translations.c.is_published)
def test_compound_column_unique_index(self): def test_compound_column_unique_index(self, article_translations):
assert not has_unique_index(self.article_translations.c.is_published) assert not has_unique_index(article_translations.c.is_published)
assert not has_unique_index(self.article_translations.c.is_archived) assert not has_unique_index(article_translations.c.is_archived)
class TestHasUniqueIndexWithFKConstraint(object): class TestHasUniqueIndexWithFKConstraint(object):
def test_composite_fk_without_index(self): def test_composite_fk_without_index(self, Base):
Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
@@ -81,8 +81,7 @@ class TestHasUniqueIndexWithFKConstraint(object):
) )
assert not has_unique_index(constraint) assert not has_unique_index(constraint)
def test_composite_fk_with_index(self): def test_composite_fk_with_index(self, Base):
Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
@@ -115,8 +114,7 @@ class TestHasUniqueIndexWithFKConstraint(object):
) )
assert has_unique_index(constraint) assert has_unique_index(constraint)
def test_composite_fk_with_partial_index_match(self): def test_composite_fk_with_partial_index_match(self, Base):
Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'

View File

@@ -1,39 +1,46 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.functions import identity from sqlalchemy_utils.functions import identity
from tests import TestCase
class IdentityTestCase(TestCase): class IdentityTestCase(object):
def test_for_transient_class_without_id(self):
assert identity(self.Building()) == (None, )
def test_for_transient_class_with_id(self): @pytest.fixture
building = self.Building(name=u'Some building') def init_models(self, Building):
self.session.add(building) pass
self.session.flush()
def test_for_transient_class_without_id(self, Building):
assert identity(Building()) == (None, )
def test_for_transient_class_with_id(self, session, Building):
building = Building(name=u'Some building')
session.add(building)
session.flush()
assert identity(building) == (building.id, ) assert identity(building) == (building.id, )
def test_identity_for_class(self): def test_identity_for_class(self, Building):
assert identity(self.Building) == (self.Building.id, ) assert identity(Building) == (Building.id, )
class TestIdentity(IdentityTestCase): class TestIdentity(IdentityTestCase):
def create_models(self):
class Building(self.Base): @pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building' __tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Building
self.Building = Building
class TestIdentityWithColumnAlias(IdentityTestCase): class TestIdentityWithColumnAlias(IdentityTestCase):
def create_models(self):
class Building(self.Base): @pytest.fixture
def Building(self, Base):
class Building(Base):
__tablename__ = 'building' __tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Building
self.Building = Building

View File

@@ -1,24 +1,24 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import is_loaded from sqlalchemy_utils import is_loaded
@pytest.fixture
def Article(Base):
class Article(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.orm.deferred(sa.Column(sa.String(100)))
return Article
class TestIsLoaded(object): class TestIsLoaded(object):
def setup_method(self, method):
Base = declarative_base()
class Article(Base): def test_loaded_property(self, Article):
__tablename__ = 'article_translation' article = Article(id=1)
id = sa.Column(sa.Integer, primary_key=True)
title = sa.orm.deferred(sa.Column(sa.String(100)))
self.Article = Article
def test_loaded_property(self):
article = self.Article(id=1)
assert is_loaded(article, 'id') assert is_loaded(article, 'id')
def test_unloaded_property(self): def test_unloaded_property(self, Article):
article = self.Article(id=4) article = Article(id=4)
assert not is_loaded(article, 'title') assert not is_loaded(article, 'title')

View File

@@ -2,11 +2,10 @@ import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import json_sql from sqlalchemy_utils import json_sql
from tests import TestCase
class TestJSONSQL(TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestJSONSQL(object):
@pytest.mark.parametrize( @pytest.mark.parametrize(
('value', 'result'), ('value', 'result'),
@@ -27,7 +26,7 @@ class TestJSONSQL(TestCase):
) )
) )
) )
def test_compiled_scalars(self, value, result): def test_compiled_scalars(self, connection, value, result):
assert result == ( assert result == (
self.connection.execute(sa.select([json_sql(value)])).fetchone()[0] connection.execute(sa.select([json_sql(value)])).fetchone()[0]
) )

View File

@@ -1,90 +1,102 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic
from tests import assert_contains, TestCase
from .. import assert_contains
class TestMakeOrderByDeterministic(TestCase): @pytest.fixture
def create_models(self): def Article(Base):
class User(self.Base): class Article(Base):
__tablename__ = 'user' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
email = sa.Column(sa.Unicode, unique=True) author = sa.orm.relationship('User')
return Article
email_lower = sa.orm.column_property(
sa.func.lower(name)
)
class Article(self.Base): @pytest.fixture
__tablename__ = 'article' def User(Base, Article):
id = sa.Column(sa.Integer, primary_key=True) class User(Base):
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) __tablename__ = 'user'
author = sa.orm.relationship(User) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode)
email = sa.Column(sa.Unicode, unique=True)
User.article_count = sa.orm.column_property( email_lower = sa.orm.column_property(
sa.select([sa.func.count()], from_obj=Article) sa.func.lower(name)
.where(Article.author_id == User.id)
.label('article_count')
) )
self.User = User User.article_count = sa.orm.column_property(
self.Article = Article sa.select([sa.func.count()], from_obj=Article)
.where(Article.author_id == User.id)
.label('article_count')
)
return User
def test_column_property(self):
query = self.session.query(self.User).order_by(self.User.email_lower) @pytest.fixture
def init_models(Article, User):
pass
class TestMakeOrderByDeterministic(object):
def test_column_property(self, session, User):
query = session.query(User).order_by(User.email_lower)
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('lower("user".name), "user".id ASC', query) assert_contains('lower("user".name), "user".id ASC', query)
def test_unique_column(self): def test_unique_column(self, session, User):
query = self.session.query(self.User).order_by(self.User.email) query = session.query(User).order_by(User.email)
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert str(query).endswith('ORDER BY "user".email') assert str(query).endswith('ORDER BY "user".email')
def test_non_unique_column(self): def test_non_unique_column(self, session, User):
query = self.session.query(self.User).order_by(self.User.name) query = session.query(User).order_by(User.name)
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name, "user".id ASC', query) assert_contains('ORDER BY "user".name, "user".id ASC', query)
def test_descending_order_by(self): def test_descending_order_by(self, session, User):
query = self.session.query(self.User).order_by( query = session.query(User).order_by(
sa.desc(self.User.name) sa.desc(User.name)
) )
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name DESC, "user".id DESC', query) assert_contains('ORDER BY "user".name DESC, "user".id DESC', query)
def test_ascending_order_by(self): def test_ascending_order_by(self, session, User):
query = self.session.query(self.User).order_by( query = session.query(User).order_by(
sa.asc(self.User.name) sa.asc(User.name)
) )
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name ASC, "user".id ASC', query) assert_contains('ORDER BY "user".name ASC, "user".id ASC', query)
def test_string_order_by(self): def test_string_order_by(self, session, User):
query = self.session.query(self.User).order_by('name') query = session.query(User).order_by('name')
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('ORDER BY "user".name, "user".id ASC', query) assert_contains('ORDER BY "user".name, "user".id ASC', query)
def test_annotated_label(self): def test_annotated_label(self, session, User):
query = self.session.query(self.User).order_by(self.User.article_count) query = session.query(User).order_by(User.article_count)
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('article_count, "user".id ASC', query) assert_contains('article_count, "user".id ASC', query)
def test_annotated_label_with_descending_order(self): def test_annotated_label_with_descending_order(self, session, User):
query = self.session.query(self.User).order_by( query = session.query(User).order_by(
sa.desc(self.User.article_count) sa.desc(User.article_count)
) )
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert_contains('ORDER BY article_count DESC, "user".id DESC', query) assert_contains('ORDER BY article_count DESC, "user".id DESC', query)
def test_query_without_order_by(self): def test_query_without_order_by(self, session, User):
query = self.session.query(self.User) query = session.query(User)
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert 'ORDER BY "user".id' in str(query) assert 'ORDER BY "user".id' in str(query)
def test_alias(self): def test_alias(self, session, User):
alias = sa.orm.aliased(self.User.__table__) alias = sa.orm.aliased(User.__table__)
query = self.session.query(alias).order_by(alias.c.name) query = session.query(alias).order_by(alias.c.name)
query = make_order_by_deterministic(query) query = make_order_by_deterministic(query)
assert str(query).endswith('ORDER BY user_1.name, "user".id ASC') assert str(query).endswith('ORDER BY user_1.name, "user".id ASC')

View File

@@ -1,20 +1,25 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import merge_references from sqlalchemy_utils import merge_references
from tests import TestCase
class TestMergeReferences(TestCase): class TestMergeReferences(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
def __repr__(self): def __repr__(self):
return 'User(%r)' % self.name return 'User(%r)' % self.name
return User
class BlogPost(self.Base): @pytest.fixture
def BlogPost(self, Base, User):
class BlogPost(Base):
__tablename__ = 'blog_post' __tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.Unicode(255)) title = sa.Column(sa.Unicode(255))
@@ -22,35 +27,37 @@ class TestMergeReferences(TestCase):
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User) author = sa.orm.relationship(User)
return BlogPost
self.User = User @pytest.fixture
self.BlogPost = BlogPost def init_models(self, User, BlogPost):
pass
def test_updates_foreign_keys(self): def test_updates_foreign_keys(self, session, User, BlogPost):
john = self.User(name=u'John') john = User(name=u'John')
jack = self.User(name=u'Jack') jack = User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john) post = BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack) post2 = BlogPost(title=u'Other title', author=jack)
self.session.add(john) session.add(john)
self.session.add(jack) session.add(jack)
self.session.add(post) session.add(post)
self.session.add(post2) session.add(post2)
self.session.commit() session.commit()
merge_references(john, jack) merge_references(john, jack)
self.session.commit() session.commit()
assert post.author == jack assert post.author == jack
assert post2.author == jack assert post2.author == jack
def test_object_merging_whenever_possible(self): def test_object_merging_whenever_possible(self, session, User, BlogPost):
john = self.User(name=u'John') john = User(name=u'John')
jack = self.User(name=u'Jack') jack = User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john) post = BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack) post2 = BlogPost(title=u'Other title', author=jack)
self.session.add(john) session.add(john)
self.session.add(jack) session.add(jack)
self.session.add(post) session.add(post)
self.session.add(post2) session.add(post2)
self.session.commit() session.commit()
# Load the author for post # Load the author for post
assert post.author_id == john.id assert post.author_id == john.id
merge_references(john, jack) merge_references(john, jack)
@@ -58,18 +65,23 @@ class TestMergeReferences(TestCase):
assert post2.author_id == jack.id assert post2.author_id == jack.id
class TestMergeReferencesWithManyToManyAssociations(TestCase): class TestMergeReferencesWithManyToManyAssociations(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
def __repr__(self): def __repr__(self):
return 'User(%r)' % self.name return 'User(%r)' % self.name
return User
@pytest.fixture
def Team(self, Base):
team_member = sa.Table( team_member = sa.Table(
'team_member', self.Base.metadata, 'team_member', Base.metadata,
sa.Column( sa.Column(
'user_id', sa.Integer, 'user_id', sa.Integer,
sa.ForeignKey('user.id', ondelete='CASCADE'), sa.ForeignKey('user.id', ondelete='CASCADE'),
@@ -82,46 +94,56 @@ class TestMergeReferencesWithManyToManyAssociations(TestCase):
) )
) )
class Team(self.Base): class Team(Base):
__tablename__ = 'team' __tablename__ = 'team'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
members = sa.orm.relationship( members = sa.orm.relationship(
User, 'User',
secondary=team_member, secondary=team_member,
backref='teams' backref='teams'
) )
return Team
self.User = User @pytest.fixture
self.Team = Team def init_models(self, User, Team):
pass
def test_supports_associations(self): def test_supports_associations(self, session, User, Team):
john = self.User(name=u'John') john = User(name=u'John')
jack = self.User(name=u'Jack') jack = User(name=u'Jack')
team = self.Team(name=u'Team') team = Team(name=u'Team')
team.members.append(john) team.members.append(john)
self.session.add(john) session.add(john)
self.session.add(jack) session.add(jack)
self.session.commit() session.commit()
merge_references(john, jack) merge_references(john, jack)
assert john not in team.members assert john not in team.members
assert jack in team.members assert jack in team.members
class TestMergeReferencesWithManyToManyAssociationObjects(TestCase): class TestMergeReferencesWithManyToManyAssociationObjects(object):
def create_models(self):
class Team(self.Base): @pytest.fixture
def Team(self, Base):
class Team(Base):
__tablename__ = 'team' __tablename__ = 'team'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Team
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return User
class TeamMember(self.Base): @pytest.fixture
def TeamMember(self, Base, User, Team):
class TeamMember(Base):
__tablename__ = 'team_member' __tablename__ = 'team_member'
user_id = sa.Column( user_id = sa.Column(
sa.Integer, sa.Integer,
@@ -150,22 +172,23 @@ class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
), ),
primaryjoin=user_id == User.id, primaryjoin=user_id == User.id,
) )
return TeamMember
self.User = User @pytest.fixture
self.TeamMember = TeamMember def init_models(self, User, Team, TeamMember):
self.Team = Team pass
def test_supports_associations(self): def test_supports_associations(self, session, User, Team, TeamMember):
john = self.User(name=u'John') john = User(name=u'John')
jack = self.User(name=u'Jack') jack = User(name=u'Jack')
team = self.Team(name=u'Team') team = Team(name=u'Team')
team.members.append(self.TeamMember(user=john)) team.members.append(TeamMember(user=john))
self.session.add(john) session.add(john)
self.session.add(jack) session.add(jack)
self.session.add(team) session.add(team)
self.session.commit() session.commit()
merge_references(john, jack) merge_references(john, jack)
self.session.commit() session.commit()
users = [member.user for member in team.members] users = [member.user for member in team.members]
assert john not in users assert john not in users
assert jack in users assert jack in users

View File

@@ -1,14 +1,13 @@
from sqlalchemy_utils.functions import naturally_equivalent from sqlalchemy_utils.functions import naturally_equivalent
from tests import TestCase
class TestNaturallyEquivalent(TestCase): class TestNaturallyEquivalent(object):
def test_returns_true_when_properties_match(self): def test_returns_true_when_properties_match(self, User):
assert naturally_equivalent( assert naturally_equivalent(
self.User(name=u'someone'), self.User(name=u'someone') User(name=u'someone'), User(name=u'someone')
) )
def test_skips_primary_keys(self): def test_skips_primary_keys(self, User):
assert naturally_equivalent( assert naturally_equivalent(
self.User(id=1, name=u'someone'), self.User(id=2, name=u'someone') User(id=1, name=u'someone'), User(id=2, name=u'someone')
) )

View File

@@ -1,24 +1,32 @@
from itertools import chain from itertools import chain
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.functions import non_indexed_foreign_keys from sqlalchemy_utils.functions import non_indexed_foreign_keys
from tests import TestCase
class TestFindNonIndexedForeignKeys(TestCase): class TestFindNonIndexedForeignKeys(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return User
class Category(self.Base): @pytest.fixture
def Category(self, Base):
class Category(Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Category
class Article(self.Base): @pytest.fixture
def Article(self, Base, User, Category):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@@ -34,13 +42,14 @@ class TestFindNonIndexedForeignKeys(TestCase):
'articles', 'articles',
) )
) )
return Article
self.User = User @pytest.fixture
self.Category = Category def init_models(self, User, Category, Article):
self.Article = Article pass
def test_finds_all_non_indexed_fks(self): def test_finds_all_non_indexed_fks(self, session, Base, engine):
fks = non_indexed_foreign_keys(self.Base.metadata, self.engine) fks = non_indexed_foreign_keys(Base.metadata, engine)
assert ( assert (
'article' in 'article' in
fks fks

View File

@@ -1,18 +1,22 @@
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy_utils.functions import quote from sqlalchemy_utils.functions import quote
from tests import TestCase
class TestQuote(TestCase): class TestQuote(object):
def test_quote_with_preserved_keyword(self): def test_quote_with_preserved_keyword(self, engine, connection, session):
assert quote(self.connection, 'order') == '"order"' assert quote(connection, 'order') == '"order"'
assert quote(self.session, 'order') == '"order"' assert quote(session, 'order') == '"order"'
assert quote(self.engine, 'order') == '"order"' assert quote(engine, 'order') == '"order"'
assert quote(postgresql.dialect(), 'order') == '"order"' assert quote(postgresql.dialect(), 'order') == '"order"'
def test_quote_with_non_preserved_keyword(self): def test_quote_with_non_preserved_keyword(
assert quote(self.connection, 'some_order') == 'some_order' self,
assert quote(self.session, 'some_order') == 'some_order' engine,
assert quote(self.engine, 'some_order') == 'some_order' connection,
session
):
assert quote(connection, 'some_order') == 'some_order'
assert quote(session, 'some_order') == 'some_order'
assert quote(engine, 'some_order') == 'some_order'
assert quote(postgresql.dialect(), 'some_order') == 'some_order' assert quote(postgresql.dialect(), 'some_order') == 'some_order'

View File

@@ -1,3 +1,4 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.functions import ( from sqlalchemy_utils.functions import (
@@ -5,52 +6,58 @@ from sqlalchemy_utils.functions import (
render_expression, render_expression,
render_statement render_statement
) )
from tests import TestCase
class TestRender(TestCase): class TestRender(object):
def create_models(self):
class User(self.Base): @pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return User
self.User = User @pytest.fixture
def init_models(self, User):
pass
def test_render_orm_query(self): def test_render_orm_query(self, session, User):
query = self.session.query(self.User).filter_by(id=3) query = session.query(User).filter_by(id=3)
text = render_statement(query) text = render_statement(query)
assert 'SELECT user.id, user.name' in text assert 'SELECT user.id, user.name' in text
assert 'FROM user' in text assert 'FROM user' in text
assert 'WHERE user.id = 3' in text assert 'WHERE user.id = 3' in text
def test_render_statement(self): def test_render_statement(self, session, User):
statement = self.User.__table__.select().where(self.User.id == 3) statement = User.__table__.select().where(User.id == 3)
text = render_statement(statement, bind=self.session.bind) text = render_statement(statement, bind=session.bind)
assert 'SELECT user.id, user.name' in text assert 'SELECT user.id, user.name' in text
assert 'FROM user' in text assert 'FROM user' in text
assert 'WHERE user.id = 3' in text assert 'WHERE user.id = 3' in text
def test_render_statement_without_mapper(self): def test_render_statement_without_mapper(self, session):
statement = sa.select([sa.text('1')]) statement = sa.select([sa.text('1')])
text = render_statement(statement, bind=self.session.bind) text = render_statement(statement, bind=session.bind)
assert 'SELECT 1' in text assert 'SELECT 1' in text
def test_render_ddl(self): def test_render_ddl(self, engine, User):
expression = 'self.User.__table__.create(engine)' expression = 'User.__table__.create(engine)'
stream = render_expression(expression, self.engine) stream = render_expression(expression, engine)
text = stream.getvalue() text = stream.getvalue()
assert 'CREATE TABLE user' in text assert 'CREATE TABLE user' in text
assert 'PRIMARY KEY' in text assert 'PRIMARY KEY' in text
def test_render_mock_ddl(self): def test_render_mock_ddl(self, engine, User):
# TODO: mock_engine doesn't seem to work with locally scoped variables.
self.engine = engine
with mock_engine('self.engine') as stream: with mock_engine('self.engine') as stream:
self.User.__table__.create(self.engine) User.__table__.create(self.engine)
text = stream.getvalue() text = stream.getvalue()

View File

@@ -1,26 +1,33 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import table_name from sqlalchemy_utils import table_name
from tests import TestCase
class TestTableName(TestCase): @pytest.fixture
def create_models(self): def Building(Base):
class Building(self.Base): class Building(Base):
__tablename__ = 'building' __tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
return Building
self.Building = Building
def test_class(self): @pytest.fixture
assert table_name(self.Building) == 'building' def init_models(Base):
del self.Building.__tablename__ pass
assert table_name(self.Building) == 'building'
def test_attribute(self):
assert table_name(self.Building.id) == 'building'
assert table_name(self.Building.name) == 'building'
def test_target(self): class TestTableName(object):
assert table_name(self.Building()) == 'building'
def test_class(self, Building):
assert table_name(Building) == 'building'
del Building.__tablename__
assert table_name(Building) == 'building'
def test_attribute(self, Building):
assert table_name(Building.id) == 'building'
assert table_name(Building.name) == 'building'
def test_target(self, Building):
assert table_name(Building()) == 'building'

View File

@@ -1,109 +1,105 @@
from __future__ import unicode_literals
import six import six
from tests import TestCase
class GenericRelationshipTestCase(object):
class GenericRelationshipTestCase(TestCase): def test_set_as_none(self, Event):
def test_set_as_none(self): event = Event()
event = self.Event()
event.object = None event.object = None
assert event.object is None assert event.object is None
def test_set_manual_and_get(self): def test_set_manual_and_get(self, session, User, Event):
user = self.User() user = User()
self.session.add(user) session.add(user)
self.session.commit() session.commit()
event = self.Event() event = Event()
event.object_id = user.id event.object_id = user.id
event.object_type = six.text_type(type(user).__name__) event.object_type = six.text_type(type(user).__name__)
assert event.object is None assert event.object is None
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == user assert event.object == user
def test_set_and_get(self): def test_set_and_get(self, session, User, Event):
user = self.User() user = User()
self.session.add(user) session.add(user)
self.session.commit() session.commit()
event = self.Event(object=user) event = Event(object=user)
assert event.object_id == user.id assert event.object_id == user.id
assert event.object_type == type(user).__name__ assert event.object_type == type(user).__name__
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == user assert event.object == user
def test_compare_instance(self): def test_compare_instance(self, session, User, Event):
user1 = self.User() user1 = User()
user2 = self.User() user2 = User()
self.session.add_all([user1, user2]) session.add_all([user1, user2])
self.session.commit() session.commit()
event = self.Event(object=user1) event = Event(object=user1)
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == user1 assert event.object == user1
assert event.object != user2 assert event.object != user2
def test_compare_query(self): def test_compare_query(self, session, User, Event):
user1 = self.User() user1 = User()
user2 = self.User() user2 = User()
self.session.add_all([user1, user2]) session.add_all([user1, user2])
self.session.commit() session.commit()
event = self.Event(object=user1) event = Event(object=user1)
self.session.add(event) session.add(event)
self.session.commit() session.commit()
q = self.session.query(self.Event) q = session.query(Event)
assert q.filter_by(object=user1).first() is not None assert q.filter_by(object=user1).first() is not None
assert q.filter_by(object=user2).first() is None assert q.filter_by(object=user2).first() is None
assert q.filter(self.Event.object == user2).first() is None assert q.filter(Event.object == user2).first() is None
def test_compare_not_query(self): def test_compare_not_query(self, session, User, Event):
user1 = self.User() user1 = User()
user2 = self.User() user2 = User()
self.session.add_all([user1, user2]) session.add_all([user1, user2])
self.session.commit() session.commit()
event = self.Event(object=user1) event = Event(object=user1)
self.session.add(event) session.add(event)
self.session.commit() session.commit()
q = self.session.query(self.Event) q = session.query(Event)
assert q.filter(self.Event.object != user2).first() is not None assert q.filter(Event.object != user2).first() is not None
def test_compare_type(self): def test_compare_type(self, session, User, Event):
user1 = self.User() user1 = User()
user2 = self.User() user2 = User()
self.session.add_all([user1, user2]) session.add_all([user1, user2])
self.session.commit() session.commit()
event1 = self.Event(object=user1) event1 = Event(object=user1)
event2 = self.Event(object=user2) event2 = Event(object=user2)
self.session.add_all([event1, event2]) session.add_all([event1, event2])
self.session.commit() session.commit()
statement = self.Event.object.is_type(self.User) statement = Event.object.is_type(User)
q = self.session.query(self.Event).filter(statement) q = session.query(Event).filter(statement)
assert q.first() is not None assert q.first() is not None

View File

@@ -1,36 +1,54 @@
from __future__ import unicode_literals import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
from . import GenericRelationshipTestCase
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
return Building
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
@pytest.fixture
def EventBase(Base):
class EventBase(Base):
__abstract__ = True
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@declared_attr
def object(cls):
return generic_relationship('object_type', 'object_id')
return EventBase
@pytest.fixture
def Event(EventBase):
class Event(EventBase):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
return Event
@pytest.fixture
def init_models(Building, User, Event):
pass
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
def create_models(self): pass
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class EventBase(self.Base):
__abstract__ = True
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@declared_attr
def object(cls):
return generic_relationship('object_type', 'object_id')
class Event(EventBase):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.User = User
self.Event = Event

View File

@@ -1,30 +1,44 @@
from __future__ import unicode_literals import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
from . import GenericRelationshipTestCase
@pytest.fixture
def Building(Base):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
return Building
@pytest.fixture
def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
return User
@pytest.fixture
def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255), name="objectType")
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
return Event
@pytest.fixture
def init_models(Building, User, Event):
pass
class TestGenericRelationship(GenericRelationshipTestCase): class TestGenericRelationship(GenericRelationshipTestCase):
def create_models(self): pass
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255), name="objectType")
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
self.Building = Building
self.User = User
self.Event = Event

View File

@@ -1,66 +1,84 @@
from __future__ import unicode_literals import pytest
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
from ..generic_relationship import GenericRelationshipTestCase
@pytest.fixture
def incrementor():
class Incrementor(object):
value = 1
return Incrementor()
@pytest.fixture
def Building(Base, incrementor):
class Building(Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(self):
incrementor.value += 1
self.id = incrementor.value
self.code = incrementor.value
return Building
@pytest.fixture
def User(Base, incrementor):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(self):
incrementor.value += 1
self.id = incrementor.value
self.code = incrementor.value
return User
@pytest.fixture
def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object_code = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(
object_type, (object_id, object_code)
)
return Event
@pytest.fixture
def init_models(Building, User, Event):
pass
class TestGenericRelationship(GenericRelationshipTestCase): class TestGenericRelationship(GenericRelationshipTestCase):
index = 1
def create_models(self): def test_set_manual_and_get(self, session, Event, User):
class Building(self.Base): user = User()
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(obj_self): session.add(user)
self.index += 1 session.commit()
obj_self.id = self.index
obj_self.code = self.index
class User(self.Base): event = Event()
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(obj_self):
self.index += 1
obj_self.id = self.index
obj_self.code = self.index
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object_code = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(
object_type, (object_id, object_code)
)
self.Building = Building
self.User = User
self.Event = Event
def test_set_manual_and_get(self):
user = self.User()
self.session.add(user)
self.session.commit()
event = self.Event()
event.object_id = user.id event.object_id = user.id
event.object_type = six.text_type(type(user).__name__) event.object_type = six.text_type(type(user).__name__)
event.object_code = user.code event.object_code = user.code
assert event.object is None assert event.object is None
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == user assert event.object == user

View File

@@ -1,68 +1,79 @@
from __future__ import unicode_literals import pytest
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
from tests import TestCase
class TestGenericRelationship(TestCase): @pytest.fixture
def create_models(self): def User(Base):
class User(self.Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
return User
class UserHistory(self.Base):
__tablename__ = 'user_history'
id = sa.Column(sa.Integer, primary_key=True)
transaction_id = sa.Column(sa.Integer, primary_key=True) @pytest.fixture
def UserHistory(Base):
class UserHistory(Base):
__tablename__ = 'user_history'
id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base): transaction_id = sa.Column(sa.Integer, primary_key=True)
__tablename__ = 'event' return UserHistory
id = sa.Column(sa.Integer, primary_key=True)
transaction_id = sa.Column(sa.Integer)
object_type = sa.Column(sa.Unicode(255)) @pytest.fixture
object_id = sa.Column(sa.Integer, nullable=False) def Event(Base):
class Event(Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object = generic_relationship( transaction_id = sa.Column(sa.Integer)
object_type, object_id
)
@hybrid_property object_type = sa.Column(sa.Unicode(255))
def object_version_type(self): object_id = sa.Column(sa.Integer, nullable=False)
return self.object_type + 'History'
@object_version_type.expression object = generic_relationship(
def object_version_type(cls): object_type, object_id
return sa.func.concat(cls.object_type, 'History') )
object_version = generic_relationship( @hybrid_property
object_version_type, (object_id, transaction_id) def object_version_type(self):
) return self.object_type + 'History'
self.User = User @object_version_type.expression
self.UserHistory = UserHistory def object_version_type(cls):
self.Event = Event return sa.func.concat(cls.object_type, 'History')
def test_set_manual_and_get(self): object_version = generic_relationship(
user = self.User(id=1) object_version_type, (object_id, transaction_id)
history = self.UserHistory(id=1, transaction_id=1) )
self.session.add(user) return Event
self.session.add(history)
self.session.commit()
event = self.Event(transaction_id=1)
@pytest.fixture
def init_models(User, UserHistory, Event):
pass
class TestGenericRelationship(object):
def test_set_manual_and_get(self, session, User, UserHistory, Event):
user = User(id=1)
history = UserHistory(id=1, transaction_id=1)
session.add(user)
session.add(history)
session.commit()
event = Event(transaction_id=1)
event.object_id = user.id event.object_id = user.id
event.object_type = six.text_type(type(user).__name__) event.object_type = six.text_type(type(user).__name__)
assert event.object is None assert event.object is None
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == user assert event.object == user
assert event.object_version == history assert event.object_version == history

View File

@@ -1,164 +1,178 @@
from __future__ import unicode_literals import pytest
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
from tests import TestCase
class TestGenericRelationship(TestCase): @pytest.fixture
def create_models(self): def Employee(Base):
class Employee(self.Base): class Employee(Base):
__tablename__ = 'employee' __tablename__ = 'employee'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(50)) name = sa.Column(sa.String(50))
type = sa.Column(sa.String(20)) type = sa.Column(sa.String(20))
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': type, 'polymorphic_on': type,
'polymorphic_identity': 'employee' 'polymorphic_identity': 'employee'
} }
return Employee
class Manager(Employee):
__mapper_args__ = {
'polymorphic_identity': 'manager'
}
class Engineer(Employee): @pytest.fixture
__mapper_args__ = { def Manager(Employee):
'polymorphic_identity': 'engineer' class Manager(Employee):
} __mapper_args__ = {
'polymorphic_identity': 'manager'
}
return Manager
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255)) @pytest.fixture
object_id = sa.Column(sa.Integer, nullable=False) def Engineer(Employee):
class Engineer(Employee):
__mapper_args__ = {
'polymorphic_identity': 'engineer'
}
return Engineer
object = generic_relationship(object_type, object_id)
self.Employee = Employee @pytest.fixture
self.Manager = Manager def Event(Base):
self.Engineer = Engineer class Event(Base):
self.Event = Event __tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
def test_set_as_none(self): object_type = sa.Column(sa.Unicode(255))
event = self.Event() object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
return Event
@pytest.fixture
def init_models(Employee, Manager, Engineer, Event):
pass
class TestGenericRelationship(object):
def test_set_as_none(self, Event):
event = Event()
event.object = None event.object = None
assert event.object is None assert event.object is None
def test_set_manual_and_get(self): def test_set_manual_and_get(self, session, Manager, Event):
manager = self.Manager() manager = Manager()
self.session.add(manager) session.add(manager)
self.session.commit() session.commit()
event = self.Event() event = Event()
event.object_id = manager.id event.object_id = manager.id
event.object_type = six.text_type(type(manager).__name__) event.object_type = six.text_type(type(manager).__name__)
assert event.object is None assert event.object is None
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == manager assert event.object == manager
def test_set_and_get(self): def test_set_and_get(self, session, Manager, Event):
manager = self.Manager() manager = Manager()
self.session.add(manager) session.add(manager)
self.session.commit() session.commit()
event = self.Event(object=manager) event = Event(object=manager)
assert event.object_id == manager.id assert event.object_id == manager.id
assert event.object_type == type(manager).__name__ assert event.object_type == type(manager).__name__
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == manager assert event.object == manager
def test_compare_instance(self): def test_compare_instance(self, session, Manager, Event):
manager1 = self.Manager() manager1 = Manager()
manager2 = self.Manager() manager2 = Manager()
self.session.add_all([manager1, manager2]) session.add_all([manager1, manager2])
self.session.commit() session.commit()
event = self.Event(object=manager1) event = Event(object=manager1)
self.session.add(event) session.add(event)
self.session.commit() session.commit()
assert event.object == manager1 assert event.object == manager1
assert event.object != manager2 assert event.object != manager2
def test_compare_query(self): def test_compare_query(self, session, Manager, Event):
manager1 = self.Manager() manager1 = Manager()
manager2 = self.Manager() manager2 = Manager()
self.session.add_all([manager1, manager2]) session.add_all([manager1, manager2])
self.session.commit() session.commit()
event = self.Event(object=manager1) event = Event(object=manager1)
self.session.add(event) session.add(event)
self.session.commit() session.commit()
q = self.session.query(self.Event) q = session.query(Event)
assert q.filter_by(object=manager1).first() is not None assert q.filter_by(object=manager1).first() is not None
assert q.filter_by(object=manager2).first() is None assert q.filter_by(object=manager2).first() is None
assert q.filter(self.Event.object == manager2).first() is None assert q.filter(Event.object == manager2).first() is None
def test_compare_not_query(self): def test_compare_not_query(self, session, Manager, Event):
manager1 = self.Manager() manager1 = Manager()
manager2 = self.Manager() manager2 = Manager()
self.session.add_all([manager1, manager2]) session.add_all([manager1, manager2])
self.session.commit() session.commit()
event = self.Event(object=manager1) event = Event(object=manager1)
self.session.add(event) session.add(event)
self.session.commit() session.commit()
q = self.session.query(self.Event) q = session.query(Event)
assert q.filter(self.Event.object != manager2).first() is not None assert q.filter(Event.object != manager2).first() is not None
def test_compare_type(self): def test_compare_type(self, session, Manager, Event):
manager1 = self.Manager() manager1 = Manager()
manager2 = self.Manager() manager2 = Manager()
self.session.add_all([manager1, manager2]) session.add_all([manager1, manager2])
self.session.commit() session.commit()
event1 = self.Event(object=manager1) event1 = Event(object=manager1)
event2 = self.Event(object=manager2) event2 = Event(object=manager2)
self.session.add_all([event1, event2]) session.add_all([event1, event2])
self.session.commit() session.commit()
statement = self.Event.object.is_type(self.Manager) statement = Event.object.is_type(Manager)
q = self.session.query(self.Event).filter(statement) q = session.query(Event).filter(statement)
assert q.first() is not None assert q.first() is not None
def test_compare_super_type(self): def test_compare_super_type(self, session, Manager, Event, Employee):
manager1 = self.Manager() manager1 = Manager()
manager2 = self.Manager() manager2 = Manager()
self.session.add_all([manager1, manager2]) session.add_all([manager1, manager2])
self.session.commit() session.commit()
event1 = self.Event(object=manager1) event1 = Event(object=manager1)
event2 = self.Event(object=manager2) event2 = Event(object=manager2)
self.session.add_all([event1, event2]) session.add_all([event1, event2])
self.session.commit() session.commit()
statement = self.Event.object.is_type(self.Employee) statement = Event.object.is_type(Employee)
q = self.session.query(self.Event).filter(statement) q = session.query(Event).filter(statement)
assert q.first() is not None assert q.first() is not None

View File

@@ -1,18 +1,24 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
class ThreeLevelDeepOneToOne(object): class ThreeLevelDeepOneToOne(object):
def create_models(self):
class Catalog(self.Base): @pytest.fixture
def Catalog(self, Base, Category):
class Catalog(Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
category = sa.orm.relationship( category = sa.orm.relationship(
'Category', Category,
uselist=False, uselist=False,
backref='catalog' backref='catalog'
) )
return Catalog
class Category(self.Base): @pytest.fixture
def Category(self, Base, SubCategory):
class Category(Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column( catalog_id = sa.Column(
@@ -22,12 +28,15 @@ class ThreeLevelDeepOneToOne(object):
) )
sub_category = sa.orm.relationship( sub_category = sa.orm.relationship(
'SubCategory', SubCategory,
uselist=False, uselist=False,
backref='category' backref='category'
) )
return Category
class SubCategory(self.Base): @pytest.fixture
def SubCategory(self, Base, Product):
class SubCategory(Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column( category_id = sa.Column(
@@ -36,12 +45,15 @@ class ThreeLevelDeepOneToOne(object):
sa.ForeignKey('category._id') sa.ForeignKey('category._id')
) )
product = sa.orm.relationship( product = sa.orm.relationship(
'Product', Product,
uselist=False, uselist=False,
backref='sub_category' backref='sub_category'
) )
return SubCategory
class Product(self.Base): @pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Integer) price = sa.Column(sa.Integer)
@@ -51,22 +63,27 @@ class ThreeLevelDeepOneToOne(object):
sa.Integer, sa.Integer,
sa.ForeignKey('sub_category._id') sa.ForeignKey('sub_category._id')
) )
return Product
self.Catalog = Catalog @pytest.fixture
self.Category = Category def init_models(self, Catalog, Category, SubCategory, Product):
self.SubCategory = SubCategory pass
self.Product = Product
class ThreeLevelDeepOneToMany(object): class ThreeLevelDeepOneToMany(object):
def create_models(self):
class Catalog(self.Base): @pytest.fixture
def Catalog(self, Base, Category):
class Catalog(Base):
__tablename__ = 'catalog' __tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship('Category', backref='catalog') categories = sa.orm.relationship(Category, backref='catalog')
return Catalog
class Category(self.Base): @pytest.fixture
def Category(self, Base, SubCategory):
class Category(Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
catalog_id = sa.Column( catalog_id = sa.Column(
@@ -76,10 +93,13 @@ class ThreeLevelDeepOneToMany(object):
) )
sub_categories = sa.orm.relationship( sub_categories = sa.orm.relationship(
'SubCategory', backref='category' SubCategory, backref='category'
) )
return Category
class SubCategory(self.Base): @pytest.fixture
def SubCategory(self, Base, Product):
class SubCategory(Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
category_id = sa.Column( category_id = sa.Column(
@@ -88,11 +108,14 @@ class ThreeLevelDeepOneToMany(object):
sa.ForeignKey('category._id') sa.ForeignKey('category._id')
) )
products = sa.orm.relationship( products = sa.orm.relationship(
'Product', Product,
backref='sub_category' backref='sub_category'
) )
return SubCategory
class Product(self.Base): @pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric) price = sa.Column(sa.Numeric)
@@ -105,25 +128,42 @@ class ThreeLevelDeepOneToMany(object):
def __repr__(self): def __repr__(self):
return '<Product id=%r>' % self.id return '<Product id=%r>' % self.id
return Product
self.Catalog = Catalog @pytest.fixture
self.Category = Category def init_models(self, Catalog, Category, SubCategory, Product):
self.SubCategory = SubCategory pass
self.Product = Product
class ThreeLevelDeepManyToMany(object): class ThreeLevelDeepManyToMany(object):
def create_models(self):
@pytest.fixture
def Catalog(self, Base, Category):
catalog_category = sa.Table( catalog_category = sa.Table(
'catalog_category', 'catalog_category',
self.Base.metadata, Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')), sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id')) sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id'))
) )
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship(
Category,
backref='catalogs',
secondary=catalog_category
)
return Catalog
@pytest.fixture
def Category(self, Base, SubCategory):
category_subcategory = sa.Table( category_subcategory = sa.Table(
'category_subcategory', 'category_subcategory',
self.Base.metadata, Base.metadata,
sa.Column( sa.Column(
'category_id', 'category_id',
sa.Integer, sa.Integer,
@@ -136,9 +176,23 @@ class ThreeLevelDeepManyToMany(object):
) )
) )
class Category(Base):
__tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
SubCategory,
backref='categories',
secondary=category_subcategory
)
return Category
@pytest.fixture
def SubCategory(self, Base, Product):
subcategory_product = sa.Table( subcategory_product = sa.Table(
'subcategory_product', 'subcategory_product',
self.Base.metadata, Base.metadata,
sa.Column( sa.Column(
'subcategory_id', 'subcategory_id',
sa.Integer, sa.Integer,
@@ -151,41 +205,24 @@ class ThreeLevelDeepManyToMany(object):
) )
) )
class Catalog(self.Base): class SubCategory(Base):
__tablename__ = 'catalog'
id = sa.Column('_id', sa.Integer, primary_key=True)
categories = sa.orm.relationship(
'Category',
backref='catalogs',
secondary=catalog_category
)
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column('_id', sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
'SubCategory',
backref='categories',
secondary=category_subcategory
)
class SubCategory(self.Base):
__tablename__ = 'sub_category' __tablename__ = 'sub_category'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
products = sa.orm.relationship( products = sa.orm.relationship(
'Product', Product,
backref='sub_categories', backref='sub_categories',
secondary=subcategory_product secondary=subcategory_product
) )
return SubCategory
class Product(self.Base): @pytest.fixture
def Product(self, Base):
class Product(Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric) price = sa.Column(sa.Numeric)
return Product
self.Catalog = Catalog @pytest.fixture
self.Category = Category def init_models(self, Catalog, Category, SubCategory, Product):
self.SubCategory = SubCategory pass
self.Product = Product

View File

@@ -1,15 +1,15 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy_utils.observer import observes from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForColumn(TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestObservesForColumn(object):
def create_models(self): @pytest.fixture
class Product(self.Base): def Product(self, Base):
class Product(Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer) price = sa.Column(sa.Integer)
@@ -17,21 +17,25 @@ class TestObservesForColumn(TestCase):
@observes('price') @observes('price')
def product_price_observer(self, price): def product_price_observer(self, price):
self.price = price * 2 self.price = price * 2
return Product
self.Product = Product @pytest.fixture
def init_models(self, Product):
pass
def test_simple_insert(self): def test_simple_insert(self, session, Product):
product = self.Product(price=100) product = Product(price=100)
self.session.add(product) session.add(product)
self.session.flush() session.flush()
assert product.price == 200 assert product.price == 200
class TestObservesForColumnWithoutActualChanges(TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestObservesForColumnWithoutActualChanges(object):
def create_models(self): @pytest.fixture
class Product(self.Base): def Product(self, Base):
class Product(Base):
__tablename__ = 'product' __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer) price = sa.Column(sa.Integer)
@@ -39,15 +43,18 @@ class TestObservesForColumnWithoutActualChanges(TestCase):
@observes('price') @observes('price')
def product_price_observer(self, price): def product_price_observer(self, price):
raise Exception('Trying to change price') raise Exception('Trying to change price')
return Product
self.Product = Product @pytest.fixture
def init_models(self, Product):
pass
def test_only_notifies_observer_on_actual_changes(self): def test_only_notifies_observer_on_actual_changes(self, session, Product):
product = self.Product() product = Product()
self.session.add(product) session.add(product)
self.session.flush() session.flush()
with raises(Exception) as e: with pytest.raises(Exception) as e:
product.price = 500 product.price = 500
self.session.commit() session.commit()
assert str(e.value) == 'Trying to change price' assert str(e.value) == 'Trying to change price'

View File

@@ -1,137 +1,158 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.observer import observes from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForManyToManyToManyToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Catalog(Base):
catalog_category = sa.Table(
'catalog_category',
Base.metadata,
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id'))
)
def create_models(self): class Catalog(Base):
catalog_category = sa.Table( __tablename__ = 'catalog'
'catalog_category', id = sa.Column(sa.Integer, primary_key=True)
self.Base.metadata, product_count = sa.Column(sa.Integer, default=0)
sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) @observes('categories.sub_categories.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship(
'Category',
backref='catalogs',
secondary=catalog_category
)
return Catalog
@pytest.fixture
def Category(Base):
category_subcategory = sa.Table(
'category_subcategory',
Base.metadata,
sa.Column(
'category_id',
sa.Integer,
sa.ForeignKey('category.id')
),
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
)
)
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship(
'SubCategory',
backref='categories',
secondary=category_subcategory
)
return Category
@pytest.fixture
def SubCategory(Base):
subcategory_product = sa.Table(
'subcategory_product',
Base.metadata,
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
),
sa.Column(
'product_id',
sa.Integer,
sa.ForeignKey('product.id')
)
)
class SubCategory(Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
products = sa.orm.relationship(
'Product',
backref='sub_categories',
secondary=subcategory_product
) )
category_subcategory = sa.Table( return SubCategory
'category_subcategory',
self.Base.metadata,
sa.Column(
'category_id',
sa.Integer,
sa.ForeignKey('category.id')
),
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
)
)
subcategory_product = sa.Table(
'subcategory_product',
self.Base.metadata,
sa.Column(
'subcategory_id',
sa.Integer,
sa.ForeignKey('sub_category.id')
),
sa.Column(
'product_id',
sa.Integer,
sa.ForeignKey('product.id')
)
)
class Catalog(self.Base): @pytest.fixture
__tablename__ = 'catalog' def Product(Base):
id = sa.Column(sa.Integer, primary_key=True) class Product(Base):
product_count = sa.Column(sa.Integer, default=0) __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
return Product
@observes('categories.sub_categories.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship( @pytest.fixture
'Category', def init_models(Catalog, Category, SubCategory, Product):
backref='catalogs', pass
secondary=catalog_category
)
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
sub_categories = sa.orm.relationship( @pytest.fixture
'SubCategory', def catalog(session, Catalog, Category, SubCategory, Product):
backref='categories', sub_category = SubCategory(products=[Product()])
secondary=category_subcategory category = Category(sub_categories=[sub_category])
) catalog = Catalog(categories=[category])
session.add(catalog)
session.flush()
return catalog
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
products = sa.orm.relationship(
'Product',
backref='sub_categories',
secondary=subcategory_product
)
class Product(self.Base): @pytest.mark.usefixtures('postgresql_dsn')
__tablename__ = 'product' class TestObservesForManyToManyToManyToMany(object):
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
self.Catalog = Catalog def test_simple_insert(self, catalog):
self.Category = Category
self.SubCategory = SubCategory
self.Product = Product
def create_catalog(self):
sub_category = self.SubCategory(products=[self.Product()])
category = self.Category(sub_categories=[sub_category])
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.session.flush()
return catalog
def test_simple_insert(self):
catalog = self.create_catalog()
assert catalog.product_count == 1 assert catalog.product_count == 1
def test_add_leaf_object(self): def test_add_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product()
product = self.Product()
catalog.categories[0].sub_categories[0].products.append(product) catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush() session.flush()
assert catalog.product_count == 2 assert catalog.product_count == 2
def test_remove_leaf_object(self): def test_remove_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product()
product = self.Product()
catalog.categories[0].sub_categories[0].products.append(product) catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush() session.flush()
self.session.delete(product) session.delete(product)
self.session.flush() session.flush()
assert catalog.product_count == 1 assert catalog.product_count == 1
def test_delete_intermediate_object(self): def test_delete_intermediate_object(self, catalog, session):
catalog = self.create_catalog() session.delete(catalog.categories[0].sub_categories[0])
self.session.delete(catalog.categories[0].sub_categories[0]) session.commit()
self.session.commit()
assert catalog.product_count == 0 assert catalog.product_count == 0
def test_gathered_objects_are_distinct(self): def test_gathered_objects_are_distinct(
catalog = self.Catalog() self,
category = self.Category(catalogs=[catalog]) session,
product = self.Product() Catalog,
Category,
SubCategory,
Product
):
catalog = Catalog()
category = Category(catalogs=[catalog])
product = Product()
category.sub_categories.append( category.sub_categories.append(
self.SubCategory(products=[product]) SubCategory(products=[product])
) )
self.session.add( session.add(
self.SubCategory(categories=[category], products=[product]) SubCategory(categories=[category], products=[product])
) )
self.session.commit() session.commit()
assert catalog.product_count == 1 assert catalog.product_count == 1

View File

@@ -1,107 +1,127 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.observer import observes from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesFor3LevelDeepOneToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
def create_models(self): @observes('categories.sub_categories.products')
class Catalog(self.Base): def product_observer(self, products):
__tablename__ = 'catalog' self.product_count = len(products)
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@observes('categories.sub_categories.products') categories = sa.orm.relationship('Category', backref='catalog')
def product_observer(self, products): return Catalog
self.product_count = len(products)
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base): @pytest.fixture
__tablename__ = 'category' def Category(Base):
id = sa.Column(sa.Integer, primary_key=True) class Category(Base):
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_categories = sa.orm.relationship( sub_categories = sa.orm.relationship(
'SubCategory', backref='category' 'SubCategory', backref='category'
) )
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship(
'Product',
backref='sub_category'
)
class Product(self.Base): @pytest.fixture
__tablename__ = 'product' def SubCategory(Base):
id = sa.Column(sa.Integer, primary_key=True) class SubCategory(Base):
price = sa.Column(sa.Numeric) __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship(
'Product',
backref='sub_category'
)
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
def __repr__(self): @pytest.fixture
return '<Product id=%r>' % self.id def Product(Base):
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
self.Catalog = Catalog sub_category_id = sa.Column(
self.Category = Category sa.Integer, sa.ForeignKey('sub_category.id')
self.SubCategory = SubCategory )
self.Product = Product
def create_catalog(self): def __repr__(self):
sub_category = self.SubCategory(products=[self.Product()]) return '<Product id=%r>' % self.id
category = self.Category(sub_categories=[sub_category]) return Product
catalog = self.Catalog(categories=[category])
self.session.add(catalog)
self.session.commit()
return catalog
def test_simple_insert(self):
catalog = self.create_catalog() @pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(products=[Product()])
category = Category(sub_categories=[sub_category])
catalog = Catalog(categories=[category])
session.add(catalog)
session.commit()
return catalog
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesFor3LevelDeepOneToMany(object):
def test_simple_insert(self, catalog):
assert catalog.product_count == 1 assert catalog.product_count == 1
def test_add_leaf_object(self): def test_add_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product()
product = self.Product()
catalog.categories[0].sub_categories[0].products.append(product) catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush() session.flush()
assert catalog.product_count == 2 assert catalog.product_count == 2
def test_remove_leaf_object(self): def test_remove_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product()
product = self.Product()
catalog.categories[0].sub_categories[0].products.append(product) catalog.categories[0].sub_categories[0].products.append(product)
self.session.flush() session.flush()
self.session.delete(product) session.delete(product)
self.session.commit() session.commit()
assert catalog.product_count == 1 assert catalog.product_count == 1
self.session.delete( session.delete(
catalog.categories[0].sub_categories[0].products[0] catalog.categories[0].sub_categories[0].products[0]
) )
self.session.commit() session.commit()
assert catalog.product_count == 0 assert catalog.product_count == 0
def test_delete_intermediate_object(self): def test_delete_intermediate_object(self, catalog, session):
catalog = self.create_catalog() session.delete(catalog.categories[0].sub_categories[0])
self.session.delete(catalog.categories[0].sub_categories[0]) session.commit()
self.session.commit()
assert catalog.product_count == 0 assert catalog.product_count == 0
def test_gathered_objects_are_distinct(self): def test_gathered_objects_are_distinct(
catalog = self.Catalog() self,
category = self.Category(catalog=catalog) session,
product = self.Product() Catalog,
Category,
SubCategory,
Product
):
catalog = Catalog()
category = Category(catalog=catalog)
product = Product()
category.sub_categories.append( category.sub_categories.append(
self.SubCategory(products=[product]) SubCategory(products=[product])
) )
self.session.add( session.add(
self.SubCategory(category=category, products=[product]) SubCategory(category=category, products=[product])
) )
self.session.commit() session.commit()
assert catalog.product_count == 1 assert catalog.product_count == 1

View File

@@ -1,96 +1,116 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.observer import observes from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForOneToManyToOneToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
def create_models(self): @observes('categories.sub_category.products')
class Catalog(self.Base): def product_observer(self, products):
__tablename__ = 'catalog' self.product_count = len(products)
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@observes('categories.sub_category.products') categories = sa.orm.relationship('Category', backref='catalog')
def product_observer(self, products): return Catalog
self.product_count = len(products)
categories = sa.orm.relationship('Category', backref='catalog')
class Category(self.Base): @pytest.fixture
__tablename__ = 'category' def Category(Base):
id = sa.Column(sa.Integer, primary_key=True) class Category(Base):
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_category = sa.orm.relationship( sub_category = sa.orm.relationship(
'SubCategory', 'SubCategory',
uselist=False, uselist=False,
backref='category' backref='category'
) )
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
class Product(self.Base): @pytest.fixture
__tablename__ = 'product' def SubCategory(Base):
id = sa.Column(sa.Integer, primary_key=True) class SubCategory(Base):
price = sa.Column(sa.Numeric) __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
products = sa.orm.relationship('Product', backref='sub_category')
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog @pytest.fixture
self.Category = Category def Product(Base):
self.SubCategory = SubCategory class Product(Base):
self.Product = Product __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
def create_catalog(self): sub_category_id = sa.Column(
sub_category = self.SubCategory(products=[self.Product()]) sa.Integer, sa.ForeignKey('sub_category.id')
category = self.Category(sub_category=sub_category) )
catalog = self.Catalog(categories=[category]) return Product
self.session.add(catalog)
self.session.flush()
return catalog
def test_simple_insert(self):
catalog = self.create_catalog() @pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(products=[Product()])
category = Category(sub_category=sub_category)
catalog = Catalog(categories=[category])
session.add(catalog)
session.flush()
return catalog
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForOneToManyToOneToMany(object):
def test_simple_insert(self, catalog):
assert catalog.product_count == 1 assert catalog.product_count == 1
def test_add_leaf_object(self): def test_add_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product()
product = self.Product()
catalog.categories[0].sub_category.products.append(product) catalog.categories[0].sub_category.products.append(product)
self.session.flush() session.flush()
assert catalog.product_count == 2 assert catalog.product_count == 2
def test_remove_leaf_object(self): def test_remove_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product()
product = self.Product()
catalog.categories[0].sub_category.products.append(product) catalog.categories[0].sub_category.products.append(product)
self.session.flush() session.flush()
self.session.delete(product) session.delete(product)
self.session.flush() session.flush()
assert catalog.product_count == 1 assert catalog.product_count == 1
def test_delete_intermediate_object(self): def test_delete_intermediate_object(self, catalog, session):
catalog = self.create_catalog() session.delete(catalog.categories[0].sub_category)
self.session.delete(catalog.categories[0].sub_category) session.commit()
self.session.commit()
assert catalog.product_count == 0 assert catalog.product_count == 0
def test_gathered_objects_are_distinct(self): def test_gathered_objects_are_distinct(
catalog = self.Catalog() self,
category = self.Category(catalog=catalog) session,
product = self.Product() Catalog,
category.sub_category = self.SubCategory(products=[product]) Category,
self.session.add( SubCategory,
self.Category(catalog=catalog, sub_category=category.sub_category) Product
):
catalog = Catalog()
category = Category(catalog=catalog)
product = Product()
category.sub_category = SubCategory(products=[product])
session.add(
Category(catalog=catalog, sub_category=category.sub_category)
) )
self.session.commit() session.commit()
assert catalog.product_count == 1 assert catalog.product_count == 1

View File

@@ -1,53 +1,66 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.observer import observes from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForOneToManyToOneToMany(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Device(Base):
class Device(Base):
__tablename__ = 'device'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
return Device
def create_models(self):
class Device(self.Base):
__tablename__ = 'device'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
class Order(self.Base): @pytest.fixture
__tablename__ = 'order' def Order(Base):
id = sa.Column(sa.Integer, primary_key=True) class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
device_id = sa.Column( device_id = sa.Column(
'device', sa.ForeignKey('device.id'), nullable=False 'device', sa.ForeignKey('device.id'), nullable=False
)
device = sa.orm.relationship('Device', backref='orders')
return Order
@pytest.fixture
def SalesInvoice(Base):
class SalesInvoice(Base):
__tablename__ = 'sales_invoice'
id = sa.Column(sa.Integer, primary_key=True)
order_id = sa.Column(
'order',
sa.ForeignKey('order.id'),
nullable=False
)
order = sa.orm.relationship(
'Order',
backref=sa.orm.backref(
'invoice',
uselist=False
) )
device = sa.orm.relationship('Device', backref='orders') )
device_name = sa.Column(sa.String)
class SalesInvoice(self.Base): @observes('order.device')
__tablename__ = 'sales_invoice' def process_device(self, device):
id = sa.Column(sa.Integer, primary_key=True) self.device_name = device.name
order_id = sa.Column(
'order',
sa.ForeignKey('order.id'),
nullable=False
)
order = sa.orm.relationship(
'Order',
backref=sa.orm.backref(
'invoice',
uselist=False
)
)
device_name = sa.Column(sa.String)
@observes('order.device') return SalesInvoice
def process_device(self, device):
self.device_name = device.name
self.Device = Device
self.Order = Order
self.SalesInvoice = SalesInvoice
def test_observable_root_obj_is_none(self): @pytest.fixture
order = self.Order(device=self.Device(name='Something')) def init_models(Device, Order, SalesInvoice):
self.session.add(order) pass
self.session.flush()
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForOneToManyToOneToMany(object):
def test_observable_root_obj_is_none(self, session, Device, Order):
order = Order(device=Device(name='Something'))
session.add(order)
session.flush()

View File

@@ -1,84 +1,98 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.observer import observes from sqlalchemy_utils.observer import observes
from tests import TestCase
class TestObservesForOneToOneToOneToOne(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def Catalog(Base):
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_price = sa.Column(sa.Integer)
def create_models(self): @observes('category.sub_category.product')
class Catalog(self.Base): def product_observer(self, product):
__tablename__ = 'catalog' self.product_price = product.price if product else None
id = sa.Column(sa.Integer, primary_key=True)
product_price = sa.Column(sa.Integer)
@observes('category.sub_category.product') category = sa.orm.relationship(
def product_observer(self, product): 'Category',
self.product_price = product.price if product else None uselist=False,
backref='catalog'
)
return Catalog
category = sa.orm.relationship(
'Category',
uselist=False,
backref='catalog'
)
class Category(self.Base): @pytest.fixture
__tablename__ = 'category' def Category(Base):
id = sa.Column(sa.Integer, primary_key=True) class Category(Base):
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
sub_category = sa.orm.relationship( sub_category = sa.orm.relationship(
'SubCategory', 'SubCategory',
uselist=False, uselist=False,
backref='category' backref='category'
) )
return Category
class SubCategory(self.Base):
__tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
product = sa.orm.relationship(
'Product',
uselist=False,
backref='sub_category'
)
class Product(self.Base): @pytest.fixture
__tablename__ = 'product' def SubCategory(Base):
id = sa.Column(sa.Integer, primary_key=True) class SubCategory(Base):
price = sa.Column(sa.Integer) __tablename__ = 'sub_category'
id = sa.Column(sa.Integer, primary_key=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
product = sa.orm.relationship(
'Product',
uselist=False,
backref='sub_category'
)
return SubCategory
sub_category_id = sa.Column(
sa.Integer, sa.ForeignKey('sub_category.id')
)
self.Catalog = Catalog @pytest.fixture
self.Category = Category def Product(Base):
self.SubCategory = SubCategory class Product(Base):
self.Product = Product __tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Integer)
def create_catalog(self): sub_category_id = sa.Column(
sub_category = self.SubCategory(product=self.Product(price=123)) sa.Integer, sa.ForeignKey('sub_category.id')
category = self.Category(sub_category=sub_category) )
catalog = self.Catalog(category=category) return Product
self.session.add(catalog)
self.session.flush()
return catalog
def test_simple_insert(self):
catalog = self.create_catalog() @pytest.fixture
def init_models(Catalog, Category, SubCategory, Product):
pass
@pytest.fixture
def catalog(session, Catalog, Category, SubCategory, Product):
sub_category = SubCategory(product=Product(price=123))
category = Category(sub_category=sub_category)
catalog = Catalog(category=category)
session.add(catalog)
session.flush()
return catalog
@pytest.mark.usefixtures('postgresql_dsn')
class TestObservesForOneToOneToOneToOne(object):
def test_simple_insert(self, catalog):
assert catalog.product_price == 123 assert catalog.product_price == 123
def test_replace_leaf_object(self): def test_replace_leaf_object(self, catalog, session, Product):
catalog = self.create_catalog() product = Product(price=44)
product = self.Product(price=44)
catalog.category.sub_category.product = product catalog.category.sub_category.product = product
self.session.flush() session.flush()
assert catalog.product_price == 44 assert catalog.product_price == 44
def test_delete_leaf_object(self): def test_delete_leaf_object(self, catalog, session):
catalog = self.create_catalog() session.delete(catalog.category.sub_category.product)
self.session.delete(catalog.category.sub_category.product) session.flush()
self.session.flush()
assert catalog.product_price is None assert catalog.product_price is None

View File

@@ -1,32 +1,36 @@
import pytest
import six import six
from pytest import mark, raises
from sqlalchemy_utils import Country, i18n from sqlalchemy_utils import Country, i18n
@mark.skipif('i18n.babel is None') @pytest.fixture
def set_get_locale():
i18n.get_locale = lambda: i18n.babel.Locale('en')
@pytest.mark.skipif('i18n.babel is None')
@pytest.mark.usefixtures('set_get_locale')
class TestCountry(object): class TestCountry(object):
def setup_method(self, method):
i18n.get_locale = lambda: i18n.babel.Locale('en')
def test_init(self): def test_init(self):
assert Country(u'FI') == Country(Country(u'FI')) assert Country(u'FI') == Country(Country(u'FI'))
def test_constructor_with_wrong_type(self): def test_constructor_with_wrong_type(self):
with raises(TypeError) as e: with pytest.raises(TypeError) as e:
Country(None) Country(None)
assert str(e.value) == ( assert str(e.value) == (
"Country() argument must be a string or a country, not 'NoneType'" "Country() argument must be a string or a country, not 'NoneType'"
) )
def test_constructor_with_invalid_code(self): def test_constructor_with_invalid_code(self):
with raises(ValueError) as e: with pytest.raises(ValueError) as e:
Country('SomeUnknownCode') Country('SomeUnknownCode')
assert str(e.value) == ( assert str(e.value) == (
'Could not convert string to country code: SomeUnknownCode' 'Could not convert string to country code: SomeUnknownCode'
) )
@mark.parametrize( @pytest.mark.parametrize(
'code', 'code',
( (
'FI', 'FI',
@@ -37,7 +41,7 @@ class TestCountry(object):
Country.validate(code) Country.validate(code)
def test_validate_with_invalid_code(self): def test_validate_with_invalid_code(self):
with raises(ValueError) as e: with pytest.raises(ValueError) as e:
Country.validate('SomeUnknownCode') Country.validate('SomeUnknownCode')
assert str(e.value) == ( assert str(e.value) == (
'Could not convert string to country code: SomeUnknownCode' 'Could not convert string to country code: SomeUnknownCode'

View File

@@ -1,14 +1,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import pytest
import six import six
from pytest import mark, raises
from sqlalchemy_utils import Currency, i18n from sqlalchemy_utils import Currency, i18n
@mark.skipif('i18n.babel is None') @pytest.fixture
def set_get_locale():
i18n.get_locale = lambda: i18n.babel.Locale('en')
@pytest.mark.skipif('i18n.babel is None')
@pytest.mark.usefixtures('set_get_locale')
class TestCurrency(object): class TestCurrency(object):
def setup_method(self, method):
i18n.get_locale = lambda: i18n.babel.Locale('en')
def test_init(self): def test_init(self):
assert Currency('USD') == Currency(Currency('USD')) assert Currency('USD') == Currency(Currency('USD'))
@@ -17,14 +21,14 @@ class TestCurrency(object):
assert len(set([Currency('USD'), Currency('USD')])) == 1 assert len(set([Currency('USD'), Currency('USD')])) == 1
def test_invalid_currency_code(self): def test_invalid_currency_code(self):
with raises(ValueError): with pytest.raises(ValueError):
Currency('Unknown code') Currency('Unknown code')
def test_invalid_currency_code_type(self): def test_invalid_currency_code_type(self):
with raises(TypeError): with pytest.raises(TypeError):
Currency(None) Currency(None)
@mark.parametrize( @pytest.mark.parametrize(
('code', 'name'), ('code', 'name'),
( (
('USD', 'US Dollar'), ('USD', 'US Dollar'),
@@ -34,7 +38,7 @@ class TestCurrency(object):
def test_name_property(self, code, name): def test_name_property(self, code, name):
assert Currency(code).name == name assert Currency(code).name == name
@mark.parametrize( @pytest.mark.parametrize(
('code', 'symbol'), ('code', 'symbol'),
( (
('USD', u'$'), ('USD', u'$'),

View File

@@ -6,10 +6,14 @@ from sqlalchemy_utils import i18n
from sqlalchemy_utils.primitives import WeekDay, WeekDays from sqlalchemy_utils.primitives import WeekDay, WeekDays
@pytest.fixture
def set_get_locale():
i18n.get_locale = lambda: i18n.babel.Locale('fi')
@pytest.mark.skipif('i18n.babel is None') @pytest.mark.skipif('i18n.babel is None')
@pytest.mark.usefixtures('set_get_locale')
class TestWeekDay(object): class TestWeekDay(object):
def setup_method(self, method):
i18n.get_locale = lambda: i18n.babel.Locale('fi')
def test_constructor_with_valid_index(self): def test_constructor_with_valid_index(self):
day = WeekDay(1) day = WeekDay(1)

View File

@@ -1,26 +1,27 @@
import pytest
from sqlalchemy_utils.relationships import chained_join from sqlalchemy_utils.relationships import chained_join
from tests import TestCase
from tests.mixins import ( from ..mixins import (
ThreeLevelDeepManyToMany, ThreeLevelDeepManyToMany,
ThreeLevelDeepOneToMany, ThreeLevelDeepOneToMany,
ThreeLevelDeepOneToOne ThreeLevelDeepOneToOne
) )
class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany):
create_tables = False
def test_simple_join(self): def test_simple_join(self, Catalog):
assert str(chained_join(self.Catalog.categories)) == ( assert str(chained_join(Catalog.categories)) == (
'catalog_category JOIN category ON ' 'catalog_category JOIN category ON '
'category._id = catalog_category.category_id' 'category._id = catalog_category.category_id'
) )
def test_two_relations(self): def test_two_relations(self, Catalog, Category):
sql = chained_join( sql = chained_join(
self.Catalog.categories, Catalog.categories,
self.Category.sub_categories Category.sub_categories
) )
assert str(sql) == ( assert str(sql) == (
'catalog_category JOIN category ON category._id = ' 'catalog_category JOIN category ON category._id = '
@@ -30,11 +31,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
'category_subcategory.subcategory_id' 'category_subcategory.subcategory_id'
) )
def test_three_relations(self): def test_three_relations(self, Catalog, Category, SubCategory):
sql = chained_join( sql = chained_join(
self.Catalog.categories, Catalog.categories,
self.Category.sub_categories, Category.sub_categories,
self.SubCategory.products SubCategory.products
) )
assert str(sql) == ( assert str(sql) == (
'catalog_category JOIN category ON category._id = ' 'catalog_category JOIN category ON category._id = '
@@ -47,28 +48,27 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase):
) )
class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany):
create_tables = False
def test_simple_join(self): def test_simple_join(self, Catalog):
assert str(chained_join(self.Catalog.categories)) == 'category' assert str(chained_join(Catalog.categories)) == 'category'
def test_two_relations(self): def test_two_relations(self, Catalog, Category):
sql = chained_join( sql = chained_join(
self.Catalog.categories, Catalog.categories,
self.Category.sub_categories Category.sub_categories
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category._id = ' 'category JOIN sub_category ON category._id = '
'sub_category._category_id' 'sub_category._category_id'
) )
def test_three_relations(self): def test_three_relations(self, Catalog, Category, SubCategory):
sql = chained_join( sql = chained_join(
self.Catalog.categories, Catalog.categories,
self.Category.sub_categories, Category.sub_categories,
self.SubCategory.products SubCategory.products
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category._id = ' 'category JOIN sub_category ON category._id = '
@@ -77,28 +77,27 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase):
) )
class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase): @pytest.mark.usefixtures('postgresql_dsn')
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne):
create_tables = False
def test_simple_join(self): def test_simple_join(self, Catalog):
assert str(chained_join(self.Catalog.category)) == 'category' assert str(chained_join(Catalog.category)) == 'category'
def test_two_relations(self): def test_two_relations(self, Catalog, Category):
sql = chained_join( sql = chained_join(
self.Catalog.category, Catalog.category,
self.Category.sub_category Category.sub_category
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category._id = ' 'category JOIN sub_category ON category._id = '
'sub_category._category_id' 'sub_category._category_id'
) )
def test_three_relations(self): def test_three_relations(self, Catalog, Category, SubCategory):
sql = chained_join( sql = chained_join(
self.Catalog.category, Catalog.category,
self.Category.sub_category, Category.sub_category,
self.SubCategory.product SubCategory.product
) )
assert str(sql) == ( assert str(sql) == (
'category JOIN sub_category ON category._id = ' 'category JOIN sub_category ON category._id = '

View File

@@ -1,31 +1,23 @@
import pytest import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils.relationships import select_correlated_expression from sqlalchemy_utils.relationships import select_correlated_expression
@pytest.fixture(scope='class') @pytest.fixture
def base(): def group_user_tbl(Base):
return declarative_base()
@pytest.fixture(scope='class')
def group_user_cls(base):
return sa.Table( return sa.Table(
'group_user', 'group_user',
base.metadata, Base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
) )
@pytest.fixture(scope='class') @pytest.fixture
def group_cls(base): def group_tbl(Base):
class Group(base): class Group(Base):
__tablename__ = 'group' __tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String) name = sa.Column(sa.String)
@@ -33,11 +25,11 @@ def group_cls(base):
return Group return Group
@pytest.fixture(scope='class') @pytest.fixture
def friendship_cls(base): def friendship_tbl(Base):
return sa.Table( return sa.Table(
'friendships', 'friendships',
base.metadata, Base.metadata,
sa.Column( sa.Column(
'friend_a_id', 'friend_a_id',
sa.Integer, sa.Integer,
@@ -53,35 +45,37 @@ def friendship_cls(base):
) )
@pytest.fixture(scope='class') @pytest.fixture
def user_cls(base, group_user_cls, friendship_cls): def User(Base, group_user_tbl, friendship_tbl):
class User(base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String) name = sa.Column(sa.String)
groups = sa.orm.relationship( groups = sa.orm.relationship(
'Group', 'Group',
secondary=group_user_cls, secondary=group_user_tbl,
backref='users' backref='users'
) )
# this relationship is used for persistence # this relationship is used for persistence
friends = sa.orm.relationship( friends = sa.orm.relationship(
'User', 'User',
secondary=friendship_cls, secondary=friendship_tbl,
primaryjoin=id == friendship_cls.c.friend_a_id, primaryjoin=id == friendship_tbl.c.friend_a_id,
secondaryjoin=id == friendship_cls.c.friend_b_id, secondaryjoin=id == friendship_tbl.c.friend_b_id,
) )
friendship_union = sa.select([ friendship_union = (
friendship_cls.c.friend_a_id, sa.select([
friendship_cls.c.friend_b_id friendship_tbl.c.friend_a_id,
friendship_tbl.c.friend_b_id
]).union( ]).union(
sa.select([ sa.select([
friendship_cls.c.friend_b_id, friendship_tbl.c.friend_b_id,
friendship_cls.c.friend_a_id] friendship_tbl.c.friend_a_id]
) )
).alias() ).alias()
)
User.all_friends = sa.orm.relationship( User.all_friends = sa.orm.relationship(
'User', 'User',
@@ -94,9 +88,9 @@ def user_cls(base, group_user_cls, friendship_cls):
return User return User
@pytest.fixture(scope='class') @pytest.fixture
def category_cls(base, group_user_cls, friendship_cls): def Category(Base, group_user_tbl, friendship_tbl):
class Category(base): class Category(Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String) name = sa.Column(sa.String)
@@ -111,9 +105,9 @@ def category_cls(base, group_user_cls, friendship_cls):
return Category return Category
@pytest.fixture(scope='class') @pytest.fixture
def article_cls(base, category_cls, user_cls): def Article(Base, Category, User):
class Article(base): class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column('_id', sa.Integer, primary_key=True) id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column(sa.String) name = sa.Column(sa.String)
@@ -129,144 +123,104 @@ def article_cls(base, category_cls, user_cls):
content = sa.Column(sa.String) content = sa.Column(sa.String)
category_id = sa.Column(sa.Integer, sa.ForeignKey(category_cls.id)) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(category_cls, backref='articles') category = sa.orm.relationship(Category, backref='articles')
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id)) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship( author = sa.orm.relationship(
user_cls, User,
primaryjoin=author_id == user_cls.id, primaryjoin=author_id == User.id,
backref='authored_articles' backref='authored_articles'
) )
owner_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id)) owner_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
owner = sa.orm.relationship( owner = sa.orm.relationship(
user_cls, User,
primaryjoin=owner_id == user_cls.id, primaryjoin=owner_id == User.id,
backref='owned_articles' backref='owned_articles'
) )
return Article return Article
@pytest.fixture(scope='class') @pytest.fixture
def comment_cls(base, article_cls, user_cls): def Comment(Base, Article, User):
class Comment(base): class Comment(Base):
__tablename__ = 'comment' __tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.String) content = sa.Column(sa.String)
article_id = sa.Column(sa.Integer, sa.ForeignKey(article_cls.id)) article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
article = sa.orm.relationship(article_cls, backref='comments') article = sa.orm.relationship(Article, backref='comments')
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id)) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(user_cls, backref='comments') author = sa.orm.relationship(User, backref='comments')
article_cls.comment_count = sa.orm.column_property( Article.comment_count = sa.orm.column_property(
sa.select([sa.func.count(Comment.id)]) sa.select([sa.func.count(Comment.id)])
.where(Comment.article_id == article_cls.id) .where(Comment.article_id == Article.id)
.correlate_except(article_cls) .correlate_except(Article)
) )
return Comment return Comment
@pytest.fixture(scope='class') @pytest.fixture
def composite_pk_cls(base): def model_mapping(Article, Category, Comment, group_tbl, User):
class CompositePKModel(base):
__tablename__ = 'composite_pk_model'
a = sa.Column(sa.Integer, primary_key=True)
b = sa.Column(sa.Integer, primary_key=True)
return CompositePKModel
@pytest.fixture(scope='class')
def dns():
return 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.yield_fixture(scope='class')
def engine(dns):
engine = create_engine(dns)
engine.echo = True
yield engine
engine.dispose()
@pytest.yield_fixture(scope='class')
def connection(engine):
conn = engine.connect()
yield conn
conn.close()
@pytest.fixture(scope='class')
def model_mapping(article_cls, category_cls, comment_cls, group_cls, user_cls):
return { return {
'articles': article_cls, 'articles': Article,
'categories': category_cls, 'categories': Category,
'comments': comment_cls, 'comments': Comment,
'groups': group_cls, 'groups': group_tbl,
'users': user_cls 'users': User
} }
@pytest.yield_fixture(scope='class') @pytest.fixture
def table_creator(base, connection, model_mapping): def init_models(Article, Category, Comment, group_tbl, User):
sa.orm.configure_mappers() pass
base.metadata.create_all(connection)
yield
base.metadata.drop_all(connection)
@pytest.yield_fixture(scope='class') @pytest.fixture
def session(connection):
Session = sessionmaker(bind=connection)
session = Session()
yield session
session.close_all()
@pytest.fixture(scope='class')
def dataset( def dataset(
session, session,
user_cls, User,
group_cls, group_tbl,
article_cls, Article,
category_cls, Category,
comment_cls Comment
): ):
group = group_cls(name='Group 1') group = group_tbl(name='Group 1')
group2 = group_cls(name='Group 2') group2 = group_tbl(name='Group 2')
user = user_cls(id=1, name='User 1', groups=[group, group2]) user = User(id=1, name='User 1', groups=[group, group2])
user2 = user_cls(id=2, name='User 2') user2 = User(id=2, name='User 2')
user3 = user_cls(id=3, name='User 3', groups=[group]) user3 = User(id=3, name='User 3', groups=[group])
user4 = user_cls(id=4, name='User 4', groups=[group2]) user4 = User(id=4, name='User 4', groups=[group2])
user5 = user_cls(id=5, name='User 5') user5 = User(id=5, name='User 5')
user.friends = [user2] user.friends = [user2]
user2.friends = [user3, user4] user2.friends = [user3, user4]
user3.friends = [user5] user3.friends = [user5]
article = article_cls( article = Article(
name='Some article', name='Some article',
author=user, author=user,
owner=user2, owner=user2,
category=category_cls( category=Category(
id=1, id=1,
name='Some category', name='Some category',
subcategories=[ subcategories=[
category_cls( Category(
id=2, id=2,
name='Subcategory 1', name='Subcategory 1',
subcategories=[ subcategories=[
category_cls( Category(
id=3, id=3,
name='Subsubcategory 1', name='Subsubcategory 1',
subcategories=[ subcategories=[
category_cls( Category(
id=5, id=5,
name='Subsubsubcategory 1', name='Subsubsubcategory 1',
), ),
category_cls( Category(
id=6, id=6,
name='Subsubsubcategory 2', name='Subsubsubcategory 2',
) )
@@ -274,11 +228,11 @@ def dataset(
) )
] ]
), ),
category_cls(id=4, name='Subcategory 2'), Category(id=4, name='Subcategory 2'),
] ]
), ),
comments=[ comments=[
comment_cls( Comment(
content='Some comment', content='Some comment',
author=user author=user
) )
@@ -290,7 +244,7 @@ def dataset(
session.commit() session.commit()
@pytest.mark.usefixtures('table_creator', 'dataset') @pytest.mark.usefixtures('dataset', 'postgresql_dsn')
class TestSelectCorrelatedExpression(object): class TestSelectCorrelatedExpression(object):
@pytest.mark.parametrize( @pytest.mark.parametrize(
('model_key', 'related_model_key', 'path', 'result'), ('model_key', 'related_model_key', 'path', 'result'),
@@ -428,20 +382,20 @@ class TestSelectCorrelatedExpression(object):
def test_with_non_aggregate_function( def test_with_non_aggregate_function(
self, self,
session, session,
user_cls, User,
article_cls Article
): ):
aggregate = select_correlated_expression( aggregate = select_correlated_expression(
article_cls, Article,
sa.func.json_build_object('name', user_cls.name), sa.func.json_build_object('name', User.name),
'comments.author', 'comments.author',
user_cls User
) )
query = session.query( query = session.query(
article_cls.id, Article.id,
aggregate.label('author_json') aggregate.label('author_json')
).order_by(article_cls.id) ).order_by(Article.id)
result = query.all() result = query.all()
assert result == [ assert result == [
(1, {'name': 'User 1'}) (1, {'name': 'User 1'})

View File

@@ -9,143 +9,152 @@ from sqlalchemy_utils import (
assert_non_nullable, assert_non_nullable,
assert_nullable assert_nullable
) )
from tests import TestCase
class AssertionTestCase(TestCase): @pytest.fixture()
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def User(Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column('_name', sa.String(20))
age = sa.Column('_age', sa.Integer, nullable=False)
email = sa.Column(
'_email', sa.String(200), nullable=False, unique=True
)
fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer))
def create_models(self): __table_args__ = (
class User(self.Base): sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
__tablename__ = 'user' sa.CheckConstraint(
id = sa.Column('_id', sa.Integer, primary_key=True) sa.and_(
name = sa.Column('_name', sa.String(20)) sa.func.array_length(fav_numbers, 1) <= 8
age = sa.Column('_age', sa.Integer, nullable=False)
email = sa.Column(
'_email', sa.String(200), nullable=False, unique=True
)
fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer))
__table_args__ = (
sa.CheckConstraint(sa.and_(age >= 0, age <= 150)),
sa.CheckConstraint(
sa.and_(
sa.func.array_length(fav_numbers, 1) <= 8
)
) )
) )
self.User = User
def setup_method(self, method):
TestCase.setup_method(self, method)
user = self.User(
name='Someone',
email='someone@example.com',
age=15,
fav_numbers=[1, 2, 3]
) )
self.session.add(user) return User
self.session.commit()
self.user = user
class TestAssertMaxLengthWithArray(AssertionTestCase): @pytest.fixture()
def test_with_max_length(self): def user(User, session):
assert_max_length(self.user, 'fav_numbers', 8) user = User(
assert_max_length(self.user, 'fav_numbers', 8) name='Someone',
email='someone@example.com',
age=15,
fav_numbers=[1, 2, 3]
)
session.add(user)
session.commit()
return user
def test_smaller_than_max_length(self):
@pytest.mark.usefixtures('postgresql_dsn')
class TestAssertMaxLengthWithArray(object):
def test_with_max_length(self, user):
assert_max_length(user, 'fav_numbers', 8)
assert_max_length(user, 'fav_numbers', 8)
def test_smaller_than_max_length(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7) assert_max_length(user, 'fav_numbers', 7)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 7) assert_max_length(user, 'fav_numbers', 7)
def test_bigger_than_max_length(self): def test_bigger_than_max_length(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9) assert_max_length(user, 'fav_numbers', 9)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_length(self.user, 'fav_numbers', 9) assert_max_length(user, 'fav_numbers', 9)
class TestAssertNonNullable(AssertionTestCase): @pytest.mark.usefixtures('postgresql_dsn')
def test_non_nullable_column(self): class TestAssertNonNullable(object):
def test_non_nullable_column(self, user):
# Test everything twice so that session gets rolled back properly # Test everything twice so that session gets rolled back properly
assert_non_nullable(self.user, 'age') assert_non_nullable(user, 'age')
assert_non_nullable(self.user, 'age') assert_non_nullable(user, 'age')
def test_nullable_column(self): def test_nullable_column(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_non_nullable(self.user, 'name') assert_non_nullable(user, 'name')
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_non_nullable(self.user, 'name') assert_non_nullable(user, 'name')
class TestAssertNullable(AssertionTestCase): @pytest.mark.usefixtures('postgresql_dsn')
def test_nullable_column(self): class TestAssertNullable(object):
assert_nullable(self.user, 'name')
assert_nullable(self.user, 'name')
def test_non_nullable_column(self): def test_nullable_column(self, user):
assert_nullable(user, 'name')
assert_nullable(user, 'name')
def test_non_nullable_column(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_nullable(self.user, 'age') assert_nullable(user, 'age')
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_nullable(self.user, 'age') assert_nullable(user, 'age')
class TestAssertMaxLength(AssertionTestCase): @pytest.mark.usefixtures('postgresql_dsn')
def test_with_max_length(self): class TestAssertMaxLength(object):
assert_max_length(self.user, 'name', 20)
assert_max_length(self.user, 'name', 20)
def test_with_non_nullable_column(self): def test_with_max_length(self, user):
assert_max_length(self.user, 'email', 200) assert_max_length(user, 'name', 20)
assert_max_length(self.user, 'email', 200) assert_max_length(user, 'name', 20)
def test_smaller_than_max_length(self): def test_with_non_nullable_column(self, user):
with pytest.raises(AssertionError): assert_max_length(user, 'email', 200)
assert_max_length(self.user, 'name', 19) assert_max_length(user, 'email', 200)
with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 19)
def test_bigger_than_max_length(self): def test_smaller_than_max_length(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 21) assert_max_length(user, 'name', 19)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_length(self.user, 'name', 21) assert_max_length(user, 'name', 19)
def test_bigger_than_max_length(self, user):
with pytest.raises(AssertionError):
assert_max_length(user, 'name', 21)
with pytest.raises(AssertionError):
assert_max_length(user, 'name', 21)
class TestAssertMinValue(AssertionTestCase): @pytest.mark.usefixtures('postgresql_dsn')
def test_with_min_value(self): class TestAssertMinValue(object):
assert_min_value(self.user, 'age', 0)
assert_min_value(self.user, 'age', 0)
def test_smaller_than_min_value(self): def test_with_min_value(self, user):
with pytest.raises(AssertionError): assert_min_value(user, 'age', 0)
assert_min_value(self.user, 'age', -1) assert_min_value(user, 'age', 0)
with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', -1)
def test_bigger_than_min_value(self): def test_smaller_than_min_value(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', 1) assert_min_value(user, 'age', -1)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_min_value(self.user, 'age', 1) assert_min_value(user, 'age', -1)
def test_bigger_than_min_value(self, user):
with pytest.raises(AssertionError):
assert_min_value(user, 'age', 1)
with pytest.raises(AssertionError):
assert_min_value(user, 'age', 1)
class TestAssertMaxValue(AssertionTestCase): @pytest.mark.usefixtures('postgresql_dsn')
def test_with_min_value(self): class TestAssertMaxValue(object):
assert_max_value(self.user, 'age', 150)
assert_max_value(self.user, 'age', 150)
def test_smaller_than_max_value(self): def test_with_min_value(self, user):
with pytest.raises(AssertionError): assert_max_value(user, 'age', 150)
assert_max_value(self.user, 'age', 149) assert_max_value(user, 'age', 150)
with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 149)
def test_bigger_than_max_value(self): def test_smaller_than_max_value(self, user):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 151) assert_max_value(user, 'age', 149)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_max_value(self.user, 'age', 151) assert_max_value(user, 'age', 149)
def test_bigger_than_max_value(self, user):
with pytest.raises(AssertionError):
assert_max_value(user, 'age', 151)
with pytest.raises(AssertionError):
assert_max_value(user, 'age', 151)

View File

@@ -1,117 +1,108 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured
from tests import TestCase
class TestAutoDeleteOrphans(TestCase): @pytest.fixture
def create_models(self): def tagging_tbl(Base):
tagging = sa.Table( return sa.Table(
'tagging', 'tagging',
self.Base.metadata, Base.metadata,
sa.Column( sa.Column(
'tag_id', 'tag_id',
sa.Integer, sa.Integer,
sa.ForeignKey('tag.id', ondelete='cascade'), sa.ForeignKey('tag.id', ondelete='cascade'),
primary_key=True primary_key=True
), ),
sa.Column( sa.Column(
'entry_id', 'entry_id',
sa.Integer, sa.Integer,
sa.ForeignKey('entry.id', ondelete='cascade'), sa.ForeignKey('entry.id', ondelete='cascade'),
primary_key=True primary_key=True
)
) )
)
class Tag(self.Base):
__tablename__ = 'tag'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100), unique=True, nullable=False)
def __init__(self, name=None): @pytest.fixture
self.name = name def Tag(Base):
class Tag(Base):
__tablename__ = 'tag'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100), unique=True, nullable=False)
class Entry(self.Base): def __init__(self, name=None):
__tablename__ = 'entry' self.name = name
return Tag
id = sa.Column(sa.Integer, primary_key=True)
tags = sa.orm.relationship( @pytest.fixture
'Tag', def Entry(Base, Tag, tagging_tbl):
secondary=tagging, class Entry(Base):
backref='entries' __tablename__ = 'entry'
)
auto_delete_orphans(Entry.tags) id = sa.Column(sa.Integer, primary_key=True)
self.Tag = Tag tags = sa.orm.relationship(
self.Entry = Entry 'Tag',
secondary=tagging_tbl,
backref='entries'
)
auto_delete_orphans(Entry.tags)
return Entry
def test_orphan_deletion(self):
r1 = self.Entry() @pytest.fixture
r2 = self.Entry() def EntryWithoutTagsBackref(Base, Tag, tagging_tbl):
r3 = self.Entry() class EntryWithoutTagsBackref(Base):
__tablename__ = 'entry'
id = sa.Column(sa.Integer, primary_key=True)
tags = sa.orm.relationship(
'Tag',
secondary=tagging_tbl
)
return EntryWithoutTagsBackref
class TestAutoDeleteOrphans(object):
@pytest.fixture
def init_models(self, Entry, Tag):
pass
def test_orphan_deletion(self, session, Entry, Tag):
r1 = Entry()
r2 = Entry()
r3 = Entry()
t1, t2, t3, t4 = ( t1, t2, t3, t4 = (
self.Tag('t1'), Tag('t1'),
self.Tag('t2'), Tag('t2'),
self.Tag('t3'), Tag('t3'),
self.Tag('t4') Tag('t4')
) )
r1.tags.extend([t1, t2]) r1.tags.extend([t1, t2])
r2.tags.extend([t2, t3]) r2.tags.extend([t2, t3])
r3.tags.extend([t4]) r3.tags.extend([t4])
self.session.add_all([r1, r2, r3]) session.add_all([r1, r2, r3])
assert self.session.query(self.Tag).count() == 4 assert session.query(Tag).count() == 4
r2.tags.remove(t2) r2.tags.remove(t2)
assert self.session.query(self.Tag).count() == 4 assert session.query(Tag).count() == 4
r1.tags.remove(t2) r1.tags.remove(t2)
assert self.session.query(self.Tag).count() == 3 assert session.query(Tag).count() == 3
r1.tags.remove(t1) r1.tags.remove(t1)
assert self.session.query(self.Tag).count() == 2 assert session.query(Tag).count() == 2
class TestAutoDeleteOrphansWithoutBackref(TestCase): class TestAutoDeleteOrphansWithoutBackref(object):
def create_models(self):
tagging = sa.Table(
'tagging',
self.Base.metadata,
sa.Column(
'tag_id',
sa.Integer,
sa.ForeignKey('tag.id', ondelete='cascade'),
primary_key=True
),
sa.Column(
'entry_id',
sa.Integer,
sa.ForeignKey('entry.id', ondelete='cascade'),
primary_key=True
)
)
class Tag(self.Base): @pytest.fixture
__tablename__ = 'tag' def init_models(self, EntryWithoutTagsBackref, Tag):
id = sa.Column(sa.Integer, primary_key=True) pass
name = sa.Column(sa.String(100), unique=True, nullable=False)
def __init__(self, name=None): def test_orphan_deletion(self, EntryWithoutTagsBackref):
self.name = name with pytest.raises(ImproperlyConfigured):
auto_delete_orphans(EntryWithoutTagsBackref.tags)
class Entry(self.Base):
__tablename__ = 'entry'
id = sa.Column(sa.Integer, primary_key=True)
tags = sa.orm.relationship(
'Tag',
secondary=tagging
)
self.Entry = Entry
def test_orphan_deletion(self):
with raises(ImproperlyConfigured):
auto_delete_orphans(self.Entry.tags)

View File

@@ -1,50 +1,60 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import EmailType from sqlalchemy_utils import EmailType
from tests import TestCase
class TestCaseInsensitiveComparator(TestCase): @pytest.fixture
def create_models(self): def User(Base):
class User(self.Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
email = sa.Column(EmailType) email = sa.Column(EmailType)
def __repr__(self): def __repr__(self):
return 'Building(%r)' % self.id return 'Building(%r)' % self.id
return User
self.User = User
def test_supports_equals(self): @pytest.fixture
def init_models(User):
pass
class TestCaseInsensitiveComparator(object):
def test_supports_equals(self, session, User):
query = ( query = (
self.session.query(self.User) session.query(User)
.filter(self.User.email == u'email@example.com') .filter(User.email == u'email@example.com')
) )
assert '"user".email = lower(:lower_1)' in str(query) assert '"user".email = lower(:lower_1)' in str(query)
def test_supports_in_(self): def test_supports_in_(self, session, User):
query = ( query = (
self.session.query(self.User) session.query(User)
.filter(self.User.email.in_([u'email@example.com', u'a'])) .filter(User.email.in_([u'email@example.com', u'a']))
) )
assert ( assert (
'"user".email IN (lower(:lower_1), lower(:lower_2))' '"user".email IN (lower(:lower_1), lower(:lower_2))'
in str(query) in str(query)
) )
def test_supports_notin_(self): def test_supports_notin_(self, session, User):
query = ( query = (
self.session.query(self.User) session.query(User)
.filter(self.User.email.notin_([u'email@example.com', u'a'])) .filter(User.email.notin_([u'email@example.com', u'a']))
) )
assert ( assert (
'"user".email NOT IN (lower(:lower_1), lower(:lower_2))' '"user".email NOT IN (lower(:lower_1), lower(:lower_2))'
in str(query) in str(query)
) )
def test_does_not_apply_lower_to_types_that_are_already_lowercased(self): def test_does_not_apply_lower_to_types_that_are_already_lowercased(
assert str(self.User.email == self.User.email) == ( self,
User
):
assert str(User.email == User.email) == (
'"user".email = "user".email' '"user".email = "user".email'
) )

View File

@@ -1,86 +1,93 @@
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from pytest import raises
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import Asterisk, row_to_json from sqlalchemy_utils import Asterisk, row_to_json
from sqlalchemy_utils.expressions import explain, explain_analyze from sqlalchemy_utils.expressions import explain, explain_analyze
from tests import TestCase
class ExpressionTestCase(TestCase): @pytest.fixture
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def assert_startswith(session):
def assert_startswith(query, query_part):
def create_models(self):
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
self.Article = Article
def assert_startswith(self, query, query_part):
assert str( assert str(
query.compile(dialect=postgresql.dialect()) query.compile(dialect=postgresql.dialect())
).startswith(query_part) ).startswith(query_part)
# Check that query executes properly # Check that query executes properly
self.session.execute(query) session.execute(query)
return assert_startswith
class TestExplain(ExpressionTestCase): @pytest.fixture
def test_render_explain(self): def Article(Base):
self.assert_startswith( class Article(Base):
explain(self.session.query(self.Article)), __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
return Article
@pytest.mark.usefixtures('postgresql_dsn')
class TestExplain(object):
def test_render_explain(self, session, assert_startswith, Article):
assert_startswith(
explain(session.query(Article)),
'EXPLAIN SELECT' 'EXPLAIN SELECT'
) )
def test_render_explain_with_analyze(self): def test_render_explain_with_analyze(
self.assert_startswith( self,
explain(self.session.query(self.Article), analyze=True), session,
assert_startswith,
Article
):
assert_startswith(
explain(session.query(Article), analyze=True),
'EXPLAIN (ANALYZE true) SELECT' 'EXPLAIN (ANALYZE true) SELECT'
) )
def test_with_string_as_stmt_param(self): def test_with_string_as_stmt_param(self, assert_startswith):
self.assert_startswith( assert_startswith(
explain('SELECT 1 FROM article'), explain('SELECT 1 FROM article'),
'EXPLAIN SELECT' 'EXPLAIN SELECT'
) )
def test_format(self): def test_format(self, assert_startswith):
self.assert_startswith( assert_startswith(
explain('SELECT 1 FROM article', format='json'), explain('SELECT 1 FROM article', format='json'),
'EXPLAIN (FORMAT json) SELECT' 'EXPLAIN (FORMAT json) SELECT'
) )
def test_timing(self): def test_timing(self, assert_startswith):
self.assert_startswith( assert_startswith(
explain('SELECT 1 FROM article', analyze=True, timing=False), explain('SELECT 1 FROM article', analyze=True, timing=False),
'EXPLAIN (ANALYZE true, TIMING false) SELECT' 'EXPLAIN (ANALYZE true, TIMING false) SELECT'
) )
def test_verbose(self): def test_verbose(self, assert_startswith):
self.assert_startswith( assert_startswith(
explain('SELECT 1 FROM article', verbose=True), explain('SELECT 1 FROM article', verbose=True),
'EXPLAIN (VERBOSE true) SELECT' 'EXPLAIN (VERBOSE true) SELECT'
) )
def test_buffers(self): def test_buffers(self, assert_startswith):
self.assert_startswith( assert_startswith(
explain('SELECT 1 FROM article', analyze=True, buffers=True), explain('SELECT 1 FROM article', analyze=True, buffers=True),
'EXPLAIN (ANALYZE true, BUFFERS true) SELECT' 'EXPLAIN (ANALYZE true, BUFFERS true) SELECT'
) )
def test_costs(self): def test_costs(self, assert_startswith):
self.assert_startswith( assert_startswith(
explain('SELECT 1 FROM article', costs=False), explain('SELECT 1 FROM article', costs=False),
'EXPLAIN (COSTS false) SELECT' 'EXPLAIN (COSTS false) SELECT'
) )
class TestExplainAnalyze(ExpressionTestCase): class TestExplainAnalyze(object):
def test_render_explain_analyze(self): def test_render_explain_analyze(self, session, Article):
assert str( assert str(
explain_analyze(self.session.query(self.Article)) explain_analyze(session.query(Article))
.compile( .compile(
dialect=postgresql.dialect() dialect=postgresql.dialect()
) )
@@ -111,7 +118,7 @@ class TestAsterisk(object):
class TestRowToJson(object): class TestRowToJson(object):
def test_compiler_with_default_dialect(self): def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError): with pytest.raises(sa.exc.CompileError):
str(row_to_json(sa.text('article.*'))) str(row_to_json(sa.text('article.*')))
def test_compiler_with_postgresql(self): def test_compiler_with_postgresql(self):
@@ -128,7 +135,7 @@ class TestRowToJson(object):
class TestArrayAgg(object): class TestArrayAgg(object):
def test_compiler_with_default_dialect(self): def test_compiler_with_default_dialect(self):
with raises(sa.exc.CompileError): with pytest.raises(sa.exc.CompileError):
str(sa.func.array_agg(sa.text('u.name'))) str(sa.func.array_agg(sa.text('u.name')))
def test_compiler_with_postgresql(self): def test_compiler_with_postgresql(self):

View File

@@ -1,27 +1,29 @@
from datetime import datetime from datetime import datetime
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils.listeners import force_instant_defaults from sqlalchemy_utils.listeners import force_instant_defaults
from tests import TestCase
force_instant_defaults() force_instant_defaults()
class TestInstantDefaultListener(TestCase): @pytest.fixture
def create_models(self): def Article(Base):
class Article(self.Base): class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article') name = sa.Column(sa.Unicode(255), default=u'Some article')
created_at = sa.Column(sa.DateTime, default=datetime.now) created_at = sa.Column(sa.DateTime, default=datetime.now)
return Article
self.Article = Article
def test_assigns_defaults_on_object_construction(self): class TestInstantDefaultListener(object):
article = self.Article()
def test_assigns_defaults_on_object_construction(self, Article):
article = Article()
assert article.name == u'Some article' assert article.name == u'Some article'
def test_callables_as_defaults(self): def test_callables_as_defaults(self, Article):
article = self.Article() article = Article()
assert isinstance(article.created_at, datetime) assert isinstance(article.created_at, datetime)

View File

@@ -1,14 +1,19 @@
from tests import TestCase class TestInstrumentedList(object):
def test_any_returns_true_if_member_has_attr_defined(
self,
class TestInstrumentedList(TestCase): Category,
def test_any_returns_true_if_member_has_attr_defined(self): Article
category = self.Category() ):
category.articles.append(self.Article()) category = Category()
category.articles.append(self.Article(name=u'some name')) category.articles.append(Article())
category.articles.append(Article(name=u'some name'))
assert category.articles.any('name') assert category.articles.any('name')
def test_any_returns_false_if_no_member_has_attr_defined(self): def test_any_returns_false_if_no_member_has_attr_defined(
category = self.Category() self,
category.articles.append(self.Article()) Category,
Article
):
category = Category()
category.articles.append(Article())
assert not category.articles.any('name') assert not category.articles.any('name')

View File

@@ -1,39 +1,40 @@
from datetime import datetime from datetime import datetime
import pytest
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import Timestamp from sqlalchemy_utils import Timestamp
from tests import TestCase
class TestTimestamp(TestCase): @pytest.fixture
def Article(Base):
class Article(Base, Timestamp):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article')
return Article
def create_models(self):
class Article(self.Base, Timestamp):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), default=u'Some article')
self.Article = Article class TestTimestamp(object):
def test_created(self): def test_created(self, session, Article):
then = datetime.utcnow() then = datetime.utcnow()
article = self.Article() article = Article()
self.session.add(article) session.add(article)
self.session.commit() session.commit()
assert article.created >= then and article.created <= datetime.utcnow() assert article.created >= then and article.created <= datetime.utcnow()
def test_updated(self): def test_updated(self, session, Article):
article = self.Article() article = Article()
self.session.add(article) session.add(article)
self.session.commit() session.commit()
then = datetime.utcnow() then = datetime.utcnow()
article.name = u"Something" article.name = u"Something"
self.session.commit() session.commit()
assert article.updated >= then and article.updated <= datetime.utcnow() assert article.updated >= then and article.updated <= datetime.utcnow()

View File

@@ -1,122 +1,127 @@
import pytest
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from pytest import mark
from sqlalchemy.util.langhelpers import symbol from sqlalchemy.util.langhelpers import symbol
from sqlalchemy_utils.path import AttrPath, Path from sqlalchemy_utils.path import AttrPath, Path
from tests import TestCase
class TestAttrPath(TestCase): @pytest.fixture
def create_models(self): def Document(Base):
class Document(self.Base): class Document(Base):
__tablename__ = 'document' __tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10)) locale = sa.Column(sa.String(10))
return Document
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column( @pytest.fixture
sa.Integer, sa.ForeignKey(Document.id) def Section(Base, Document):
) class Section(Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document = sa.orm.relationship(Document, backref='sections') document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections')
self.Document = Document
self.Section = Section
self.SubSection = SubSection
@mark.parametrize(
('class_', 'path', 'direction'),
(
('SubSection', 'section', symbol('MANYTOONE')),
) )
)
def test_direction(self, class_, path, direction): document = sa.orm.relationship(Document, backref='sections')
return Section
@pytest.fixture
def SubSection(Base, Section):
class SubSection(Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section, backref='subsections')
return SubSection
class TestAttrPath(object):
@pytest.fixture
def init_models(self, Document, Section, SubSection):
pass
def test_direction(self, SubSection):
assert ( assert (
AttrPath(getattr(self, class_), path).direction == direction AttrPath(SubSection, 'section').direction == symbol('MANYTOONE')
) )
def test_invert(self): def test_invert(self, Document, Section, SubSection):
path = ~ AttrPath(self.SubSection, 'section.document') path = ~ AttrPath(SubSection, 'section.document')
assert path.parts == [ assert path.parts == [
self.Document.sections, Document.sections,
self.Section.subsections Section.subsections
] ]
assert str(path.path) == 'sections.subsections' assert str(path.path) == 'sections.subsections'
def test_len(self): def test_len(self, SubSection):
len(AttrPath(self.SubSection, 'section.document')) == 2 len(AttrPath(SubSection, 'section.document')) == 2
def test_init(self): def test_init(self, SubSection):
path = AttrPath(self.SubSection, 'section.document') path = AttrPath(SubSection, 'section.document')
assert path.class_ == self.SubSection assert path.class_ == SubSection
assert path.path == Path('section.document') assert path.path == Path('section.document')
def test_iter(self): def test_iter(self, Section, SubSection):
path = AttrPath(self.SubSection, 'section.document') path = AttrPath(SubSection, 'section.document')
assert list(path) == [ assert list(path) == [
self.SubSection.section, SubSection.section,
self.Section.document Section.document
] ]
def test_repr(self): def test_repr(self, SubSection):
path = AttrPath(self.SubSection, 'section.document') path = AttrPath(SubSection, 'section.document')
assert repr(path) == ( assert repr(path) == (
"AttrPath(SubSection, 'section.document')" "AttrPath(SubSection, 'section.document')"
) )
def test_index(self): def test_index(self, Section, SubSection):
path = AttrPath(self.SubSection, 'section.document') path = AttrPath(SubSection, 'section.document')
assert path.index(self.Section.document) == 1 assert path.index(Section.document) == 1
assert path.index(self.SubSection.section) == 0 assert path.index(SubSection.section) == 0
def test_getitem(self): def test_getitem(self, Section, SubSection):
path = AttrPath(self.SubSection, 'section.document') path = AttrPath(SubSection, 'section.document')
assert path[0] is self.SubSection.section assert path[0] is SubSection.section
assert path[1] is self.Section.document assert path[1] is Section.document
def test_getitem_with_slice(self): def test_getitem_with_slice(self, Section, SubSection):
path = AttrPath(self.SubSection, 'section.document') path = AttrPath(SubSection, 'section.document')
assert path[:] == AttrPath(self.SubSection, 'section.document') assert path[:] == AttrPath(SubSection, 'section.document')
assert path[:-1] == AttrPath(self.SubSection, 'section') assert path[:-1] == AttrPath(SubSection, 'section')
assert path[1:] == AttrPath(self.Section, 'document') assert path[1:] == AttrPath(Section, 'document')
def test_eq(self): def test_eq(self, SubSection):
assert ( assert (
AttrPath(self.SubSection, 'section.document') == AttrPath(SubSection, 'section.document') ==
AttrPath(self.SubSection, 'section.document') AttrPath(SubSection, 'section.document')
) )
assert not ( assert not (
AttrPath(self.SubSection, 'section') == AttrPath(SubSection, 'section') ==
AttrPath(self.SubSection, 'section.document') AttrPath(SubSection, 'section.document')
) )
def test_ne(self): def test_ne(self, SubSection):
assert not ( assert not (
AttrPath(self.SubSection, 'section.document') != AttrPath(SubSection, 'section.document') !=
AttrPath(self.SubSection, 'section.document') AttrPath(SubSection, 'section.document')
) )
assert ( assert (
AttrPath(self.SubSection, 'section') != AttrPath(SubSection, 'section') !=
AttrPath(self.SubSection, 'section.document') AttrPath(SubSection, 'section.document')
) )
@@ -133,7 +138,7 @@ class TestPath(object):
path = Path('s.s2.s3') path = Path('s.s2.s3')
assert list(path) == ['s', 's2', 's3'] assert list(path) == ['s', 's2', 's3']
@mark.parametrize(('path', 'length'), ( @pytest.mark.parametrize(('path', 'length'), (
(Path('s.s2.s3'), 3), (Path('s.s2.s3'), 3),
(Path('s.s2'), 2), (Path('s.s2'), 2),
(Path(''), 0) (Path(''), 0)
@@ -167,14 +172,14 @@ class TestPath(object):
path = Path('s.s2.s3') path = Path('s.s2.s3')
assert path[1:] == Path('s2.s3') assert path[1:] == Path('s2.s3')
@mark.parametrize(('test', 'result'), ( @pytest.mark.parametrize(('test', 'result'), (
(Path('s.s2') == Path('s.s2'), True), (Path('s.s2') == Path('s.s2'), True),
(Path('s.s2') == Path('s.s3'), False) (Path('s.s2') == Path('s.s3'), False)
)) ))
def test_eq(self, test, result): def test_eq(self, test, result):
assert test is result assert test is result
@mark.parametrize(('test', 'result'), ( @pytest.mark.parametrize(('test', 'result'), (
(Path('s.s2') != Path('s.s2'), False), (Path('s.s2') != Path('s.s2'), False),
(Path('s.s2') != Path('s.s3'), True) (Path('s.s2') != Path('s.s3'), True)
)) ))

Some files were not shown because too many files have changed in this diff Show More