From 47d636247a6af51806601a8dea191b075db4f714 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 15 Jan 2016 10:06:48 +0200 Subject: [PATCH 1/3] Add encoding issue with MySQL --- tests/types/test_weekdays.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/types/test_weekdays.py b/tests/types/test_weekdays.py index 682b616..8a4eaba 100644 --- a/tests/types/test_weekdays.py +++ b/tests/types/test_weekdays.py @@ -26,7 +26,7 @@ class WeekDaysTypeTestCase(TestCase): def test_color_parameter_processing(self): schedule = self.Schedule( - working_days='0001111' + working_days=b'0001111' ) self.session.add(schedule) self.session.commit() @@ -35,7 +35,7 @@ class WeekDaysTypeTestCase(TestCase): assert isinstance(schedule.working_days, WeekDays) def test_scalar_attributes_get_coerced_to_objects(self): - schedule = self.Schedule(working_days='1010101') + schedule = self.Schedule(working_days=b'1010101') assert isinstance(schedule.working_days, WeekDays) From 5bdd4d3efb9cde279809135a6d34accff7218d81 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Fri, 15 Jan 2016 10:06:48 +0200 Subject: [PATCH 2/3] Add encoding to fix MySQL issues --- tests/types/test_weekdays.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/types/test_weekdays.py b/tests/types/test_weekdays.py index 8a4eaba..adabc60 100644 --- a/tests/types/test_weekdays.py +++ b/tests/types/test_weekdays.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import pytest import sqlalchemy as sa From 815f07d6c1cb3c9881209e041cd7efc4fcaebb78 Mon Sep 17 00:00:00 2001 From: Jacob Magnusson Date: Mon, 18 Jan 2016 16:32:12 +0100 Subject: [PATCH 3/3] Use pytest fixtures to reduce complexity and repetition Also: Allow override of database name and user in tests (important for me as I would have to mess with my PSQL and MySQL database users otherwise) Use dict.items instead of six.iteritems as it sporadically caused RuntimeError: dictionary changed size during iteration in Python 2.6 tests. Fix typo DNS to DSN Adds Python 3.5 to tox.ini Added an .editorconfig Import babel.dates in sqlalchemy_utils.i18n as an exception would be raised when using the latest versions of babel. --- .editorconfig | 14 + .gitignore | 5 + .isort.cfg | 2 +- .travis.yml | 40 +- MANIFEST.in | 2 +- conftest.py | 198 ++++++++++ docs/installation.rst | 2 + setup.py | 2 +- sqlalchemy_utils/aggregates.py | 5 +- sqlalchemy_utils/expressions.py | 2 +- sqlalchemy_utils/functions/database.py | 3 +- sqlalchemy_utils/functions/foreign_keys.py | 3 +- sqlalchemy_utils/functions/orm.py | 2 +- sqlalchemy_utils/generic.py | 3 +- sqlalchemy_utils/i18n.py | 1 + sqlalchemy_utils/observer.py | 6 +- sqlalchemy_utils/primitives/country.py | 4 +- sqlalchemy_utils/primitives/currency.py | 4 +- sqlalchemy_utils/primitives/weekday.py | 4 +- sqlalchemy_utils/primitives/weekdays.py | 3 +- sqlalchemy_utils/types/arrow.py | 3 +- sqlalchemy_utils/types/color.py | 3 +- sqlalchemy_utils/types/country.py | 3 +- sqlalchemy_utils/types/currency.py | 5 +- sqlalchemy_utils/types/encrypted.py | 5 +- sqlalchemy_utils/types/ip_address.py | 3 +- sqlalchemy_utils/types/password.py | 3 +- sqlalchemy_utils/types/pg_composite.py | 2 +- sqlalchemy_utils/types/phone_number.py | 5 +- sqlalchemy_utils/types/timezone.py | 3 +- sqlalchemy_utils/types/weekdays.py | 7 +- tests/__init__.py | 129 ------- tests/aggregate/test_backrefs.py | 101 ++--- .../test_custom_select_expressions.py | 83 ++-- .../aggregate/test_join_table_inheritance.py | 150 ++++---- tests/aggregate/test_m2m.py | 99 ++--- tests/aggregate/test_m2m_m2m.py | 134 ++++--- .../test_multiple_aggregates_per_class.py | 127 ++++--- tests/aggregate/test_o2m_m2m.py | 124 +++--- tests/aggregate/test_o2m_o2m.py | 90 +++-- tests/aggregate/test_o2m_o2m_o2m.py | 155 +++++--- tests/aggregate/test_search_vectors.py | 69 ++-- tests/aggregate/test_simple_paths.py | 101 ++--- tests/aggregate/test_with_column_alias.py | 103 ++--- tests/aggregate/test_with_ondelete_cascade.py | 79 ++-- tests/functions/test_analyze.py | 34 +- tests/functions/test_database.py | 98 ++--- tests/functions/test_dependent_objects.py | 253 ++++++++----- tests/functions/test_escape_like.py | 3 +- tests/functions/test_get_bind.py | 23 +- tests/functions/test_get_class_by_table.py | 75 ++-- tests/functions/test_get_column_key.py | 52 +-- tests/functions/test_get_columns.py | 60 +-- tests/functions/test_get_hybrid_properties.py | 48 +-- tests/functions/test_get_mapper.py | 131 +++---- tests/functions/test_get_primary_keys.py | 48 +-- tests/functions/test_get_query_entities.py | 143 +++---- .../test_get_referencing_foreign_keys.py | 70 ++-- tests/functions/test_get_tables.py | 108 +++--- tests/functions/test_get_type.py | 69 ++-- tests/functions/test_getdotattr.py | 126 ++++--- tests/functions/test_has_changes.py | 49 ++- tests/functions/test_has_index.py | 41 +- tests/functions/test_has_unique_index.py | 52 ++- tests/functions/test_identity.py | 43 ++- tests/functions/test_is_loaded.py | 28 +- tests/functions/test_json_sql.py | 9 +- .../test_make_order_by_deterministic.py | 104 ++--- tests/functions/test_merge_references.py | 145 ++++--- tests/functions/test_naturally_equivalent.py | 11 +- .../test_non_indexed_foreign_keys.py | 31 +- tests/functions/test_quote.py | 24 +- tests/functions/test_render.py | 41 +- tests/functions/test_table_name.py | 41 +- tests/generic_relationship/__init__.py | 116 +++--- .../test_abstract_base_class.py | 76 ++-- .../test_column_aliases.py | 62 +-- .../test_composite_keys.py | 114 +++--- .../test_hybrid_properties.py | 95 ++--- .../test_single_table_inheritance.py | 210 ++++++----- tests/mixins.py | 147 +++++--- tests/observes/test_column_property.py | 51 +-- tests/observes/test_m2m_m2m_m2m.py | 233 ++++++------ tests/observes/test_o2m_o2m_o2m.py | 166 ++++---- tests/observes/test_o2m_o2o_o2m.py | 156 ++++---- tests/observes/test_o2o_o2o.py | 93 +++-- tests/observes/test_o2o_o2o_o2o.py | 138 ++++--- tests/primitives/test_country.py | 20 +- tests/primitives/test_currency.py | 20 +- tests/primitives/test_weekdays.py | 8 +- tests/relationships/test_chained_join.py | 75 ++-- .../test_select_correlated_expression.py | 220 +++++------ tests/test_asserts.py | 201 +++++----- tests/test_auto_delete_orphans.py | 169 ++++----- tests/test_case_insensitive_comparator.py | 52 +-- tests/test_expressions.py | 87 +++-- tests/test_instant_defaults_listener.py | 28 +- tests/test_instrumented_list.py | 27 +- tests/test_models.py | 35 +- tests/test_path.py | 171 +++++---- tests/test_proxy_dict.py | 141 +++---- tests/test_query_chain.py | 161 ++++---- tests/test_sort_query.py | 255 +++++++------ tests/test_translation_hybrid.py | 92 +++-- tests/types/test_arrow.py | 56 +-- tests/types/test_choice.py | 154 ++++---- tests/types/test_color.py | 54 +-- tests/types/test_composite.py | 268 +++++++------ tests/types/test_country.py | 44 ++- tests/types/test_currency.py | 58 +-- tests/types/test_date_range.py | 87 +++-- tests/types/test_datetime_range.py | 87 +++-- tests/types/test_email.py | 32 +- tests/types/test_encrypted.py | 355 ++++++++++-------- tests/types/test_int_range.py | 296 +++++++++------ tests/types/test_ip_address.py | 40 +- tests/types/test_json.py | 62 +-- tests/types/test_locale.py | 64 ++-- tests/types/test_numeric_range.py | 125 +++--- tests/types/test_password.py | 112 +++--- tests/types/test_phonenumber.py | 165 ++++---- tests/types/test_scalar_list.py | 75 ++-- tests/types/test_timezone.py | 49 +-- tests/types/test_tsvector.py | 65 ++-- tests/types/test_url.py | 44 ++- tests/types/test_uuid.py | 39 +- tests/types/test_weekdays.py | 63 ++-- tox.ini | 34 +- 128 files changed, 5412 insertions(+), 4286 deletions(-) create mode 100644 .editorconfig create mode 100644 conftest.py diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..c7e1a08 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,14 @@ +# EditorConfig helps developers define and maintain consistent +# coding styles between different editors and IDEs +# editorconfig.org + +root = true + + +[*] +indent_style = space +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true +indent_size = 4 diff --git a/.gitignore b/.gitignore index 95825ec..9323407 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ var sdist develop-eggs .installed.cfg +.cache +.eggs lib lib64 docs/_build @@ -42,3 +44,6 @@ nosetests.xml Session.vim .netrwhist *~ + +# Sublime Text +*.sublime-* diff --git a/.isort.cfg b/.isort.cfg index 52591e4..6c84b45 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_first_party=sqlalchemy_utils,tests +known_first_party=sqlalchemy_utils line_length=79 multi_line_output=3 not_skip=__init__.py diff --git a/.travis.yml b/.travis.yml index 56387c1..96b6205 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,6 @@ +sudo: false +language: python + addons: postgresql: "9.4" @@ -6,22 +9,29 @@ before_script: - psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test - mysql -e 'create database sqlalchemy_utils_test;' -language: python -python: - - 2.6 - - 2.7 - - 3.3 - - 3.4 - - 3.5 - -env: - - EXTRAS=test - - EXTRAS=test_all +matrix: + include: + - python: 2.6 + env: + - "TOXENV=py26" + - python: 2.7 + env: + - "TOXENV=py27" + - python: 3.3 + env: + - "TOXENV=py33" + - python: 3.4 + env: + - "TOXENV=py34" + - python: 3.5 + env: + - "TOXENV=py35" + - python: 3.5 + env: + - "TOXENV=lint" install: - - pip install -e .[$EXTRAS] + - pip install tox script: - - isort --recursive --diff sqlalchemy_utils tests && isort --recursive --check-only sqlalchemy_utils tests - - flake8 sqlalchemy_utils tests - - py.test + - tox diff --git a/MANIFEST.in b/MANIFEST.in index cd07949..eaf5b7c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ -include CHANGES.rst LICENSE README.rst +include CHANGES.rst LICENSE README.rst conftest.py .isort.cfg recursive-include tests * recursive-exclude tests *.pyc recursive-include docs * diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..c5abb4c --- /dev/null +++ b/conftest.py @@ -0,0 +1,198 @@ +import os +import warnings + +import pytest +import sqlalchemy as sa +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base, synonym_for +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import sessionmaker +from sqlalchemy_utils import ( + aggregates, + coercion_listener, + i18n, + InstrumentedList +) + +from sqlalchemy_utils.types.pg_composite import remove_composite_listeners + + +@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') +def count_sql_calls(conn, cursor, statement, parameters, context, executemany): + try: + conn.query_count += 1 + except AttributeError: + conn.query_count = 0 + + +warnings.simplefilter('error', sa.exc.SAWarning) + +sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) + + +def get_locale(): + class Locale(): + territories = {'FI': 'Finland'} + + return Locale() + + +@pytest.fixture(scope='session') +def db_name(): + return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test') + + +@pytest.fixture(scope='session') +def postgresql_db_user(): + return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres') + + +@pytest.fixture(scope='session') +def mysql_db_user(): + return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root') + + +@pytest.fixture +def postgresql_dsn(postgresql_db_user, db_name): + return 'postgres://{0}@localhost/{1}'.format(postgresql_db_user, db_name) + + +@pytest.fixture +def mysql_dsn(mysql_db_user, db_name): + return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name) + + +@pytest.fixture +def sqlite_memory_dsn(): + return 'sqlite:///:memory:' + + +@pytest.fixture +def sqlite_file_dsn(): + return 'sqlite:///{0}.db'.format(db_name) + + +@pytest.fixture +def dsn(request): + if 'postgresql_dsn' in request.fixturenames: + return request.getfuncargvalue('postgresql_dsn') + elif 'mysql_dsn' in request.fixturenames: + return request.getfuncargvalue('mysql_dsn') + elif 'sqlite_file_dsn' in request.fixturenames: + return request.getfuncargvalue('sqlite_file_dsn') + elif 'sqlite_memory_dsn' in request.fixturenames: + pass # Return default + return request.getfuncargvalue('sqlite_memory_dsn') + + +@pytest.fixture +def engine(dsn): + engine = create_engine(dsn) + # engine.echo = True + return engine + + +@pytest.fixture +def connection(engine): + return engine.connect() + + +@pytest.fixture +def Base(): + return declarative_base() + + +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + return User + + +@pytest.fixture +def Category(Base): + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + title = sa.Column(sa.Unicode(255)) + + @hybrid_property + def full_name(self): + return u'%s %s' % (self.title, self.name) + + @full_name.expression + def full_name(self): + return sa.func.concat(self.title, ' ', self.name) + + @hybrid_property + def articles_count(self): + return len(self.articles) + + @articles_count.expression + def articles_count(cls): + Article = Base._decl_class_registry['Article'] + return ( + sa.select([sa.func.count(Article.id)]) + .where(Article.category_id == cls.id) + .correlate(Article.__table__) + .label('article_count') + ) + + @property + def name_alias(self): + return self.name + + @synonym_for('name') + @property + def name_synonym(self): + return self.name + return Category + + +@pytest.fixture +def Article(Base, Category): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255), index=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles', + collection_class=InstrumentedList + ) + ) + return Article + + +@pytest.fixture +def init_models(User, Category, Article): + pass + + +@pytest.fixture +def session(request, engine, connection, Base, init_models): + sa.orm.configure_mappers() + Base.metadata.create_all(connection) + Session = sessionmaker(bind=connection) + session = Session() + i18n.get_locale = get_locale + + def teardown(): + aggregates.manager.reset() + session.close_all() + Base.metadata.drop_all(connection) + remove_composite_listeners() + connection.close() + engine.dispose() + + request.addfinalizer(teardown) + + return session diff --git a/docs/installation.rst b/docs/installation.rst index 27fb3bf..a31b9a3 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -11,6 +11,8 @@ SQLAlchemy-Utils has been tested against the following Python platforms. - cPython 2.6 - cPython 2.7 - cPython 3.3 +- cPython 3.4 +- cPython 3.5 Installing an official release diff --git a/setup.py b/setup.py index 4f01d8e..18a424f 100644 --- a/setup.py +++ b/setup.py @@ -89,11 +89,11 @@ setup( 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', 'Topic :: Software Development :: Libraries :: Python Modules' ] diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index b8b4605..da63edf 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -365,7 +365,6 @@ TODO from collections import defaultdict from weakref import WeakKeyDictionary -import six import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.sql.functions import _FunctionGenerator @@ -519,7 +518,7 @@ class AggregationManager(object): ) def update_generator_registry(self): - for class_, attrs in six.iteritems(aggregated_attrs): + for class_, attrs in aggregated_attrs.items(): for expr, path, column in attrs: value = AggregatedValue( class_=class_, @@ -539,7 +538,7 @@ class AggregationManager(object): if class_ in self.generator_registry: object_dict[class_].append(obj) - for class_, objects in six.iteritems(object_dict): + for class_, objects in object_dict.items(): for aggregate_value in self.generator_registry[class_]: query = aggregate_value.update_query(objects) if query is not None: diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py index 1489a8f..c8a4b30 100644 --- a/sqlalchemy_utils/expressions.py +++ b/sqlalchemy_utils/expressions.py @@ -10,7 +10,7 @@ from sqlalchemy.sql.expression import ( ) from sqlalchemy.sql.functions import GenericFunction -from sqlalchemy_utils.functions.orm import quote +from .functions.orm import quote class explain(Executable, ClauseElement): diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py index 7db8695..5b4bb30 100644 --- a/sqlalchemy_utils/functions/database.py +++ b/sqlalchemy_utils/functions/database.py @@ -7,8 +7,7 @@ import sqlalchemy as sa from sqlalchemy.engine.url import make_url from sqlalchemy.exc import OperationalError, ProgrammingError -from sqlalchemy_utils.expressions import explain_analyze - +from ..expressions import explain_analyze from ..utils import starts_with from .orm import quote diff --git a/sqlalchemy_utils/functions/foreign_keys.py b/sqlalchemy_utils/functions/foreign_keys.py index b135acc..4750b55 100644 --- a/sqlalchemy_utils/functions/foreign_keys.py +++ b/sqlalchemy_utils/functions/foreign_keys.py @@ -1,7 +1,6 @@ from collections import defaultdict from itertools import groupby -import six import sqlalchemy as sa from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.orm import object_session @@ -167,7 +166,7 @@ def merge_references(from_, to, foreign_keys=None): new_values = get_foreign_key_values(fk, to) criteria = ( getattr(fk.constraint.table.c, key) == value - for key, value in six.iteritems(old_values) + for key, value in old_values.items() ) try: mapper = get_mapper(fk.constraint.table) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 2909402..bf09a46 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -19,7 +19,7 @@ from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp -from sqlalchemy_utils.utils import is_sequence +from ..utils import is_sequence def get_class_by_table(base, table, data=None): diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index e466316..c2779b8 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -8,9 +8,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.util import set_creation_order -from sqlalchemy_utils.functions import identity - from .exceptions import ImproperlyConfigured +from .functions import identity class GenericAttributeImpl(attributes.ScalarAttributeImpl): diff --git a/sqlalchemy_utils/i18n.py b/sqlalchemy_utils/i18n.py index 06039e3..2626938 100644 --- a/sqlalchemy_utils/i18n.py +++ b/sqlalchemy_utils/i18n.py @@ -8,6 +8,7 @@ from .exceptions import ImproperlyConfigured try: import babel + import babel.dates except ImportError: babel = None diff --git a/sqlalchemy_utils/observer.py b/sqlalchemy_utils/observer.py index 0e0db69..4f08722 100644 --- a/sqlalchemy_utils/observer.py +++ b/sqlalchemy_utils/observer.py @@ -154,9 +154,9 @@ from collections import defaultdict, Iterable, namedtuple import sqlalchemy as sa -from sqlalchemy_utils.functions import getdotattr, has_changes -from sqlalchemy_utils.path import AttrPath -from sqlalchemy_utils.utils import is_sequence +from .functions import getdotattr, has_changes +from .path import AttrPath +from .utils import is_sequence Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) diff --git a/sqlalchemy_utils/primitives/country.py b/sqlalchemy_utils/primitives/country.py index 01e6261..18f5ede 100644 --- a/sqlalchemy_utils/primitives/country.py +++ b/sqlalchemy_utils/primitives/country.py @@ -1,7 +1,7 @@ import six -from sqlalchemy_utils import i18n -from sqlalchemy_utils.utils import str_coercible +from .. import i18n +from ..utils import str_coercible @str_coercible diff --git a/sqlalchemy_utils/primitives/currency.py b/sqlalchemy_utils/primitives/currency.py index d27f688..9eff8fd 100644 --- a/sqlalchemy_utils/primitives/currency.py +++ b/sqlalchemy_utils/primitives/currency.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- import six -from sqlalchemy_utils import i18n, ImproperlyConfigured -from sqlalchemy_utils.utils import str_coercible +from .. import i18n, ImproperlyConfigured +from ..utils import str_coercible @str_coercible diff --git a/sqlalchemy_utils/primitives/weekday.py b/sqlalchemy_utils/primitives/weekday.py index 29a4443..501a55f 100644 --- a/sqlalchemy_utils/primitives/weekday.py +++ b/sqlalchemy_utils/primitives/weekday.py @@ -4,8 +4,8 @@ try: except ImportError: # Python 2.6 port from total_ordering import total_ordering -from sqlalchemy_utils import i18n -from sqlalchemy_utils.utils import str_coercible +from .. import i18n +from ..utils import str_coercible @str_coercible diff --git a/sqlalchemy_utils/primitives/weekdays.py b/sqlalchemy_utils/primitives/weekdays.py index b6aedab..94c7a15 100644 --- a/sqlalchemy_utils/primitives/weekdays.py +++ b/sqlalchemy_utils/primitives/weekdays.py @@ -1,7 +1,6 @@ import six -from sqlalchemy_utils.utils import str_coercible - +from ..utils import str_coercible from .weekday import WeekDay diff --git a/sqlalchemy_utils/types/arrow.py b/sqlalchemy_utils/types/arrow.py index d5a3bd0..a800072 100644 --- a/sqlalchemy_utils/types/arrow.py +++ b/sqlalchemy_utils/types/arrow.py @@ -6,8 +6,7 @@ from datetime import datetime import six from sqlalchemy import types -from sqlalchemy_utils.exceptions import ImproperlyConfigured - +from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible arrow = None diff --git a/sqlalchemy_utils/types/color.py b/sqlalchemy_utils/types/color.py index 7020a8f..161b285 100644 --- a/sqlalchemy_utils/types/color.py +++ b/sqlalchemy_utils/types/color.py @@ -1,8 +1,7 @@ import six from sqlalchemy import types -from sqlalchemy_utils.exceptions import ImproperlyConfigured - +from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible colour = None diff --git a/sqlalchemy_utils/types/country.py b/sqlalchemy_utils/types/country.py index dd05fa1..159c9d0 100644 --- a/sqlalchemy_utils/types/country.py +++ b/sqlalchemy_utils/types/country.py @@ -1,8 +1,7 @@ import six from sqlalchemy import types -from sqlalchemy_utils.primitives import Country - +from ..primitives import Country from .scalar_coercible import ScalarCoercible diff --git a/sqlalchemy_utils/types/currency.py b/sqlalchemy_utils/types/currency.py index 6c8abbe..4794c09 100644 --- a/sqlalchemy_utils/types/currency.py +++ b/sqlalchemy_utils/types/currency.py @@ -1,9 +1,8 @@ import six from sqlalchemy import types -from sqlalchemy_utils import i18n, ImproperlyConfigured -from sqlalchemy_utils.primitives import Currency - +from .. import i18n, ImproperlyConfigured +from ..primitives import Currency from .scalar_coercible import ScalarCoercible diff --git a/sqlalchemy_utils/types/encrypted.py b/sqlalchemy_utils/types/encrypted.py index 4470334..a15895b 100644 --- a/sqlalchemy_utils/types/encrypted.py +++ b/sqlalchemy_utils/types/encrypted.py @@ -5,8 +5,7 @@ import datetime import six from sqlalchemy.types import Binary, String, TypeDecorator -from sqlalchemy_utils.exceptions import ImproperlyConfigured - +from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible cryptography = None @@ -84,7 +83,7 @@ class AesEngine(EncryptionDecryptionBaseEngine): value = str(value) decryptor = self.cipher.decryptor() decrypted = base64.b64decode(value) - decrypted = decryptor.update(decrypted)+decryptor.finalize() + decrypted = decryptor.update(decrypted) + decryptor.finalize() decrypted = decrypted.rstrip(self.PADDING) if not isinstance(decrypted, six.string_types): decrypted = decrypted.decode('utf-8') diff --git a/sqlalchemy_utils/types/ip_address.py b/sqlalchemy_utils/types/ip_address.py index 7ec9741..7ff9ca5 100644 --- a/sqlalchemy_utils/types/ip_address.py +++ b/sqlalchemy_utils/types/ip_address.py @@ -1,8 +1,7 @@ import six from sqlalchemy import types -from sqlalchemy_utils.exceptions import ImproperlyConfigured - +from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible ip_address = None diff --git a/sqlalchemy_utils/types/password.py b/sqlalchemy_utils/types/password.py index 2b45bf4..84d43f3 100644 --- a/sqlalchemy_utils/types/password.py +++ b/sqlalchemy_utils/types/password.py @@ -5,8 +5,7 @@ from sqlalchemy import types from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.ext.mutable import Mutable -from sqlalchemy_utils.exceptions import ImproperlyConfigured - +from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible passlib = None diff --git a/sqlalchemy_utils/types/pg_composite.py b/sqlalchemy_utils/types/pg_composite.py index b195f3a..8440ae5 100644 --- a/sqlalchemy_utils/types/pg_composite.py +++ b/sqlalchemy_utils/types/pg_composite.py @@ -109,7 +109,7 @@ from sqlalchemy.types import ( UserDefinedType ) -from sqlalchemy_utils import ImproperlyConfigured +from .. import ImproperlyConfigured psycopg2 = None CompositeCaster = None diff --git a/sqlalchemy_utils/types/phone_number.py b/sqlalchemy_utils/types/phone_number.py index 07490c0..3d23eb5 100644 --- a/sqlalchemy_utils/types/phone_number.py +++ b/sqlalchemy_utils/types/phone_number.py @@ -1,8 +1,7 @@ from sqlalchemy import types -from sqlalchemy_utils.exceptions import ImproperlyConfigured -from sqlalchemy_utils.utils import str_coercible - +from ..exceptions import ImproperlyConfigured +from ..utils import str_coercible from .scalar_coercible import ScalarCoercible try: diff --git a/sqlalchemy_utils/types/timezone.py b/sqlalchemy_utils/types/timezone.py index 730260f..239712c 100644 --- a/sqlalchemy_utils/types/timezone.py +++ b/sqlalchemy_utils/types/timezone.py @@ -1,8 +1,7 @@ import six from sqlalchemy import types -from sqlalchemy_utils.exceptions import ImproperlyConfigured - +from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible diff --git a/sqlalchemy_utils/types/weekdays.py b/sqlalchemy_utils/types/weekdays.py index 5a3635c..4b6bf4a 100644 --- a/sqlalchemy_utils/types/weekdays.py +++ b/sqlalchemy_utils/types/weekdays.py @@ -1,10 +1,9 @@ import six from sqlalchemy import types -from sqlalchemy_utils import i18n -from sqlalchemy_utils.exceptions import ImproperlyConfigured -from sqlalchemy_utils.primitives import WeekDay, WeekDays - +from .. import i18n +from ..exceptions import ImproperlyConfigured +from ..primitives import WeekDay, WeekDays from .bit import BitType from .scalar_coercible import ScalarCoercible diff --git a/tests/__init__.py b/tests/__init__.py index afe0cc8..62c8dbb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,132 +1,3 @@ -import warnings - -import sqlalchemy as sa -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base, synonym_for -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import sessionmaker - -from sqlalchemy_utils import ( - aggregates, - coercion_listener, - i18n, - InstrumentedList -) -from sqlalchemy_utils.types.pg_composite import remove_composite_listeners - - -@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') -def count_sql_calls(conn, cursor, statement, parameters, context, executemany): - try: - conn.query_count += 1 - except AttributeError: - conn.query_count = 0 - - -warnings.simplefilter('error', sa.exc.SAWarning) - - -sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) - - -def get_locale(): - class Locale(): - territories = {'FI': 'Finland'} - - return Locale() - - -class TestCase(object): - dns = 'sqlite:///:memory:' - create_tables = True - - def setup_method(self, method): - self.engine = create_engine(self.dns) - # self.engine.echo = True - self.connection = self.engine.connect() - self.Base = declarative_base() - - self.create_models() - sa.orm.configure_mappers() - if self.create_tables: - self.Base.metadata.create_all(self.connection) - - Session = sessionmaker(bind=self.connection) - self.session = Session() - - i18n.get_locale = get_locale - - def teardown_method(self, method): - aggregates.manager.reset() - self.session.close_all() - if self.create_tables: - self.Base.metadata.drop_all(self.connection) - remove_composite_listeners() - self.connection.close() - self.engine.dispose() - - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - title = sa.Column(sa.Unicode(255)) - - @hybrid_property - def full_name(self): - return u'%s %s' % (self.title, self.name) - - @full_name.expression - def full_name(self): - return sa.func.concat(self.title, ' ', self.name) - - @hybrid_property - def articles_count(self): - return len(self.articles) - - @articles_count.expression - def articles_count(cls): - return ( - sa.select([sa.func.count(self.Article.id)]) - .where(self.Article.category_id == self.Category.id) - .correlate(self.Article.__table__) - .label('article_count') - ) - - @property - def name_alias(self): - return self.name - - @synonym_for('name') - @property - def name_synonym(self): - return self.name - - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255), index=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) - - category = sa.orm.relationship( - Category, - primaryjoin=category_id == Category.id, - backref=sa.orm.backref( - 'articles', - collection_class=InstrumentedList - ) - ) - - self.User = User - self.Category = Category - self.Article = Article - - def assert_contains(clause, query): # Test that query executes query.all() diff --git a/tests/aggregate/test_backrefs.py b/tests/aggregate/test_backrefs.py index e7d0d2d..2589a5d 100644 --- a/tests/aggregate/test_backrefs.py +++ b/tests/aggregate/test_backrefs.py @@ -1,61 +1,76 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregateValueGenerationWithBackrefs(TestCase): - def create_models(self): - class Thread(self.Base): - __tablename__ = 'thread' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def Thread(Base): + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - @aggregated('comments', sa.Column(sa.Integer, default=0)) - def comment_count(self): - return sa.func.count('1') + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') + return Thread - class Comment(self.Base): - __tablename__ = 'comment' - id = sa.Column(sa.Integer, primary_key=True) - content = sa.Column(sa.Unicode(255)) - thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) - thread = sa.orm.relationship(Thread, backref='comments') +@pytest.fixture +def Comment(Base, Thread): + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) - self.Thread = Thread - self.Comment = Comment + thread = sa.orm.relationship(Thread, backref='comments') + return Comment - def test_assigns_aggregates_on_insert(self): - thread = self.Thread() + +@pytest.fixture +def init_models(Thread, Comment): + pass + + +class TestAggregateValueGenerationWithBackrefs(object): + + def test_assigns_aggregates_on_insert(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 - def test_assigns_aggregates_on_separate_insert(self): - thread = self.Thread() + def test_assigns_aggregates_on_separate_insert( + self, + session, + Thread, + Comment + ): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - self.session.commit() - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + session.commit() + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 - def test_assigns_aggregates_on_delete(self): - thread = self.Thread() + def test_assigns_aggregates_on_delete(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - self.session.commit() - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.delete(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + session.commit() + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.delete(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 0 diff --git a/tests/aggregate/test_custom_select_expressions.py b/tests/aggregate/test_custom_select_expressions.py index de03b79..90d475b 100644 --- a/tests/aggregate/test_custom_select_expressions.py +++ b/tests/aggregate/test_custom_select_expressions.py @@ -1,67 +1,76 @@ from decimal import Decimal +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + return Product - @aggregated('products', sa.Column(sa.Numeric, default=0)) - def net_worth(self): - return sa.func.sum(Product.price) - products = sa.orm.relationship('Product', backref='catalog') +@pytest.fixture +def Catalog(Base, Product): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) + @aggregated('products', sa.Column(sa.Numeric, default=0)) + def net_worth(self): + return sa.func.sum(Product.price) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + products = sa.orm.relationship('Product', backref='catalog') + return Catalog - self.Catalog = Catalog - self.Product = Product - def test_assigns_aggregates_on_insert(self): - catalog = self.Catalog( +@pytest.fixture +def init_models(Product, Catalog): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestLazyEvaluatedSelectExpressionsForAggregates(object): + + def test_assigns_aggregates_on_insert(self, session, Product, Catalog): + catalog = Catalog( name=u'Some catalog' ) - self.session.add(catalog) - self.session.commit() - product = self.Product( + session.add(catalog) + session.commit() + product = Product( name=u'Some product', price=Decimal('1000'), catalog=catalog ) - self.session.add(product) - self.session.commit() - self.session.refresh(catalog) + session.add(product) + session.commit() + session.refresh(catalog) assert catalog.net_worth == Decimal('1000') - def test_assigns_aggregates_on_update(self): - catalog = self.Catalog( + def test_assigns_aggregates_on_update(self, session, Product, Catalog): + catalog = Catalog( name=u'Some catalog' ) - self.session.add(catalog) - self.session.commit() - product = self.Product( + session.add(catalog) + session.commit() + product = Product( name=u'Some product', price=Decimal('1000'), catalog=catalog ) - self.session.add(product) - self.session.commit() + session.add(product) + session.commit() product.price = Decimal('500') - self.session.commit() - self.session.refresh(catalog) + session.commit() + session.refresh(catalog) assert catalog.net_worth == Decimal('500') diff --git a/tests/aggregate/test_join_table_inheritance.py b/tests/aggregate/test_join_table_inheritance.py index fa891ec..b0c0398 100644 --- a/tests/aggregate/test_join_table_inheritance.py +++ b/tests/aggregate/test_join_table_inheritance.py @@ -1,101 +1,121 @@ from decimal import Decimal +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - type = sa.Column(sa.Unicode(255)) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + return Product - __mapper_args__ = { - 'polymorphic_on': type - } - @aggregated('products', sa.Column(sa.Numeric, default=0)) - def net_worth(self): - return sa.func.sum(Product.price) +@pytest.fixture +def Catalog(Base, Product): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(sa.Unicode(255)) - products = sa.orm.relationship('Product', backref='catalog') + __mapper_args__ = { + 'polymorphic_on': type + } - class CostumeCatalog(Catalog): - __tablename__ = 'costume_catalog' - id = sa.Column( - sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True - ) + @aggregated('products', sa.Column(sa.Numeric, default=0)) + def net_worth(self): + return sa.func.sum(Product.price) - __mapper_args__ = { - 'polymorphic_identity': 'costumes', - } + products = sa.orm.relationship('Product', backref='catalog') + return Catalog - class CarCatalog(Catalog): - __tablename__ = 'car_catalog' - id = sa.Column( - sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True - ) - __mapper_args__ = { - 'polymorphic_identity': 'cars', - } +@pytest.fixture +def CostumeCatalog(Catalog): + class CostumeCatalog(Catalog): + __tablename__ = 'costume_catalog' + id = sa.Column( + sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True + ) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) + __mapper_args__ = { + 'polymorphic_identity': 'costumes', + } + return CostumeCatalog - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - self.Catalog = Catalog - self.CostumeCatalog = CostumeCatalog - self.CarCatalog = CarCatalog - self.Product = Product +@pytest.fixture +def CarCatalog(Catalog): + class CarCatalog(Catalog): + __tablename__ = 'car_catalog' + id = sa.Column( + sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True + ) - def test_columns_inherited_from_parent(self): - assert self.CarCatalog.net_worth - assert self.CostumeCatalog.net_worth - assert self.Catalog.net_worth - assert not hasattr(self.CarCatalog.__table__.c, 'net_worth') - assert not hasattr(self.CostumeCatalog.__table__.c, 'net_worth') + __mapper_args__ = { + 'polymorphic_identity': 'cars', + } + return CarCatalog - def test_assigns_aggregates_on_insert(self): - catalog = self.Catalog( + +@pytest.fixture +def init_models(Product, Catalog, CostumeCatalog, CarCatalog): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestLazyEvaluatedSelectExpressionsForAggregates(object): + + def test_columns_inherited_from_parent( + self, + Catalog, + CarCatalog, + CostumeCatalog + ): + assert CarCatalog.net_worth + assert CostumeCatalog.net_worth + assert Catalog.net_worth + assert not hasattr(CarCatalog.__table__.c, 'net_worth') + assert not hasattr(CostumeCatalog.__table__.c, 'net_worth') + + def test_assigns_aggregates_on_insert(self, session, Product, Catalog): + catalog = Catalog( name=u'Some catalog' ) - self.session.add(catalog) - self.session.commit() - product = self.Product( + session.add(catalog) + session.commit() + product = Product( name=u'Some product', price=Decimal('1000'), catalog=catalog ) - self.session.add(product) - self.session.commit() - self.session.refresh(catalog) + session.add(product) + session.commit() + session.refresh(catalog) assert catalog.net_worth == Decimal('1000') - def test_assigns_aggregates_on_update(self): - catalog = self.Catalog( + def test_assigns_aggregates_on_update(self, session, Catalog, Product): + catalog = Catalog( name=u'Some catalog' ) - self.session.add(catalog) - self.session.commit() - product = self.Product( + session.add(catalog) + session.commit() + product = Product( name=u'Some product', price=Decimal('1000'), catalog=catalog ) - self.session.add(product) - self.session.commit() + session.add(product) + session.commit() product.price = Decimal('500') - self.session.commit() - self.session.refresh(catalog) + session.commit() + session.refresh(catalog) assert catalog.net_worth == Decimal('500') diff --git a/tests/aggregate/test_m2m.py b/tests/aggregate/test_m2m.py index 4a9ac45..07a0c7a 100644 --- a/tests/aggregate/test_m2m.py +++ b/tests/aggregate/test_m2m.py @@ -1,72 +1,81 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregatesWithManyToManyRelationships(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def User(Base): + user_group = sa.Table( + 'user_group', + Base.metadata, + sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), + sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) + ) - def create_models(self): - user_group = sa.Table( - 'user_group', - self.Base.metadata, - sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), - sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('groups', sa.Column(sa.Integer, default=0)) + def group_count(self): + return sa.func.count('1') + + groups = sa.orm.relationship( + 'Group', + backref='users', + secondary=user_group ) + return User - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - @aggregated('groups', sa.Column(sa.Integer, default=0)) - def group_count(self): - return sa.func.count('1') +@pytest.fixture +def Group(Base): + class Group(Base): + __tablename__ = 'group' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + return Group - groups = sa.orm.relationship( - 'Group', - backref='users', - secondary=user_group - ) - class Group(self.Base): - __tablename__ = 'group' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def init_models(User, Group): + pass - self.User = User - self.Group = Group - def test_assigns_aggregates_on_insert(self): - user = self.User( +@pytest.mark.usefixtures('postgresql_dsn') +class TestAggregatesWithManyToManyRelationships(object): + + def test_assigns_aggregates_on_insert(self, session, User, Group): + user = User( name=u'John Matrix' ) - self.session.add(user) - self.session.commit() - group = self.Group( + session.add(user) + session.commit() + group = Group( name=u'Some group', users=[user] ) - self.session.add(group) - self.session.commit() - self.session.refresh(user) + session.add(group) + session.commit() + session.refresh(user) assert user.group_count == 1 - def test_updates_aggregates_on_delete(self): - user = self.User( + def test_updates_aggregates_on_delete(self, session, User, Group): + user = User( name=u'John Matrix' ) - self.session.add(user) - self.session.commit() - group = self.Group( + session.add(user) + session.commit() + group = Group( name=u'Some group', users=[user] ) - self.session.add(group) - self.session.commit() - self.session.refresh(user) + session.add(group) + session.commit() + session.refresh(user) user.groups = [] - self.session.commit() - self.session.refresh(user) + session.commit() + session.refresh(user) assert user.group_count == 0 diff --git a/tests/aggregate/test_m2m_m2m.py b/tests/aggregate/test_m2m_m2m.py index 1fee8c3..505ef5e 100644 --- a/tests/aggregate/test_m2m_m2m.py +++ b/tests/aggregate/test_m2m_m2m.py @@ -1,80 +1,92 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated -from tests import TestCase -class TestAggregateManyToManyAndManyToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + return Category - def create_models(self): - catalog_products = sa.Table( - 'catalog_product', - self.Base.metadata, - sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), - sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + +@pytest.fixture +def Catalog(Base, Category): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'products.categories', + sa.Column(sa.Integer, default=0) + ) + def category_count(self): + return sa.func.count(sa.distinct(Category.id)) + return Catalog + + +@pytest.fixture +def Product(Base, Catalog, Category): + catalog_products = sa.Table( + 'catalog_product', + Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + product_categories = sa.Table( + 'category_product', + Base.metadata, + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column( + sa.Integer, sa.ForeignKey('catalog.id') ) - product_categories = sa.Table( - 'category_product', - self.Base.metadata, - sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), - sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + catalogs = sa.orm.relationship( + Catalog, + backref='products', + secondary=catalog_products ) - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + categories = sa.orm.relationship( + Category, + backref='products', + secondary=product_categories + ) + return Product - @aggregated( - 'products.categories', - sa.Column(sa.Integer, default=0) - ) - def category_count(self): - return sa.func.count(sa.distinct(Category.id)) - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def init_models(Category, Catalog, Product): + pass - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) - catalog_id = sa.Column( - sa.Integer, sa.ForeignKey('catalog.id') - ) +@pytest.mark.usefixtures('postgresql_dsn') +class TestAggregateManyToManyAndManyToMany(object): - catalogs = sa.orm.relationship( - Catalog, - backref='products', - secondary=catalog_products - ) - - categories = sa.orm.relationship( - Category, - backref='products', - secondary=product_categories - ) - - self.Catalog = Catalog - self.Category = Category - self.Product = Product - - def test_insert(self): - category = self.Category() + def test_insert(self, session, Product, Category, Catalog): + category = Category() products = [ - self.Product(categories=[category]), - self.Product(categories=[category]) + Product(categories=[category]), + Product(categories=[category]) ] - catalog = self.Catalog(products=products) - self.session.add(catalog) - catalog2 = self.Catalog(products=products) - self.session.add(catalog) - self.session.commit() + catalog = Catalog(products=products) + session.add(catalog) + catalog2 = Catalog(products=products) + session.add(catalog) + session.commit() assert catalog.category_count == 1 assert catalog2.category_count == 1 diff --git a/tests/aggregate/test_multiple_aggregates_per_class.py b/tests/aggregate/test_multiple_aggregates_per_class.py index ee8b5dc..849c5c4 100644 --- a/tests/aggregate/test_multiple_aggregates_per_class.py +++ b/tests/aggregate/test_multiple_aggregates_per_class.py @@ -1,81 +1,96 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregateValueGenerationForSimpleModelPaths(TestCase): - def create_models(self): - class Thread(self.Base): - __tablename__ = 'thread' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def Comment(Base): + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + return Comment - @aggregated( - 'comments', - sa.Column(sa.Integer, default=0) - ) - def comment_count(self): - return sa.func.count('1') - @aggregated('comments', sa.Column(sa.Integer)) - def last_comment_id(self): - return sa.func.max(Comment.id) +@pytest.fixture +def Thread(Base, Comment): + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - comments = sa.orm.relationship( - 'Comment', - backref='thread' - ) + @aggregated( + 'comments', + sa.Column(sa.Integer, default=0) + ) + def comment_count(self): + return sa.func.count('1') - Thread.last_comment = sa.orm.relationship( + @aggregated('comments', sa.Column(sa.Integer)) + def last_comment_id(self): + return sa.func.max(Comment.id) + + comments = sa.orm.relationship( 'Comment', - primaryjoin='Thread.last_comment_id == Comment.id', - foreign_keys=[Thread.last_comment_id], - viewonly=True + backref='thread' ) - class Comment(self.Base): - __tablename__ = 'comment' - id = sa.Column(sa.Integer, primary_key=True) - content = sa.Column(sa.Unicode(255)) - thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + Thread.last_comment = sa.orm.relationship( + 'Comment', + primaryjoin='Thread.last_comment_id == Comment.id', + foreign_keys=[Thread.last_comment_id], + viewonly=True + ) + return Thread - self.Thread = Thread - self.Comment = Comment - def test_assigns_aggregates_on_insert(self): - thread = self.Thread() +@pytest.fixture +def init_models(Comment, Thread): + pass + + +class TestAggregateValueGenerationForSimpleModelPaths(object): + + def test_assigns_aggregates_on_insert(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 assert thread.last_comment_id == comment.id - def test_assigns_aggregates_on_separate_insert(self): - thread = self.Thread() + def test_assigns_aggregates_on_separate_insert( + self, + session, + Thread, + Comment + ): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - self.session.commit() - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + session.commit() + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 assert thread.last_comment_id == 1 - def test_assigns_aggregates_on_delete(self): - thread = self.Thread() + def test_assigns_aggregates_on_delete(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - self.session.commit() - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.delete(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + session.commit() + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.delete(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 0 assert thread.last_comment_id is None diff --git a/tests/aggregate/test_o2m_m2m.py b/tests/aggregate/test_o2m_m2m.py index c2b4ff5..c97569c 100644 --- a/tests/aggregate/test_o2m_m2m.py +++ b/tests/aggregate/test_o2m_m2m.py @@ -1,76 +1,88 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated -from tests import TestCase -class TestAggregateOneToManyAndManyToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + return Category - def create_models(self): - product_categories = sa.Table( - 'category_product', - self.Base.metadata, - sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), - sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + +@pytest.fixture +def Catalog(Base, Category): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'products.categories', + sa.Column(sa.Integer, default=0) + ) + def category_count(self): + return sa.func.count(sa.distinct(Category.id)) + return Catalog + + +@pytest.fixture +def Product(Base, Catalog, Category): + product_categories = sa.Table( + 'category_product', + Base.metadata, + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column( + sa.Integer, sa.ForeignKey('catalog.id') ) - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + catalog = sa.orm.relationship( + Catalog, + backref='products' + ) - @aggregated( - 'products.categories', - sa.Column(sa.Integer, default=0) - ) - def category_count(self): - return sa.func.count(sa.distinct(Category.id)) + categories = sa.orm.relationship( + Category, + backref='products', + secondary=product_categories + ) + return Product - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) +@pytest.fixture +def init_models(Category, Catalog, Product): + pass - catalog_id = sa.Column( - sa.Integer, sa.ForeignKey('catalog.id') - ) - catalog = sa.orm.relationship( - Catalog, - backref='products' - ) +@pytest.mark.usefixtures('postgresql_dsn') +class TestAggregateOneToManyAndManyToMany(object): - categories = sa.orm.relationship( - Category, - backref='products', - secondary=product_categories - ) - - self.Catalog = Catalog - self.Category = Category - self.Product = Product - - def test_insert(self): - category = self.Category() + def test_insert(self, session, Category, Catalog, Product): + category = Category() products = [ - self.Product(categories=[category]), - self.Product(categories=[category]) + Product(categories=[category]), + Product(categories=[category]) ] - catalog = self.Catalog(products=products) - self.session.add(catalog) + catalog = Catalog(products=products) + session.add(catalog) products2 = [ - self.Product(categories=[category]), - self.Product(categories=[category]) + Product(categories=[category]), + Product(categories=[category]) ] - catalog2 = self.Catalog(products=products2) - self.session.add(catalog) - self.session.commit() + catalog2 = Catalog(products=products2) + session.add(catalog) + session.commit() assert catalog.category_count == 1 assert catalog2.category_count == 1 diff --git a/tests/aggregate/test_o2m_o2m.py b/tests/aggregate/test_o2m_o2m.py index 9abed54..9eacdb5 100644 --- a/tests/aggregate/test_o2m_o2m.py +++ b/tests/aggregate/test_o2m_o2m.py @@ -1,64 +1,76 @@ from decimal import Decimal +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregateOneToManyAndOneToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Catalog(Base): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + @aggregated( + 'categories.products', + sa.Column(sa.Integer, default=0) + ) + def product_count(self): + return sa.func.count('1') - @aggregated( - 'categories.products', - sa.Column(sa.Integer, default=0) - ) - def product_count(self): - return sa.func.count('1') + categories = sa.orm.relationship('Category', backref='catalog') + return Catalog - categories = sa.orm.relationship('Category', backref='catalog') - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - products = sa.orm.relationship('Product', backref='category') + products = sa.orm.relationship('Product', backref='category') + return Category - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) - self.Catalog = Catalog - self.Category = Category - self.Product = Product + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + return Product - def test_assigns_aggregates(self): - category = self.Category(name=u'Some category') - catalog = self.Catalog( + +@pytest.fixture +def init_models(Catalog, Category, Product): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestAggregateOneToManyAndOneToMany(object): + + def test_assigns_aggregates(self, session, Category, Catalog, Product): + category = Category(name=u'Some category') + catalog = Catalog( categories=[category] ) catalog.name = u'Some catalog' - self.session.add(catalog) - self.session.commit() - product = self.Product( + session.add(catalog) + session.commit() + product = Product( name=u'Some product', price=Decimal('1000'), category=category ) - self.session.add(product) - self.session.commit() - self.session.refresh(catalog) + session.add(product) + session.commit() + session.refresh(catalog) assert catalog.product_count == 1 diff --git a/tests/aggregate/test_o2m_o2m_o2m.py b/tests/aggregate/test_o2m_o2m_o2m.py index 672045b..aa1e0a3 100644 --- a/tests/aggregate/test_o2m_o2m_o2m.py +++ b/tests/aggregate/test_o2m_o2m_o2m.py @@ -1,88 +1,129 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated -from tests import TestCase -class Test3LevelDeepOneToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Catalog(Base): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) + @aggregated( + 'categories.sub_categories.products', + sa.Column(sa.Integer, default=0) + ) + def product_count(self): + return sa.func.count('1') - @aggregated( - 'categories.sub_categories.products', - sa.Column(sa.Integer, default=0) - ) - def product_count(self): - return sa.func.count('1') + categories = sa.orm.relationship('Category', backref='catalog') + return Catalog - categories = sa.orm.relationship('Category', backref='catalog') - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - sub_categories = sa.orm.relationship( - 'SubCategory', backref='category' - ) + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + return Category - class SubCategory(self.Base): - __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) - products = sa.orm.relationship('Product', backref='sub_category') - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - price = sa.Column(sa.Numeric) +@pytest.fixture +def SubCategory(Base): + class SubCategory(Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship('Product', backref='sub_category') + return SubCategory - sub_category_id = sa.Column( - sa.Integer, sa.ForeignKey('sub_category.id') - ) - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) - def test_assigns_aggregates(self): - catalog = self.catalog_factory() - self.session.commit() - self.session.refresh(catalog) - assert catalog.product_count == 1 + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + return Product - def catalog_factory(self): - product = self.Product() - sub_category = self.SubCategory( + +@pytest.fixture +def init_models(Catalog, Category, SubCategory, Product): + pass + + +@pytest.fixture +def catalog_factory(Product, SubCategory, Category, Catalog, session): + def catalog_factory(): + product = Product() + sub_category = SubCategory( products=[product] ) - category = self.Category(sub_categories=[sub_category]) - catalog = self.Catalog(categories=[category]) - self.session.add(catalog) + category = Category(sub_categories=[sub_category]) + catalog = Catalog(categories=[category]) + session.add(catalog) + return catalog + return catalog_factory + + +@pytest.mark.usefixtures('postgresql_dsn') +class Test3LevelDeepOneToMany(object): + + def test_assigns_aggregates(self, session, catalog_factory): + catalog = catalog_factory() + session.commit() + session.refresh(catalog) + assert catalog.product_count == 1 + + def catalog_factory( + self, + session, + Product, + SubCategory, + Category, + Catalog + ): + product = Product() + sub_category = SubCategory( + products=[product] + ) + category = Category(sub_categories=[sub_category]) + catalog = Catalog(categories=[category]) + session.add(catalog) return catalog - def test_only_updates_affected_aggregates(self): - catalog = self.catalog_factory() - catalog2 = self.catalog_factory() - self.session.commit() + def test_only_updates_affected_aggregates( + self, + session, + catalog_factory, + Product + ): + catalog = catalog_factory() + catalog2 = catalog_factory() + session.commit() # force set catalog2 product_count to zero in order to check if it gets # updated when the other catalog's product count gets updated - self.session.execute( + session.execute( 'UPDATE catalog SET product_count = 0 WHERE id = %d' % catalog2.id ) catalog.categories[0].sub_categories[0].products.append( - self.Product() + Product() ) - self.session.commit() - self.session.refresh(catalog) - self.session.refresh(catalog2) + session.commit() + session.refresh(catalog) + session.refresh(catalog2) assert catalog.product_count == 2 assert catalog2.product_count == 0 diff --git a/tests/aggregate/test_search_vectors.py b/tests/aggregate/test_search_vectors.py index d2a8621..2fd0246 100644 --- a/tests/aggregate/test_search_vectors.py +++ b/tests/aggregate/test_search_vectors.py @@ -1,7 +1,7 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated, TSVectorType -from tests import TestCase def tsvector_reduce_concat(vectors): @@ -13,45 +13,54 @@ def tsvector_reduce_concat(vectors): ) -class TestSearchVectorAggregates(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + return Product - @aggregated('products', sa.Column(TSVectorType)) - def product_search_vector(self): - return tsvector_reduce_concat( - sa.func.to_tsvector(Product.name) - ) - products = sa.orm.relationship('Product', backref='catalog') +@pytest.fixture +def Catalog(Base, Product): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - price = sa.Column(sa.Numeric) + @aggregated('products', sa.Column(TSVectorType)) + def product_search_vector(self): + return tsvector_reduce_concat( + sa.func.to_tsvector(Product.name) + ) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + products = sa.orm.relationship('Product', backref='catalog') + return Catalog - self.Catalog = Catalog - self.Product = Product - def test_assigns_aggregates_on_insert(self): - catalog = self.Catalog( +@pytest.fixture +def init_models(Product, Catalog): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestSearchVectorAggregates(object): + + def test_assigns_aggregates_on_insert(self, session, Product, Catalog): + catalog = Catalog( name=u'Some catalog' ) - self.session.add(catalog) - self.session.commit() - product = self.Product( + session.add(catalog) + session.commit() + product = Product( name=u'Product XYZ', catalog=catalog ) - self.session.add(product) - self.session.commit() - self.session.refresh(catalog) + session.add(product) + session.commit() + session.refresh(catalog) assert catalog.product_search_vector == "'product':1 'xyz':2" diff --git a/tests/aggregate/test_simple_paths.py b/tests/aggregate/test_simple_paths.py index 1f1c071..46a7213 100644 --- a/tests/aggregate/test_simple_paths.py +++ b/tests/aggregate/test_simple_paths.py @@ -1,61 +1,76 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregateValueGenerationForSimpleModelPaths(TestCase): - def create_models(self): - class Thread(self.Base): - __tablename__ = 'thread' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def Thread(Base): + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - @aggregated('comments', sa.Column(sa.Integer, default=0)) - def comment_count(self): - return sa.func.count('1') + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') - comments = sa.orm.relationship('Comment', backref='thread') + comments = sa.orm.relationship('Comment', backref='thread') + return Thread - class Comment(self.Base): - __tablename__ = 'comment' - id = sa.Column(sa.Integer, primary_key=True) - content = sa.Column(sa.Unicode(255)) - thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) - self.Thread = Thread - self.Comment = Comment +@pytest.fixture +def Comment(Base): + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + return Comment - def test_assigns_aggregates_on_insert(self): - thread = self.Thread() + +@pytest.fixture +def init_models(Thread, Comment): + pass + + +class TestAggregateValueGenerationForSimpleModelPaths(object): + + def test_assigns_aggregates_on_insert(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 - def test_assigns_aggregates_on_separate_insert(self): - thread = self.Thread() + def test_assigns_aggregates_on_separate_insert( + self, + session, + Thread, + Comment + ): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - self.session.commit() - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + session.commit() + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 - def test_assigns_aggregates_on_delete(self): - thread = self.Thread() + def test_assigns_aggregates_on_delete(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - self.session.commit() - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.delete(comment) - self.session.commit() - self.session.refresh(thread) + session.add(thread) + session.commit() + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.delete(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 0 diff --git a/tests/aggregate/test_with_column_alias.py b/tests/aggregate/test_with_column_alias.py index 80a87e5..316a25d 100644 --- a/tests/aggregate/test_with_column_alias.py +++ b/tests/aggregate/test_with_column_alias.py @@ -1,59 +1,74 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregatedWithColumnAlias(TestCase): - def create_models(self): - class Thread(self.Base): - __tablename__ = 'thread' - id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def Thread(Base): + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) - @aggregated( - 'comments', - sa.Column('_comment_count', sa.Integer, default=0) - ) - def comment_count(self): - return sa.func.count('1') + @aggregated( + 'comments', + sa.Column('_comment_count', sa.Integer, default=0) + ) + def comment_count(self): + return sa.func.count('1') - comments = sa.orm.relationship('Comment', backref='thread') + comments = sa.orm.relationship('Comment', backref='thread') + return Thread - class Comment(self.Base): - __tablename__ = 'comment' - id = sa.Column(sa.Integer, primary_key=True) - thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) - self.Thread = Thread - self.Comment = Comment +@pytest.fixture +def Comment(Base): + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + return Comment - def test_assigns_aggregates_on_insert(self): - thread = self.Thread() - self.session.add(thread) - comment = self.Comment(thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + +@pytest.fixture +def init_models(Thread, Comment): + pass + + +class TestAggregatedWithColumnAlias(object): + + def test_assigns_aggregates_on_insert(self, session, Thread, Comment): + thread = Thread() + session.add(thread) + comment = Comment(thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 - def test_assigns_aggregates_on_separate_insert(self): - thread = self.Thread() - self.session.add(thread) - self.session.commit() - comment = self.Comment(thread=thread) - self.session.add(comment) - self.session.commit() - self.session.refresh(thread) + def test_assigns_aggregates_on_separate_insert( + self, + session, + Thread, + Comment + ): + thread = Thread() + session.add(thread) + session.commit() + comment = Comment(thread=thread) + session.add(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 1 - def test_assigns_aggregates_on_delete(self): - thread = self.Thread() - self.session.add(thread) - self.session.commit() - comment = self.Comment(thread=thread) - self.session.add(comment) - self.session.commit() - self.session.delete(comment) - self.session.commit() - self.session.refresh(thread) + def test_assigns_aggregates_on_delete(self, session, Thread, Comment): + thread = Thread() + session.add(thread) + session.commit() + comment = Comment(thread=thread) + session.add(comment) + session.commit() + session.delete(comment) + session.commit() + session.refresh(thread) assert thread.comment_count == 0 diff --git a/tests/aggregate/test_with_ondelete_cascade.py b/tests/aggregate/test_with_ondelete_cascade.py index 93a26d0..0a64cc0 100644 --- a/tests/aggregate/test_with_ondelete_cascade.py +++ b/tests/aggregate/test_with_ondelete_cascade.py @@ -1,47 +1,56 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated -from tests import TestCase -class TestAggregateValueGenerationWithCascadeDelete(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Thread(Base): + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - def create_models(self): - class Thread(self.Base): - __tablename__ = 'thread' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') - @aggregated('comments', sa.Column(sa.Integer, default=0)) - def comment_count(self): - return sa.func.count('1') + comments = sa.orm.relationship( + 'Comment', + passive_deletes=True, + backref='thread' + ) + return Thread - comments = sa.orm.relationship( - 'Comment', - passive_deletes=True, - backref='thread' - ) - class Comment(self.Base): - __tablename__ = 'comment' - id = sa.Column(sa.Integer, primary_key=True) - content = sa.Column(sa.Unicode(255)) - thread_id = sa.Column( - sa.Integer, - sa.ForeignKey('thread.id', ondelete='CASCADE') - ) +@pytest.fixture +def Comment(Base): + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column( + sa.Integer, + sa.ForeignKey('thread.id', ondelete='CASCADE') + ) + return Comment - self.Thread = Thread - self.Comment = Comment - def test_something(self): - thread = self.Thread() +@pytest.fixture +def init_models(Thread, Comment): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestAggregateValueGenerationWithCascadeDelete(object): + + def test_something(self, session, Thread, Comment): + thread = Thread() thread.name = u'some article name' - self.session.add(thread) - comment = self.Comment(content=u'Some content', thread=thread) - self.session.add(comment) - self.session.commit() - self.session.expire_all() - self.session.delete(thread) - self.session.commit() + session.add(thread) + comment = Comment(content=u'Some content', thread=thread) + session.add(comment) + session.commit() + session.expire_all() + session.delete(thread) + session.commit() diff --git a/tests/functions/test_analyze.py b/tests/functions/test_analyze.py index 1633efd..93af496 100644 --- a/tests/functions/test_analyze.py +++ b/tests/functions/test_analyze.py @@ -1,29 +1,35 @@ +import pytest + from sqlalchemy_utils import analyze -from tests import TestCase -class TestAnalyzeWithPostgres(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestAnalyzeWithPostgres(object): - def test_runtime(self): - query = self.session.query(self.Article) - assert analyze(self.connection, query).runtime + def test_runtime(self, session, connection, Article): + query = session.query(Article) + assert analyze(connection, query).runtime - def test_node_types_with_join(self): + def test_node_types_with_join(self, session, connection, Article): query = ( - self.session.query(self.Article) - .join(self.Article.category) + session.query(Article) + .join(Article.category) ) - analysis = analyze(self.connection, query) + analysis = analyze(connection, query) assert analysis.node_types == [ u'Hash Join', u'Seq Scan', u'Hash', u'Seq Scan' ] - def test_node_types_with_index_only_scan(self): + def test_node_types_with_index_only_scan( + self, + session, + connection, + Article + ): query = ( - self.session.query(self.Article.name) - .order_by(self.Article.name) + session.query(Article.name) + .order_by(Article.name) .limit(10) ) - analysis = analyze(self.connection, query) + analysis = analyze(connection, query) assert analysis.node_types == [u'Limit', u'Index Only Scan'] diff --git a/tests/functions/test_database.py b/tests/functions/test_database.py index ab14b85..8d76942 100644 --- a/tests/functions/test_database.py +++ b/tests/functions/test_database.py @@ -1,11 +1,8 @@ -import os - +import pytest import sqlalchemy as sa from flexmock import flexmock -from pytest import mark from sqlalchemy_utils import create_database, database_exists, drop_database -from tests import TestCase pymysql = None try: @@ -14,38 +11,73 @@ except ImportError: pass -class DatabaseTest(TestCase): - def test_create_and_drop(self): - assert not database_exists(self.url) - create_database(self.url) - assert database_exists(self.url) - drop_database(self.url) - assert not database_exists(self.url) +class DatabaseTest(object): + def test_create_and_drop(self, dsn): + assert not database_exists(dsn) + create_database(dsn) + assert database_exists(dsn) + drop_database(dsn) + assert not database_exists(dsn) -class TestDatabaseSQLite(DatabaseTest): - url = 'sqlite:///sqlalchemy_utils.db' +@pytest.mark.usefixtures('sqlite_memory_dsn') +class TestDatabaseSQLiteMemory(object): - def setup(self): - if os.path.exists('sqlalchemy_utils.db'): - os.remove('sqlalchemy_utils.db') - - def test_exists_memory(self): - assert database_exists('sqlite:///:memory:') + def test_exists_memory(self, dsn): + assert database_exists(dsn) -@mark.skipif('pymysql is None') +@pytest.mark.usefixtures('sqlite_file_dsn') +class TestDatabaseSQLiteFile(DatabaseTest): + pass + + +@pytest.mark.skipif('pymysql is None') +@pytest.mark.usefixtures('mysql_dsn') class TestDatabaseMySQL(DatabaseTest): - url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy_util' + + @pytest.fixture + def db_name(self): + return 'db_test_sqlalchemy_util' -@mark.skipif('pymysql is None') +@pytest.mark.skipif('pymysql is None') +@pytest.mark.usefixtures('mysql_dsn') class TestDatabaseMySQLWithQuotedName(DatabaseTest): - url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy-util' + + @pytest.fixture + def db_name(self): + return 'db_test_sqlalchemy-util' +@pytest.mark.usefixtures('postgresql_dsn') +class TestDatabasePostgres(DatabaseTest): + + @pytest.fixture + def db_name(self): + return 'db_test_sqlalchemy_util' + + def test_template(self): + ( + flexmock(sa.engine.Engine) + .should_receive('execute') + .with_args( + "CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' " + "TEMPLATE my_template" + ) + ) + create_database( + 'postgres://postgres@localhost/db_test_sqlalchemy_util', + template='my_template' + ) + + +@pytest.mark.usefixtures('postgresql_dsn') class TestDatabasePostgresWithQuotedName(DatabaseTest): - url = 'postgres://postgres@localhost/db_test_sqlalchemy-util' + + @pytest.fixture + def db_name(self): + return 'db_test_sqlalchemy-util' def test_template(self): ( @@ -61,21 +93,3 @@ class TestDatabasePostgresWithQuotedName(DatabaseTest): 'postgres://postgres@localhost/db_test_sqlalchemy-util', template='my-template' ) - - -class TestDatabasePostgres(DatabaseTest): - url = 'postgres://postgres@localhost/db_test_sqlalchemy_util' - - def test_template(self): - ( - flexmock(sa.engine.Engine) - .should_receive('execute') - .with_args( - "CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' " - "TEMPLATE my_template" - ) - ) - create_database( - 'postgres://postgres@localhost/db_test_sqlalchemy_util', - template='my_template' - ) diff --git a/tests/functions/test_dependent_objects.py b/tests/functions/test_dependent_objects.py index 0702dd7..1f1dd66 100644 --- a/tests/functions/test_dependent_objects.py +++ b/tests/functions/test_dependent_objects.py @@ -1,18 +1,23 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys -from tests import TestCase -class TestDependentObjects(TestCase): - def create_models(self): - class User(self.Base): +class TestDependentObjects(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) + return User - class Article(self.Base): + @pytest.fixture + def Article(self, Base, User): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) @@ -22,8 +27,11 @@ class TestDependentObjects(TestCase): author = sa.orm.relationship(User, foreign_keys=[author_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id]) + return Article - class BlogPost(self.Base): + @pytest.fixture + def BlogPost(self, Base, User): + class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) owner_id = sa.Column( @@ -31,21 +39,22 @@ class TestDependentObjects(TestCase): ) owner = sa.orm.relationship(User) + return BlogPost - self.User = User - self.Article = Article - self.BlogPost = BlogPost + @pytest.fixture + def init_models(self, User, Article, BlogPost): + pass - def test_returns_all_dependent_objects(self): - user = self.User(first_name=u'John') + def test_returns_all_dependent_objects(self, session, User, Article): + user = User(first_name=u'John') articles = [ - self.Article(author=user), - self.Article(), - self.Article(owner=user), - self.Article(author=user, owner=user) + Article(author=user), + Article(), + Article(owner=user), + Article(author=user, owner=user) ] - self.session.add_all(articles) - self.session.commit() + session.add_all(articles) + session.commit() deps = list(dependent_objects(user)) assert len(deps) == 3 @@ -53,23 +62,29 @@ class TestDependentObjects(TestCase): assert articles[2] in deps assert articles[3] in deps - def test_with_foreign_keys_parameter(self): - user = self.User(first_name=u'John') + def test_with_foreign_keys_parameter( + self, + session, + User, + Article, + BlogPost + ): + user = User(first_name=u'John') objects = [ - self.Article(author=user), - self.Article(), - self.Article(owner=user), - self.Article(author=user, owner=user), - self.BlogPost(owner=user) + Article(author=user), + Article(), + Article(owner=user), + Article(author=user, owner=user), + BlogPost(owner=user) ] - self.session.add_all(objects) - self.session.commit() + session.add_all(objects) + session.commit() deps = list( dependent_objects( user, ( - fk for fk in get_referencing_foreign_keys(self.User) + fk for fk in get_referencing_foreign_keys(User) if fk.ondelete == 'RESTRICT' or fk.ondelete is None ) ).limit(5) @@ -79,15 +94,20 @@ class TestDependentObjects(TestCase): assert objects[3] in deps -class TestDependentObjectsWithColumnAliases(TestCase): - def create_models(self): - class User(self.Base): +class TestDependentObjectsWithColumnAliases(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) + return User - class Article(self.Base): + @pytest.fixture + def Article(self, Base, User): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column( @@ -100,8 +120,11 @@ class TestDependentObjectsWithColumnAliases(TestCase): author = sa.orm.relationship(User, foreign_keys=[author_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id]) + return Article - class BlogPost(self.Base): + @pytest.fixture + def BlogPost(self, Base, User): + class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) owner_id = sa.Column( @@ -110,21 +133,22 @@ class TestDependentObjectsWithColumnAliases(TestCase): ) owner = sa.orm.relationship(User) + return BlogPost - self.User = User - self.Article = Article - self.BlogPost = BlogPost + @pytest.fixture + def init_models(self, User, Article, BlogPost): + pass - def test_returns_all_dependent_objects(self): - user = self.User(first_name=u'John') + def test_returns_all_dependent_objects(self, session, User, Article): + user = User(first_name=u'John') articles = [ - self.Article(author=user), - self.Article(), - self.Article(owner=user), - self.Article(author=user, owner=user) + Article(author=user), + Article(), + Article(owner=user), + Article(author=user, owner=user) ] - self.session.add_all(articles) - self.session.commit() + session.add_all(articles) + session.commit() deps = list(dependent_objects(user)) assert len(deps) == 3 @@ -132,23 +156,29 @@ class TestDependentObjectsWithColumnAliases(TestCase): assert articles[2] in deps assert articles[3] in deps - def test_with_foreign_keys_parameter(self): - user = self.User(first_name=u'John') + def test_with_foreign_keys_parameter( + self, + session, + User, + Article, + BlogPost + ): + user = User(first_name=u'John') objects = [ - self.Article(author=user), - self.Article(), - self.Article(owner=user), - self.Article(author=user, owner=user), - self.BlogPost(owner=user) + Article(author=user), + Article(), + Article(owner=user), + Article(author=user, owner=user), + BlogPost(owner=user) ] - self.session.add_all(objects) - self.session.commit() + session.add_all(objects) + session.commit() deps = list( dependent_objects( user, ( - fk for fk in get_referencing_foreign_keys(self.User) + fk for fk in get_referencing_foreign_keys(User) if fk.ondelete == 'RESTRICT' or fk.ondelete is None ) ).limit(5) @@ -158,50 +188,64 @@ class TestDependentObjectsWithColumnAliases(TestCase): assert objects[3] in deps -class TestDependentObjectsWithManyReferences(TestCase): - def create_models(self): - class User(self.Base): +class TestDependentObjectsWithManyReferences(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) + return User - class BlogPost(self.Base): + @pytest.fixture + def BlogPost(self, Base, User): + class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) + return BlogPost - class Article(self.Base): + @pytest.fixture + def Article(self, Base, User): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) + return Article - self.User = User - self.Article = Article - self.BlogPost = BlogPost + @pytest.fixture + def init_models(self, User, BlogPost, Article): + pass - def test_with_many_dependencies(self): - user = self.User(first_name=u'John') + def test_with_many_dependencies(self, session, User, Article, BlogPost): + user = User(first_name=u'John') objects = [ - self.Article(author=user), - self.BlogPost(author=user) + Article(author=user), + BlogPost(author=user) ] - self.session.add_all(objects) - self.session.commit() + session.add_all(objects) + session.commit() deps = list(dependent_objects(user)) assert len(deps) == 2 -class TestDependentObjectsWithCompositeKeys(TestCase): - def create_models(self): - class User(self.Base): +class TestDependentObjectsWithCompositeKeys(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) + return User - class Article(self.Base): + @pytest.fixture + def Article(self, Base, User): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) @@ -214,20 +258,22 @@ class TestDependentObjectsWithCompositeKeys(TestCase): ) author = sa.orm.relationship(User) + return Article - self.User = User - self.Article = Article + @pytest.fixture + def init_models(self, User, Article): + pass - def test_returns_all_dependent_objects(self): - user = self.User(first_name=u'John', last_name=u'Smith') + def test_returns_all_dependent_objects(self, session, User, Article): + user = User(first_name=u'John', last_name=u'Smith') articles = [ - self.Article(author=user), - self.Article(), - self.Article(), - self.Article(author=user) + Article(author=user), + Article(), + Article(), + Article(author=user) ] - self.session.add_all(articles) - self.session.commit() + session.add_all(articles) + session.commit() deps = list(dependent_objects(user)) assert len(deps) == 2 @@ -235,14 +281,19 @@ class TestDependentObjectsWithCompositeKeys(TestCase): assert articles[3] in deps -class TestDependentObjectsWithSingleTableInheritance(TestCase): - def create_models(self): - class Category(self.Base): +class TestDependentObjectsWithSingleTableInheritance(object): + + @pytest.fixture + def Category(self, Base): + class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) + return Category - class TextItem(self.Base): + @pytest.fixture + def TextItem(self, Base, Category): + class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @@ -261,33 +312,39 @@ class TestDependentObjectsWithSingleTableInheritance(TestCase): __mapper_args__ = { 'polymorphic_on': type, } + return TextItem + @pytest.fixture + def Article(self, TextItem): class Article(TextItem): __mapper_args__ = { 'polymorphic_identity': u'article' } + return Article + @pytest.fixture + def BlogPost(self, TextItem): class BlogPost(TextItem): __mapper_args__ = { 'polymorphic_identity': u'blog_post' } + return BlogPost - self.Category = Category - self.TextItem = TextItem - self.Article = Article - self.BlogPost = BlogPost + @pytest.fixture + def init_models(self, Category, TextItem, Article, BlogPost): + pass - def test_returns_all_dependent_objects(self): - category1 = self.Category(name=u'Category #1') - category2 = self.Category(name=u'Category #2') + def test_returns_all_dependent_objects(self, session, Category, Article): + category1 = Category(name=u'Category #1') + category2 = Category(name=u'Category #2') articles = [ - self.Article(category=category1), - self.Article(category=category1), - self.Article(category=category2), - self.Article(category=category2), + Article(category=category1), + Article(category=category1), + Article(category=category2), + Article(category=category2), ] - self.session.add_all(articles) - self.session.commit() + session.add_all(articles) + session.commit() deps = list(dependent_objects(category1)) assert len(deps) == 2 diff --git a/tests/functions/test_escape_like.py b/tests/functions/test_escape_like.py index d1f78ca..8a3fa9f 100644 --- a/tests/functions/test_escape_like.py +++ b/tests/functions/test_escape_like.py @@ -1,7 +1,6 @@ from sqlalchemy_utils import escape_like -from tests import TestCase -class TestEscapeLike(TestCase): +class TestEscapeLike(object): def test_escapes_wildcards(self): assert escape_like('_*%') == '*_***%' diff --git a/tests/functions/test_get_bind.py b/tests/functions/test_get_bind.py index c10a5f0..090ac2b 100644 --- a/tests/functions/test_get_bind.py +++ b/tests/functions/test_get_bind.py @@ -1,21 +1,20 @@ -from pytest import raises +import pytest from sqlalchemy_utils import get_bind -from tests import TestCase -class TestGetBind(TestCase): - def test_with_session(self): - assert get_bind(self.session) == self.connection +class TestGetBind(object): + def test_with_session(self, session, connection): + assert get_bind(session) == connection - def test_with_connection(self): - assert get_bind(self.connection) == self.connection + def test_with_connection(self, session, connection): + assert get_bind(connection) == connection - def test_with_model_object(self): - article = self.Article() - self.session.add(article) - assert get_bind(article) == self.connection + def test_with_model_object(self, session, connection, Article): + article = Article() + session.add(article) + assert get_bind(article) == connection def test_with_unknown_type(self): - with raises(TypeError): + with pytest.raises(TypeError): get_bind(None) diff --git a/tests/functions/test_get_class_by_table.py b/tests/functions/test_get_class_by_table.py index 99db410..50926d6 100644 --- a/tests/functions/test_get_class_by_table.py +++ b/tests/functions/test_get_class_by_table.py @@ -1,15 +1,14 @@ +import pytest import sqlalchemy as sa -from pytest import raises -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_class_by_table class TestGetClassByTableWithJoinedTableInheritance(object): - def setup_method(self, method): - self.Base = declarative_base() - class Entity(self.Base): + @pytest.fixture + def Entity(self, Base): + class Entity(Base): __tablename__ = 'entity' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) @@ -18,7 +17,10 @@ class TestGetClassByTableWithJoinedTableInheritance(object): 'polymorphic_on': type, 'polymorphic_identity': 'entity' } + return Entity + @pytest.fixture + def User(self, Entity): class User(Entity): __tablename__ = 'user' id = sa.Column( @@ -29,31 +31,29 @@ class TestGetClassByTableWithJoinedTableInheritance(object): __mapper_args__ = { 'polymorphic_identity': 'user' } + return User - self.Entity = Entity - self.User = User - - def test_returns_class(self): - assert get_class_by_table(self.Base, self.User.__table__) == self.User + def test_returns_class(self, Base, User, Entity): + assert get_class_by_table(Base, User.__table__) == User assert get_class_by_table( - self.Base, - self.Entity.__table__ - ) == self.Entity + Base, + Entity.__table__ + ) == Entity - def test_table_with_no_associated_class(self): + def test_table_with_no_associated_class(self, Base): table = sa.Table( 'some_table', - self.Base.metadata, + Base.metadata, sa.Column('id', sa.Integer) ) - assert get_class_by_table(self.Base, table) is None + assert get_class_by_table(Base, table) is None class TestGetClassByTableWithSingleTableInheritance(object): - def setup_method(self, method): - self.Base = declarative_base() - class Entity(self.Base): + @pytest.fixture + def Entity(self, Base): + class Entity(Base): __tablename__ = 'entity' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) @@ -62,38 +62,39 @@ class TestGetClassByTableWithSingleTableInheritance(object): 'polymorphic_on': type, 'polymorphic_identity': 'entity' } + return Entity + @pytest.fixture + def User(self, Entity): class User(Entity): __mapper_args__ = { 'polymorphic_identity': 'user' } + return User - self.Entity = Entity - self.User = User - - def test_multiple_classes_without_data_parameter(self): - with raises(ValueError): + def test_multiple_classes_without_data_parameter(self, Base, Entity, User): + with pytest.raises(ValueError): assert get_class_by_table( - self.Base, - self.Entity.__table__ + Base, + Entity.__table__ ) - def test_multiple_classes_with_data_parameter(self): + def test_multiple_classes_with_data_parameter(self, Base, Entity, User): assert get_class_by_table( - self.Base, - self.Entity.__table__, + Base, + Entity.__table__, {'type': 'entity'} - ) == self.Entity + ) == Entity assert get_class_by_table( - self.Base, - self.Entity.__table__, + Base, + Entity.__table__, {'type': 'user'} - ) == self.User + ) == User - def test_multiple_classes_with_bogus_data(self): - with raises(ValueError): + def test_multiple_classes_with_bogus_data(self, Base, Entity, User): + with pytest.raises(ValueError): assert get_class_by_table( - self.Base, - self.Entity.__table__, + Base, + Entity.__table__, {'type': 'unknown'} ) diff --git a/tests/functions/test_get_column_key.py b/tests/functions/test_get_column_key.py index 6aba2d2..c3959bc 100644 --- a/tests/functions/test_get_column_key.py +++ b/tests/functions/test_get_column_key.py @@ -1,42 +1,44 @@ from copy import copy +import pytest import sqlalchemy as sa -from pytest import raises -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_column_key +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column('_name', sa.Unicode(255)) + return Building + + +@pytest.fixture +def Movie(Base): + class Movie(Base): + __tablename__ = 'movie' + id = sa.Column(sa.Integer, primary_key=True) + return Movie + + class TestGetColumnKey(object): - def setup_method(self, method): - Base = declarative_base() - class Building(Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column('_name', sa.Unicode(255)) - - class Movie(Base): - __tablename__ = 'movie' - id = sa.Column(sa.Integer, primary_key=True) - - self.Building = Building - self.Movie = Movie - - def test_supports_aliases(self): + def test_supports_aliases(self, Building): assert ( - get_column_key(self.Building, self.Building.__table__.c.id) == + get_column_key(Building, Building.__table__.c.id) == 'id' ) assert ( - get_column_key(self.Building, self.Building.__table__.c._name) == + get_column_key(Building, Building.__table__.c._name) == 'name' ) - def test_supports_vague_matching_of_column_objects(self): - column = copy(self.Building.__table__.c._name) - assert get_column_key(self.Building, column) == 'name' + def test_supports_vague_matching_of_column_objects(self, Building): + column = copy(Building.__table__.c._name) + assert get_column_key(Building, column) == 'name' - def test_throws_value_error_for_unknown_column(self): - with raises(sa.orm.exc.UnmappedColumnError): - get_column_key(self.Building, self.Movie.__table__.c.id) + def test_throws_value_error_for_unknown_column(self, Building, Movie): + with pytest.raises(sa.orm.exc.UnmappedColumnError): + get_column_key(Building, Movie.__table__.c.id) diff --git a/tests/functions/test_get_columns.py b/tests/functions/test_get_columns.py index d09dd39..2254083 100644 --- a/tests/functions/test_get_columns.py +++ b/tests/functions/test_get_columns.py @@ -1,65 +1,65 @@ +import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_columns +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column('_name', sa.Unicode(255)) + return Building + + class TestGetColumns(object): - def setup_method(self, method): - Base = declarative_base() - class Building(Base): - __tablename__ = 'building' - id = sa.Column('_id', sa.Integer, primary_key=True) - name = sa.Column('_name', sa.Unicode(255)) - - self.Building = Building - - def test_table(self): + def test_table(self, Building): assert isinstance( - get_columns(self.Building.__table__), + get_columns(Building.__table__), sa.sql.base.ImmutableColumnCollection ) - def test_instrumented_attribute(self): - assert get_columns(self.Building.id) == [self.Building.__table__.c._id] + def test_instrumented_attribute(self, Building): + assert get_columns(Building.id) == [Building.__table__.c._id] - def test_column_property(self): - assert get_columns(self.Building.id.property) == [ - self.Building.__table__.c._id + def test_column_property(self, Building): + assert get_columns(Building.id.property) == [ + Building.__table__.c._id ] - def test_column(self): - assert get_columns(self.Building.__table__.c._id) == [ - self.Building.__table__.c._id + def test_column(self, Building): + assert get_columns(Building.__table__.c._id) == [ + Building.__table__.c._id ] - def test_declarative_class(self): + def test_declarative_class(self, Building): assert isinstance( - get_columns(self.Building), + get_columns(Building), sa.util._collections.OrderedProperties ) - def test_declarative_object(self): + def test_declarative_object(self, Building): assert isinstance( - get_columns(self.Building()), + get_columns(Building()), sa.util._collections.OrderedProperties ) - def test_mapper(self): + def test_mapper(self, Building): assert isinstance( - get_columns(self.Building.__mapper__), + get_columns(Building.__mapper__), sa.util._collections.OrderedProperties ) - def test_class_alias(self): + def test_class_alias(self, Building): assert isinstance( - get_columns(sa.orm.aliased(self.Building)), + get_columns(sa.orm.aliased(Building)), sa.util._collections.OrderedProperties ) - def test_table_alias(self): - alias = sa.orm.aliased(self.Building.__table__) + def test_table_alias(self, Building): + alias = sa.orm.aliased(Building.__table__) assert isinstance( get_columns(alias), sa.sql.base.ImmutableColumnCollection diff --git a/tests/functions/test_get_hybrid_properties.py b/tests/functions/test_get_hybrid_properties.py index 16c78fa..5f795b2 100644 --- a/tests/functions/test_get_hybrid_properties.py +++ b/tests/functions/test_get_hybrid_properties.py @@ -1,41 +1,41 @@ +import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_hybrid_properties +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @hybrid_property + def lowercase_name(self): + return self.name.lower() + + @lowercase_name.expression + def lowercase_name(cls): + return sa.func.lower(cls.name) + return Category + + class TestGetHybridProperties(object): - def setup_method(self, method): - Base = declarative_base() - class Category(Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - @hybrid_property - def lowercase_name(self): - return self.name.lower() - - @lowercase_name.expression - def lowercase_name(cls): - return sa.func.lower(cls.name) - - self.Category = Category - - def test_declarative_model(self): + def test_declarative_model(self, Category): assert ( - list(get_hybrid_properties(self.Category).keys()) == + list(get_hybrid_properties(Category).keys()) == ['lowercase_name'] ) - def test_mapper(self): + def test_mapper(self, Category): assert ( - list(get_hybrid_properties(sa.inspect(self.Category)).keys()) == + list(get_hybrid_properties(sa.inspect(Category)).keys()) == ['lowercase_name'] ) - def test_aliased_class(self): - props = get_hybrid_properties(sa.orm.aliased(self.Category)) + def test_aliased_class(self, Category): + props = get_hybrid_properties(sa.orm.aliased(Category)) assert list(props.keys()) == ['lowercase_name'] diff --git a/tests/functions/test_get_mapper.py b/tests/functions/test_get_mapper.py index 766b495..1298d79 100644 --- a/tests/functions/test_get_mapper.py +++ b/tests/functions/test_get_mapper.py @@ -1,104 +1,106 @@ +import pytest import sqlalchemy as sa -from pytest import raises -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_mapper -from tests import TestCase class TestGetMapper(object): - def setup_method(self, method): - self.Base = declarative_base() - class Building(self.Base): + @pytest.fixture + def Building(self, Base): + class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) + return Building - self.Building = Building + def test_table(self, Building): + assert get_mapper(Building.__table__) == sa.inspect(Building) - def test_table(self): - assert get_mapper(self.Building.__table__) == sa.inspect(self.Building) - - def test_declarative_class(self): + def test_declarative_class(self, Building): assert ( - get_mapper(self.Building) == - sa.inspect(self.Building) + get_mapper(Building) == + sa.inspect(Building) ) - def test_declarative_object(self): + def test_declarative_object(self, Building): assert ( - get_mapper(self.Building()) == - sa.inspect(self.Building) + get_mapper(Building()) == + sa.inspect(Building) ) - def test_mapper(self): + def test_mapper(self, Building): assert ( - get_mapper(self.Building.__mapper__) == - sa.inspect(self.Building) + get_mapper(Building.__mapper__) == + sa.inspect(Building) ) - def test_class_alias(self): + def test_class_alias(self, Building): assert ( - get_mapper(sa.orm.aliased(self.Building)) == - sa.inspect(self.Building) + get_mapper(sa.orm.aliased(Building)) == + sa.inspect(Building) ) - def test_instrumented_attribute(self): + def test_instrumented_attribute(self, Building): assert ( - get_mapper(self.Building.id) == sa.inspect(self.Building) + get_mapper(Building.id) == sa.inspect(Building) ) - def test_table_alias(self): - alias = sa.orm.aliased(self.Building.__table__) + def test_table_alias(self, Building): + alias = sa.orm.aliased(Building.__table__) assert ( get_mapper(alias) == - sa.inspect(self.Building) + sa.inspect(Building) ) - def test_column(self): + def test_column(self, Building): assert ( - get_mapper(self.Building.__table__.c.id) == - sa.inspect(self.Building) + get_mapper(Building.__table__.c.id) == + sa.inspect(Building) ) - def test_column_of_an_alias(self): + def test_column_of_an_alias(self, Building): assert ( - get_mapper(sa.orm.aliased(self.Building.__table__).c.id) == - sa.inspect(self.Building) + get_mapper(sa.orm.aliased(Building.__table__).c.id) == + sa.inspect(Building) ) -class TestGetMapperWithQueryEntities(TestCase): - def create_models(self): - class Building(self.Base): +class TestGetMapperWithQueryEntities(object): + + @pytest.fixture + def Building(self, Base): + class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) + return Building - self.Building = Building + @pytest.fixture + def init_models(self, Building): + pass - def test_mapper_entity_with_mapper(self): - entity = self.session.query(self.Building.__mapper__)._entities[0] + def test_mapper_entity_with_mapper(self, session, Building): + entity = session.query(Building.__mapper__)._entities[0] assert ( get_mapper(entity) == - sa.inspect(self.Building) + sa.inspect(Building) ) - def test_mapper_entity_with_class(self): - entity = self.session.query(self.Building)._entities[0] + def test_mapper_entity_with_class(self, session, Building): + entity = session.query(Building)._entities[0] assert ( get_mapper(entity) == - sa.inspect(self.Building) + sa.inspect(Building) ) - def test_column_entity(self): - query = self.session.query(self.Building.id) - assert get_mapper(query._entities[0]) == sa.inspect(self.Building) + def test_column_entity(self, session, Building): + query = session.query(Building.id) + assert get_mapper(query._entities[0]) == sa.inspect(Building) class TestGetMapperWithMultipleMappersFound(object): - def setup_method(self, method): - Base = declarative_base() + @pytest.fixture + def Building(self, Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) @@ -106,29 +108,30 @@ class TestGetMapperWithMultipleMappersFound(object): class BigBuilding(Building): pass - self.Building = Building - self.BigBuilding = BigBuilding + return Building - def test_table(self): - with raises(ValueError): - get_mapper(self.Building.__table__) + def test_table(self, Building): + with pytest.raises(ValueError): + get_mapper(Building.__table__) - def test_table_alias(self): - alias = sa.orm.aliased(self.Building.__table__) - with raises(ValueError): + def test_table_alias(self, Building): + alias = sa.orm.aliased(Building.__table__) + with pytest.raises(ValueError): get_mapper(alias) class TestGetMapperForTableWithoutMapper(object): - def setup_method(self, method): + + @pytest.fixture + def building(self): metadata = sa.MetaData() - self.building = sa.Table('building', metadata) + return sa.Table('building', metadata) - def test_table(self): - with raises(ValueError): - get_mapper(self.building) + def test_table(self, building): + with pytest.raises(ValueError): + get_mapper(building) - def test_table_alias(self): - alias = sa.orm.aliased(self.building) - with raises(ValueError): + def test_table_alias(self, building): + alias = sa.orm.aliased(building) + with pytest.raises(ValueError): get_mapper(alias) diff --git a/tests/functions/test_get_primary_keys.py b/tests/functions/test_get_primary_keys.py index 73525be..d45490e 100644 --- a/tests/functions/test_get_primary_keys.py +++ b/tests/functions/test_get_primary_keys.py @@ -1,5 +1,5 @@ +import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_primary_keys @@ -9,40 +9,40 @@ except ImportError: from ordereddict import OrderedDict +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column('_name', sa.Unicode(255)) + return Building + + class TestGetPrimaryKeys(object): - def setup_method(self, method): - Base = declarative_base() - class Building(Base): - __tablename__ = 'building' - id = sa.Column('_id', sa.Integer, primary_key=True) - name = sa.Column('_name', sa.Unicode(255)) - - self.Building = Building - - def test_table(self): - assert get_primary_keys(self.Building.__table__) == OrderedDict({ - '_id': self.Building.__table__.c._id + def test_table(self, Building): + assert get_primary_keys(Building.__table__) == OrderedDict({ + '_id': Building.__table__.c._id }) - def test_declarative_class(self): - assert get_primary_keys(self.Building) == OrderedDict({ - 'id': self.Building.__table__.c._id + def test_declarative_class(self, Building): + assert get_primary_keys(Building) == OrderedDict({ + 'id': Building.__table__.c._id }) - def test_declarative_object(self): - assert get_primary_keys(self.Building()) == OrderedDict({ - 'id': self.Building.__table__.c._id + def test_declarative_object(self, Building): + assert get_primary_keys(Building()) == OrderedDict({ + 'id': Building.__table__.c._id }) - def test_class_alias(self): - alias = sa.orm.aliased(self.Building) + def test_class_alias(self, Building): + alias = sa.orm.aliased(Building) assert get_primary_keys(alias) == OrderedDict({ - 'id': self.Building.__table__.c._id + 'id': Building.__table__.c._id }) - def test_table_alias(self): - alias = sa.orm.aliased(self.Building.__table__) + def test_table_alias(self, Building): + alias = sa.orm.aliased(Building.__table__) assert get_primary_keys(alias) == OrderedDict({ '_id': alias.c._id }) diff --git a/tests/functions/test_get_query_entities.py b/tests/functions/test_get_query_entities.py index f1ae7c9..8fd48a2 100644 --- a/tests/functions/test_get_query_entities.py +++ b/tests/functions/test_get_query_entities.py @@ -1,102 +1,115 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import get_query_entities -from tests import TestCase -class TestGetQueryEntities(TestCase): - def create_models(self): - class TextItem(self.Base): - __tablename__ = 'text_item' - id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def TextItem(Base): + class TextItem(Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) - type = sa.Column(sa.Unicode(255)) + type = sa.Column(sa.Unicode(255)) - __mapper_args__ = { - 'polymorphic_on': type, - } + __mapper_args__ = { + 'polymorphic_on': type, + } + return TextItem - class Article(TextItem): - __tablename__ = 'article' - id = sa.Column( - sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True - ) - category = sa.Column(sa.Unicode(255)) - __mapper_args__ = { - 'polymorphic_identity': u'article' - } - class BlogPost(TextItem): - __tablename__ = 'blog_post' - id = sa.Column( - sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True - ) - __mapper_args__ = { - 'polymorphic_identity': u'blog_post' - } +@pytest.fixture +def Article(TextItem): + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + category = sa.Column(sa.Unicode(255)) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + return Article - self.TextItem = TextItem - self.Article = Article - self.BlogPost = BlogPost - def test_mapper(self): - query = self.session.query(sa.inspect(self.TextItem)) - assert get_query_entities(query) == [self.TextItem] +@pytest.fixture +def BlogPost(TextItem): + class BlogPost(TextItem): + __tablename__ = 'blog_post' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + return BlogPost - def test_entity(self): - query = self.session.query(self.TextItem) - assert get_query_entities(query) == [self.TextItem] - def test_instrumented_attribute(self): - query = self.session.query(self.TextItem.id) - assert get_query_entities(query) == [self.TextItem] +@pytest.fixture +def init_models(TextItem, Article, BlogPost): + pass - def test_column(self): - query = self.session.query(self.TextItem.__table__.c.id) - assert get_query_entities(query) == [self.TextItem.__table__] - def test_aliased_selectable(self): - selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) - query = self.session.query(selectable) +class TestGetQueryEntities(object): + + def test_mapper(self, session, TextItem): + query = session.query(sa.inspect(TextItem)) + assert get_query_entities(query) == [TextItem] + + def test_entity(self, session, TextItem): + query = session.query(TextItem) + assert get_query_entities(query) == [TextItem] + + def test_instrumented_attribute(self, session, TextItem): + query = session.query(TextItem.id) + assert get_query_entities(query) == [TextItem] + + def test_column(self, session, TextItem): + query = session.query(TextItem.__table__.c.id) + assert get_query_entities(query) == [TextItem.__table__] + + def test_aliased_selectable(self, session, TextItem, BlogPost): + selectable = sa.orm.with_polymorphic(TextItem, [BlogPost]) + query = session.query(selectable) assert get_query_entities(query) == [selectable] - def test_joined_entity(self): - query = self.session.query(self.TextItem).join( - self.BlogPost, self.BlogPost.id == self.TextItem.id + def test_joined_entity(self, session, TextItem, BlogPost): + query = session.query(TextItem).join( + BlogPost, BlogPost.id == TextItem.id ) assert get_query_entities(query) == [ - self.TextItem, sa.inspect(self.BlogPost) + TextItem, sa.inspect(BlogPost) ] - def test_joined_aliased_entity(self): - alias = sa.orm.aliased(self.BlogPost) + def test_joined_aliased_entity(self, session, TextItem, BlogPost): + alias = sa.orm.aliased(BlogPost) - query = self.session.query(self.TextItem).join( - alias, alias.id == self.TextItem.id + query = session.query(TextItem).join( + alias, alias.id == TextItem.id ) - assert get_query_entities(query) == [self.TextItem, alias] + assert get_query_entities(query) == [TextItem, alias] - def test_column_entity_with_label(self): - query = self.session.query(self.Article.id.label('id')) - assert get_query_entities(query) == [self.Article] + def test_column_entity_with_label(self, session, Article): + query = session.query(Article.id.label('id')) + assert get_query_entities(query) == [Article] - def test_with_subquery(self): + def test_with_subquery(self, session, Article): number_of_articles = ( sa.select( - [sa.func.count(self.Article.id)], + [sa.func.count(Article.id)], ) .select_from( - self.Article.__table__ + Article.__table__ ) ).label('number_of_articles') - query = self.session.query(self.Article, number_of_articles) + query = session.query(Article, number_of_articles) assert get_query_entities(query) == [ - self.Article, + Article, number_of_articles ] - def test_aliased_entity(self): - alias = sa.orm.aliased(self.Article) - query = self.session.query(alias) + def test_aliased_entity(self, session, Article): + alias = sa.orm.aliased(Article) + query = session.query(alias) assert get_query_entities(query) == [alias] diff --git a/tests/functions/test_get_referencing_foreign_keys.py b/tests/functions/test_get_referencing_foreign_keys.py index dc8a5de..bb4ffa0 100644 --- a/tests/functions/test_get_referencing_foreign_keys.py +++ b/tests/functions/test_get_referencing_foreign_keys.py @@ -1,17 +1,22 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import get_referencing_foreign_keys -from tests import TestCase -class TestGetReferencingFksWithCompositeKeys(TestCase): - def create_models(self): - class User(self.Base): +class TestGetReferencingFksWithCompositeKeys(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) + return User - class Article(self.Base): + @pytest.fixture + def Article(self, Base, User): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) @@ -22,22 +27,26 @@ class TestGetReferencingFksWithCompositeKeys(TestCase): [User.first_name, User.last_name] ), ) + return Article - self.User = User - self.Article = Article + @pytest.fixture + def init_models(self, User, Article): + pass - def test_with_declarative_class(self): - fks = get_referencing_foreign_keys(self.User) - assert self.Article.__table__.foreign_keys == fks + def test_with_declarative_class(self, User, Article): + fks = get_referencing_foreign_keys(User) + assert Article.__table__.foreign_keys == fks - def test_with_table(self): - fks = get_referencing_foreign_keys(self.User.__table__) - assert self.Article.__table__.foreign_keys == fks + def test_with_table(self, User, Article): + fks = get_referencing_foreign_keys(User.__table__) + assert Article.__table__.foreign_keys == fks -class TestGetReferencingFksWithInheritance(TestCase): - def create_models(self): - class User(self.Base): +class TestGetReferencingFksWithInheritance(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode) @@ -47,14 +56,20 @@ class TestGetReferencingFksWithInheritance(TestCase): __mapper_args__ = { 'polymorphic_on': 'type' } + return User + @pytest.fixture + def Admin(self, User): class Admin(User): __tablename__ = 'admin' id = sa.Column( sa.Integer, sa.ForeignKey(User.id), primary_key=True ) + return Admin - class TextItem(self.Base): + @pytest.fixture + def TextItem(self, Base, User): + class TextItem(Base): __tablename__ = 'textitem' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode) @@ -62,7 +77,10 @@ class TestGetReferencingFksWithInheritance(TestCase): __mapper_args__ = { 'polymorphic_on': 'type' } + return TextItem + @pytest.fixture + def Article(self, TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( @@ -71,16 +89,16 @@ class TestGetReferencingFksWithInheritance(TestCase): __mapper_args__ = { 'polymorphic_identity': 'article' } + return Article - self.Admin = Admin - self.User = User - self.Article = Article - self.TextItem = TextItem + @pytest.fixture + def init_models(self, User, Admin, TextItem, Article): + pass - def test_with_declarative_class(self): - fks = get_referencing_foreign_keys(self.Admin) - assert self.TextItem.__table__.foreign_keys == fks + def test_with_declarative_class(self, Admin, TextItem): + fks = get_referencing_foreign_keys(Admin) + assert TextItem.__table__.foreign_keys == fks - def test_with_table(self): - fks = get_referencing_foreign_keys(self.Admin.__table__) + def test_with_table(self, Admin): + fks = get_referencing_foreign_keys(Admin.__table__) assert fks == set([]) diff --git a/tests/functions/test_get_tables.py b/tests/functions/test_get_tables.py index 682e782..db01fef 100644 --- a/tests/functions/test_get_tables.py +++ b/tests/functions/test_get_tables.py @@ -1,76 +1,86 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import get_tables -from tests import TestCase -class TestGetTables(TestCase): - def create_models(self): - class TextItem(self.Base): - __tablename__ = 'text_item' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - type = sa.Column(sa.Unicode(255)) +@pytest.fixture +def TextItem(Base): + class TextItem(Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(sa.Unicode(255)) - __mapper_args__ = { - 'polymorphic_on': type, - 'with_polymorphic': '*' - } + __mapper_args__ = { + 'polymorphic_on': type, + 'with_polymorphic': '*' + } + return TextItem - class Article(TextItem): - __tablename__ = 'article' - id = sa.Column( - sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True - ) - __mapper_args__ = { - 'polymorphic_identity': u'article' - } - self.TextItem = TextItem - self.Article = Article +@pytest.fixture +def Article(TextItem): + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + return Article - def test_child_class_using_join_table_inheritance(self): - assert get_tables(self.Article) == [ - self.TextItem.__table__, - self.Article.__table__ + +@pytest.fixture +def init_models(TextItem, Article): + pass + + +class TestGetTables(object): + + def test_child_class_using_join_table_inheritance(self, TextItem, Article): + assert get_tables(Article) == [ + TextItem.__table__, + Article.__table__ ] - def test_entity_using_with_polymorphic(self): - assert get_tables(self.TextItem) == [ - self.TextItem.__table__, - self.Article.__table__ + def test_entity_using_with_polymorphic(self, TextItem, Article): + assert get_tables(TextItem) == [ + TextItem.__table__, + Article.__table__ ] - def test_instrumented_attribute(self): - assert get_tables(self.TextItem.name) == [ - self.TextItem.__table__, + def test_instrumented_attribute(self, TextItem): + assert get_tables(TextItem.name) == [ + TextItem.__table__, ] - def test_polymorphic_instrumented_attribute(self): - assert get_tables(self.Article.id) == [ - self.TextItem.__table__, - self.Article.__table__ + def test_polymorphic_instrumented_attribute(self, TextItem, Article): + assert get_tables(Article.id) == [ + TextItem.__table__, + Article.__table__ ] - def test_column(self): - assert get_tables(self.Article.__table__.c.id) == [ - self.Article.__table__ + def test_column(self, Article): + assert get_tables(Article.__table__.c.id) == [ + Article.__table__ ] - def test_mapper_entity_with_class(self): - query = self.session.query(self.Article) + def test_mapper_entity_with_class(self, session, TextItem, Article): + query = session.query(Article) assert get_tables(query._entities[0]) == [ - self.TextItem.__table__, self.Article.__table__ + TextItem.__table__, Article.__table__ ] - def test_mapper_entity_with_mapper(self): - query = self.session.query(sa.inspect(self.Article)) + def test_mapper_entity_with_mapper(self, session, TextItem, Article): + query = session.query(sa.inspect(Article)) assert get_tables(query._entities[0]) == [ - self.TextItem.__table__, self.Article.__table__ + TextItem.__table__, Article.__table__ ] - def test_column_entity(self): - query = self.session.query(self.Article.id) + def test_column_entity(self, session, TextItem, Article): + query = session.query(Article.id) assert get_tables(query._entities[0]) == [ - self.TextItem.__table__, self.Article.__table__ + TextItem.__table__, Article.__table__ ] diff --git a/tests/functions/test_get_type.py b/tests/functions/test_get_type.py index 652e69b..edd0d32 100644 --- a/tests/functions/test_get_type.py +++ b/tests/functions/test_get_type.py @@ -1,46 +1,49 @@ +import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_type +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + return User + + +@pytest.fixture +def Article(Base, User): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + author = sa.orm.relationship(User) + + some_property = sa.orm.column_property( + sa.func.coalesce(id, 1) + ) + return Article + + class TestGetType(object): - def setup_method(self, method): - Base = declarative_base() - class User(Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) + def test_instrumented_attribute(self, Article): + assert isinstance(get_type(Article.id), sa.Integer) - class Article(Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) + def test_column_property(self, Article): + assert isinstance(get_type(Article.id.property), sa.Integer) - author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) - author = sa.orm.relationship(User) + def test_column(self, Article): + assert isinstance(get_type(Article.__table__.c.id), sa.Integer) - some_property = sa.orm.column_property( - sa.func.coalesce(id, 1) - ) + def test_calculated_column_property(self, Article): + assert isinstance(get_type(Article.some_property), sa.Integer) - self.Article = Article - self.User = User + def test_relationship_property(self, Article, User): + assert get_type(Article.author) == User - def test_instrumented_attribute(self): - assert isinstance(get_type(self.Article.id), sa.Integer) - - def test_column_property(self): - assert isinstance(get_type(self.Article.id.property), sa.Integer) - - def test_column(self): - assert isinstance(get_type(self.Article.__table__.c.id), sa.Integer) - - def test_calculated_column_property(self): - assert isinstance(get_type(self.Article.some_property), sa.Integer) - - def test_relationship_property(self): - assert get_type(self.Article.author) == self.User - - def test_scalar_select(self): - query = sa.select([self.Article.id]).as_scalar() + def test_scalar_select(self, Article): + query = sa.select([Article.id]).as_scalar() assert isinstance(get_type(query), sa.Integer) diff --git a/tests/functions/test_getdotattr.py b/tests/functions/test_getdotattr.py index c768ecb..fa691c8 100644 --- a/tests/functions/test_getdotattr.py +++ b/tests/functions/test_getdotattr.py @@ -1,72 +1,94 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import getdotattr -from tests import TestCase -class TestGetDotAttr(TestCase): - def create_models(self): - class Document(self.Base): - __tablename__ = 'document' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def Document(Base): + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + return Document - class Section(self.Base): - __tablename__ = 'section' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - document_id = sa.Column( - sa.Integer, sa.ForeignKey(Document.id) - ) +@pytest.fixture +def Section(Base, Document): + class Section(Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - document = sa.orm.relationship(Document, backref='sections') + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) - class SubSection(self.Base): - __tablename__ = 'subsection' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) + document = sa.orm.relationship(Document, backref='sections') + return Section - section_id = sa.Column( - sa.Integer, sa.ForeignKey(Section.id) - ) - section = sa.orm.relationship(Section, backref='subsections') +@pytest.fixture +def SubSection(Base, Section): + class SubSection(Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) - class SubSubSection(self.Base): - __tablename__ = 'subsubsection' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) - subsection_id = sa.Column( - sa.Integer, sa.ForeignKey(SubSection.id) - ) + section = sa.orm.relationship(Section, backref='subsections') + return SubSection - subsection = sa.orm.relationship( - SubSection, backref='subsubsections' - ) - self.Document = Document - self.Section = Section - self.SubSection = SubSection - self.SubSubSection = SubSubSection +@pytest.fixture +def SubSubSection(Base, SubSection): + class SubSubSection(Base): + __tablename__ = 'subsubsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) - def test_simple_objects(self): - document = self.Document(name=u'some document') - section = self.Section(document=document) - subsection = self.SubSection(section=section) + subsection_id = sa.Column( + sa.Integer, sa.ForeignKey(SubSection.id) + ) + + subsection = sa.orm.relationship( + SubSection, backref='subsubsections' + ) + return SubSubSection + + +@pytest.fixture +def init_models(Document, Section, SubSection, SubSubSection): + pass + + +class TestGetDotAttr(object): + + def test_simple_objects(self, Document, Section, SubSection): + document = Document(name=u'some document') + section = Section(document=document) + subsection = SubSection(section=section) assert getdotattr( subsection, 'section.document.name' ) == u'some document' - def test_with_instrumented_lists(self): - document = self.Document(name=u'some document') - section = self.Section(document=document) - subsection = self.SubSection(section=section) - subsubsection = self.SubSubSection(subsection=subsection) + def test_with_instrumented_lists( + self, + Document, + Section, + SubSection, + SubSubSection + ): + document = Document(name=u'some document') + section = Section(document=document) + subsection = SubSection(section=section) + subsubsection = SubSubSection(subsection=subsection) assert getdotattr(document, 'sections') == [section] assert getdotattr(document, 'sections.subsections') == [ @@ -76,10 +98,10 @@ class TestGetDotAttr(TestCase): subsubsection ] - def test_class_paths(self): - assert getdotattr(self.Section, 'document') is self.Section.document + def test_class_paths(self, Document, Section, SubSection): + assert getdotattr(Section, 'document') is Section.document assert ( - getdotattr(self.SubSection, 'section.document') is - self.Section.document + getdotattr(SubSection, 'section.document') is + Section.document ) - assert getdotattr(self.Section, 'document.name') is self.Document.name + assert getdotattr(Section, 'document.name') is Document.name diff --git a/tests/functions/test_has_changes.py b/tests/functions/test_has_changes.py index 6fc2684..3a30be5 100644 --- a/tests/functions/test_has_changes.py +++ b/tests/functions/test_has_changes.py @@ -1,47 +1,44 @@ +import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import has_changes -class HasChangesTestCase(object): - def setup_method(self, method): - Base = declarative_base() - - class Article(Base): - __tablename__ = 'article_translation' - id = sa.Column(sa.Integer, primary_key=True) - title = sa.Column(sa.String(100)) - - self.Article = Article +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String(100)) + return Article -class TestHasChangesWithStringAttr(HasChangesTestCase): - def test_without_changed_attr(self): - article = self.Article() +class TestHasChangesWithStringAttr(object): + def test_without_changed_attr(self, Article): + article = Article() assert not has_changes(article, 'title') - def test_with_changed_attr(self): - article = self.Article(title='Some title') + def test_with_changed_attr(self, Article): + article = Article(title='Some title') assert has_changes(article, 'title') -class TestHasChangesWithMultipleAttrs(HasChangesTestCase): - def test_without_changed_attr(self): - article = self.Article() +class TestHasChangesWithMultipleAttrs(object): + def test_without_changed_attr(self, Article): + article = Article() assert not has_changes(article, ['title']) - def test_with_changed_attr(self): - article = self.Article(title='Some title') + def test_with_changed_attr(self, Article): + article = Article(title='Some title') assert has_changes(article, ['title', 'id']) -class TestHasChangesWithExclude(HasChangesTestCase): - def test_without_changed_attr(self): - article = self.Article() +class TestHasChangesWithExclude(object): + def test_without_changed_attr(self, Article): + article = Article() assert not has_changes(article, exclude=['id']) - def test_with_changed_attr(self): - article = self.Article(title='Some title') + def test_with_changed_attr(self, Article): + article = Article(title='Some title') assert has_changes(article, exclude=['id']) assert not has_changes(article, exclude=['title']) diff --git a/tests/functions/test_has_index.py b/tests/functions/test_has_index.py index e0b8f0b..75cc8a6 100644 --- a/tests/functions/test_has_index.py +++ b/tests/functions/test_has_index.py @@ -1,14 +1,13 @@ +import pytest import sqlalchemy as sa -from pytest import raises -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_fk_constraint_for_columns, has_index class TestHasIndex(object): - def setup_method(self, method): - Base = declarative_base() + @pytest.fixture + def table(self, Base): class ArticleTranslation(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) @@ -21,24 +20,23 @@ class TestHasIndex(object): __table_args__ = ( sa.Index('my_index', is_deleted, is_archived), ) + return ArticleTranslation.__table__ - self.table = ArticleTranslation.__table__ - - def test_column_that_belongs_to_an_alias(self): - alias = sa.orm.aliased(self.table) - with raises(TypeError): + def test_column_that_belongs_to_an_alias(self, table): + alias = sa.orm.aliased(table) + with pytest.raises(TypeError): assert has_index(alias.c.id) - def test_compound_primary_key(self): - assert has_index(self.table.c.id) - assert not has_index(self.table.c.locale) + def test_compound_primary_key(self, table): + assert has_index(table.c.id) + assert not has_index(table.c.locale) - def test_single_column_index(self): - assert has_index(self.table.c.is_published) + def test_single_column_index(self, table): + assert has_index(table.c.is_published) - def test_compound_column_index(self): - assert has_index(self.table.c.is_deleted) - assert not has_index(self.table.c.is_archived) + def test_compound_column_index(self, table): + assert has_index(table.c.is_deleted) + assert not has_index(table.c.is_archived) def test_table_without_primary_key(self): article = sa.Table( @@ -50,8 +48,7 @@ class TestHasIndex(object): class TestHasIndexWithFKConstraint(object): - def test_composite_fk_without_index(self): - Base = declarative_base() + def test_composite_fk_without_index(self, Base): class User(Base): __tablename__ = 'user' @@ -78,8 +75,7 @@ class TestHasIndexWithFKConstraint(object): ) assert not has_index(constraint) - def test_composite_fk_with_index(self): - Base = declarative_base() + def test_composite_fk_with_index(self, Base): class User(Base): __tablename__ = 'user' @@ -109,8 +105,7 @@ class TestHasIndexWithFKConstraint(object): ) assert has_index(constraint) - def test_composite_fk_with_partial_index_match(self): - Base = declarative_base() + def test_composite_fk_with_partial_index_match(self, Base): class User(Base): __tablename__ = 'user' diff --git a/tests/functions/test_has_unique_index.py b/tests/functions/test_has_unique_index.py index 4db3f21..0794b8d 100644 --- a/tests/functions/test_has_unique_index.py +++ b/tests/functions/test_has_unique_index.py @@ -1,18 +1,20 @@ +import pytest import sqlalchemy as sa -from pytest import raises -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_fk_constraint_for_columns, has_unique_index class TestHasUniqueIndex(object): - def setup_method(self, method): - Base = declarative_base() + @pytest.fixture + def articles(self, Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) + return Article.__table__ + @pytest.fixture + def article_translations(self, Base): class ArticleTranslation(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) @@ -26,35 +28,33 @@ class TestHasUniqueIndex(object): sa.Index('my_index', is_archived, is_published, unique=True), ) - self.articles = Article.__table__ - self.article_translations = ArticleTranslation.__table__ + return ArticleTranslation.__table__ - def test_primary_key(self): - assert has_unique_index(self.articles.c.id) + def test_primary_key(self, articles): + assert has_unique_index(articles.c.id) - def test_column_of_aliased_table(self): - alias = sa.orm.aliased(self.articles) - with raises(TypeError): + def test_column_of_aliased_table(self, articles): + alias = sa.orm.aliased(articles) + with pytest.raises(TypeError): assert has_unique_index(alias.c.id) - def test_unique_index(self): - assert has_unique_index(self.article_translations.c.is_deleted) + def test_unique_index(self, article_translations): + assert has_unique_index(article_translations.c.is_deleted) - def test_compound_primary_key(self): - assert not has_unique_index(self.article_translations.c.id) - assert not has_unique_index(self.article_translations.c.locale) + def test_compound_primary_key(self, article_translations): + assert not has_unique_index(article_translations.c.id) + assert not has_unique_index(article_translations.c.locale) - def test_single_column_index(self): - assert not has_unique_index(self.article_translations.c.is_published) + def test_single_column_index(self, article_translations): + assert not has_unique_index(article_translations.c.is_published) - def test_compound_column_unique_index(self): - assert not has_unique_index(self.article_translations.c.is_published) - assert not has_unique_index(self.article_translations.c.is_archived) + def test_compound_column_unique_index(self, article_translations): + assert not has_unique_index(article_translations.c.is_published) + assert not has_unique_index(article_translations.c.is_archived) class TestHasUniqueIndexWithFKConstraint(object): - def test_composite_fk_without_index(self): - Base = declarative_base() + def test_composite_fk_without_index(self, Base): class User(Base): __tablename__ = 'user' @@ -81,8 +81,7 @@ class TestHasUniqueIndexWithFKConstraint(object): ) assert not has_unique_index(constraint) - def test_composite_fk_with_index(self): - Base = declarative_base() + def test_composite_fk_with_index(self, Base): class User(Base): __tablename__ = 'user' @@ -115,8 +114,7 @@ class TestHasUniqueIndexWithFKConstraint(object): ) assert has_unique_index(constraint) - def test_composite_fk_with_partial_index_match(self): - Base = declarative_base() + def test_composite_fk_with_partial_index_match(self, Base): class User(Base): __tablename__ = 'user' diff --git a/tests/functions/test_identity.py b/tests/functions/test_identity.py index d1d3d42..c3d5c3d 100644 --- a/tests/functions/test_identity.py +++ b/tests/functions/test_identity.py @@ -1,39 +1,46 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import identity -from tests import TestCase -class IdentityTestCase(TestCase): - def test_for_transient_class_without_id(self): - assert identity(self.Building()) == (None, ) +class IdentityTestCase(object): - def test_for_transient_class_with_id(self): - building = self.Building(name=u'Some building') - self.session.add(building) - self.session.flush() + @pytest.fixture + def init_models(self, Building): + pass + + def test_for_transient_class_without_id(self, Building): + assert identity(Building()) == (None, ) + + def test_for_transient_class_with_id(self, session, Building): + building = Building(name=u'Some building') + session.add(building) + session.flush() assert identity(building) == (building.id, ) - def test_identity_for_class(self): - assert identity(self.Building) == (self.Building.id, ) + def test_identity_for_class(self, Building): + assert identity(Building) == (Building.id, ) class TestIdentity(IdentityTestCase): - def create_models(self): - class Building(self.Base): + + @pytest.fixture + def Building(self, Base): + class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - - self.Building = Building + return Building class TestIdentityWithColumnAlias(IdentityTestCase): - def create_models(self): - class Building(self.Base): + + @pytest.fixture + def Building(self, Base): + class Building(Base): __tablename__ = 'building' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) - - self.Building = Building + return Building diff --git a/tests/functions/test_is_loaded.py b/tests/functions/test_is_loaded.py index 68d7d24..8368526 100644 --- a/tests/functions/test_is_loaded.py +++ b/tests/functions/test_is_loaded.py @@ -1,24 +1,24 @@ +import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import is_loaded +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.orm.deferred(sa.Column(sa.String(100))) + return Article + + class TestIsLoaded(object): - def setup_method(self, method): - Base = declarative_base() - class Article(Base): - __tablename__ = 'article_translation' - id = sa.Column(sa.Integer, primary_key=True) - title = sa.orm.deferred(sa.Column(sa.String(100))) - - self.Article = Article - - def test_loaded_property(self): - article = self.Article(id=1) + def test_loaded_property(self, Article): + article = Article(id=1) assert is_loaded(article, 'id') - def test_unloaded_property(self): - article = self.Article(id=4) + def test_unloaded_property(self, Article): + article = Article(id=4) assert not is_loaded(article, 'title') diff --git a/tests/functions/test_json_sql.py b/tests/functions/test_json_sql.py index bb3766a..7e94746 100644 --- a/tests/functions/test_json_sql.py +++ b/tests/functions/test_json_sql.py @@ -2,11 +2,10 @@ import pytest import sqlalchemy as sa from sqlalchemy_utils import json_sql -from tests import TestCase -class TestJSONSQL(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestJSONSQL(object): @pytest.mark.parametrize( ('value', 'result'), @@ -27,7 +26,7 @@ class TestJSONSQL(TestCase): ) ) ) - def test_compiled_scalars(self, value, result): + def test_compiled_scalars(self, connection, value, result): assert result == ( - self.connection.execute(sa.select([json_sql(value)])).fetchone()[0] + connection.execute(sa.select([json_sql(value)])).fetchone()[0] ) diff --git a/tests/functions/test_make_order_by_deterministic.py b/tests/functions/test_make_order_by_deterministic.py index 632b1ca..6e691f3 100644 --- a/tests/functions/test_make_order_by_deterministic.py +++ b/tests/functions/test_make_order_by_deterministic.py @@ -1,90 +1,102 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic -from tests import assert_contains, TestCase + +from .. import assert_contains -class TestMakeOrderByDeterministic(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode) - email = sa.Column(sa.Unicode, unique=True) +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + author = sa.orm.relationship('User') + return Article - email_lower = sa.orm.column_property( - sa.func.lower(name) - ) - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) - author = sa.orm.relationship(User) +@pytest.fixture +def User(Base, Article): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode) + email = sa.Column(sa.Unicode, unique=True) - User.article_count = sa.orm.column_property( - sa.select([sa.func.count()], from_obj=Article) - .where(Article.author_id == User.id) - .label('article_count') + email_lower = sa.orm.column_property( + sa.func.lower(name) ) - self.User = User - self.Article = Article + User.article_count = sa.orm.column_property( + sa.select([sa.func.count()], from_obj=Article) + .where(Article.author_id == User.id) + .label('article_count') + ) + return User - def test_column_property(self): - query = self.session.query(self.User).order_by(self.User.email_lower) + +@pytest.fixture +def init_models(Article, User): + pass + + +class TestMakeOrderByDeterministic(object): + + def test_column_property(self, session, User): + query = session.query(User).order_by(User.email_lower) query = make_order_by_deterministic(query) assert_contains('lower("user".name), "user".id ASC', query) - def test_unique_column(self): - query = self.session.query(self.User).order_by(self.User.email) + def test_unique_column(self, session, User): + query = session.query(User).order_by(User.email) query = make_order_by_deterministic(query) assert str(query).endswith('ORDER BY "user".email') - def test_non_unique_column(self): - query = self.session.query(self.User).order_by(self.User.name) + def test_non_unique_column(self, session, User): + query = session.query(User).order_by(User.name) query = make_order_by_deterministic(query) assert_contains('ORDER BY "user".name, "user".id ASC', query) - def test_descending_order_by(self): - query = self.session.query(self.User).order_by( - sa.desc(self.User.name) + def test_descending_order_by(self, session, User): + query = session.query(User).order_by( + sa.desc(User.name) ) query = make_order_by_deterministic(query) assert_contains('ORDER BY "user".name DESC, "user".id DESC', query) - def test_ascending_order_by(self): - query = self.session.query(self.User).order_by( - sa.asc(self.User.name) + def test_ascending_order_by(self, session, User): + query = session.query(User).order_by( + sa.asc(User.name) ) query = make_order_by_deterministic(query) assert_contains('ORDER BY "user".name ASC, "user".id ASC', query) - def test_string_order_by(self): - query = self.session.query(self.User).order_by('name') + def test_string_order_by(self, session, User): + query = session.query(User).order_by('name') query = make_order_by_deterministic(query) assert_contains('ORDER BY "user".name, "user".id ASC', query) - def test_annotated_label(self): - query = self.session.query(self.User).order_by(self.User.article_count) + def test_annotated_label(self, session, User): + query = session.query(User).order_by(User.article_count) query = make_order_by_deterministic(query) assert_contains('article_count, "user".id ASC', query) - def test_annotated_label_with_descending_order(self): - query = self.session.query(self.User).order_by( - sa.desc(self.User.article_count) + def test_annotated_label_with_descending_order(self, session, User): + query = session.query(User).order_by( + sa.desc(User.article_count) ) query = make_order_by_deterministic(query) assert_contains('ORDER BY article_count DESC, "user".id DESC', query) - def test_query_without_order_by(self): - query = self.session.query(self.User) + def test_query_without_order_by(self, session, User): + query = session.query(User) query = make_order_by_deterministic(query) assert 'ORDER BY "user".id' in str(query) - def test_alias(self): - alias = sa.orm.aliased(self.User.__table__) - query = self.session.query(alias).order_by(alias.c.name) + def test_alias(self, session, User): + alias = sa.orm.aliased(User.__table__) + query = session.query(alias).order_by(alias.c.name) query = make_order_by_deterministic(query) assert str(query).endswith('ORDER BY user_1.name, "user".id ASC') diff --git a/tests/functions/test_merge_references.py b/tests/functions/test_merge_references.py index e97f7a2..216724a 100644 --- a/tests/functions/test_merge_references.py +++ b/tests/functions/test_merge_references.py @@ -1,20 +1,25 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import merge_references -from tests import TestCase -class TestMergeReferences(TestCase): - def create_models(self): - class User(self.Base): +class TestMergeReferences(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) def __repr__(self): return 'User(%r)' % self.name + return User - class BlogPost(self.Base): + @pytest.fixture + def BlogPost(self, Base, User): + class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.Unicode(255)) @@ -22,35 +27,37 @@ class TestMergeReferences(TestCase): author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) + return BlogPost - self.User = User - self.BlogPost = BlogPost + @pytest.fixture + def init_models(self, User, BlogPost): + pass - def test_updates_foreign_keys(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - post = self.BlogPost(title=u'Some title', author=john) - post2 = self.BlogPost(title=u'Other title', author=jack) - self.session.add(john) - self.session.add(jack) - self.session.add(post) - self.session.add(post2) - self.session.commit() + def test_updates_foreign_keys(self, session, User, BlogPost): + john = User(name=u'John') + jack = User(name=u'Jack') + post = BlogPost(title=u'Some title', author=john) + post2 = BlogPost(title=u'Other title', author=jack) + session.add(john) + session.add(jack) + session.add(post) + session.add(post2) + session.commit() merge_references(john, jack) - self.session.commit() + session.commit() assert post.author == jack assert post2.author == jack - def test_object_merging_whenever_possible(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - post = self.BlogPost(title=u'Some title', author=john) - post2 = self.BlogPost(title=u'Other title', author=jack) - self.session.add(john) - self.session.add(jack) - self.session.add(post) - self.session.add(post2) - self.session.commit() + def test_object_merging_whenever_possible(self, session, User, BlogPost): + john = User(name=u'John') + jack = User(name=u'Jack') + post = BlogPost(title=u'Some title', author=john) + post2 = BlogPost(title=u'Other title', author=jack) + session.add(john) + session.add(jack) + session.add(post) + session.add(post2) + session.commit() # Load the author for post assert post.author_id == john.id merge_references(john, jack) @@ -58,18 +65,23 @@ class TestMergeReferences(TestCase): assert post2.author_id == jack.id -class TestMergeReferencesWithManyToManyAssociations(TestCase): - def create_models(self): - class User(self.Base): +class TestMergeReferencesWithManyToManyAssociations(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) def __repr__(self): return 'User(%r)' % self.name + return User + @pytest.fixture + def Team(self, Base): team_member = sa.Table( - 'team_member', self.Base.metadata, + 'team_member', Base.metadata, sa.Column( 'user_id', sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE'), @@ -82,46 +94,56 @@ class TestMergeReferencesWithManyToManyAssociations(TestCase): ) ) - class Team(self.Base): + class Team(Base): __tablename__ = 'team' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) members = sa.orm.relationship( - User, + 'User', secondary=team_member, backref='teams' ) + return Team - self.User = User - self.Team = Team + @pytest.fixture + def init_models(self, User, Team): + pass - def test_supports_associations(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - team = self.Team(name=u'Team') + def test_supports_associations(self, session, User, Team): + john = User(name=u'John') + jack = User(name=u'Jack') + team = Team(name=u'Team') team.members.append(john) - self.session.add(john) - self.session.add(jack) - self.session.commit() + session.add(john) + session.add(jack) + session.commit() merge_references(john, jack) assert john not in team.members assert jack in team.members -class TestMergeReferencesWithManyToManyAssociationObjects(TestCase): - def create_models(self): - class Team(self.Base): +class TestMergeReferencesWithManyToManyAssociationObjects(object): + + @pytest.fixture + def Team(self, Base): + class Team(Base): __tablename__ = 'team' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) + return Team - class User(self.Base): + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) + return User - class TeamMember(self.Base): + @pytest.fixture + def TeamMember(self, Base, User, Team): + class TeamMember(Base): __tablename__ = 'team_member' user_id = sa.Column( sa.Integer, @@ -150,22 +172,23 @@ class TestMergeReferencesWithManyToManyAssociationObjects(TestCase): ), primaryjoin=user_id == User.id, ) + return TeamMember - self.User = User - self.TeamMember = TeamMember - self.Team = Team + @pytest.fixture + def init_models(self, User, Team, TeamMember): + pass - def test_supports_associations(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - team = self.Team(name=u'Team') - team.members.append(self.TeamMember(user=john)) - self.session.add(john) - self.session.add(jack) - self.session.add(team) - self.session.commit() + def test_supports_associations(self, session, User, Team, TeamMember): + john = User(name=u'John') + jack = User(name=u'Jack') + team = Team(name=u'Team') + team.members.append(TeamMember(user=john)) + session.add(john) + session.add(jack) + session.add(team) + session.commit() merge_references(john, jack) - self.session.commit() + session.commit() users = [member.user for member in team.members] assert john not in users assert jack in users diff --git a/tests/functions/test_naturally_equivalent.py b/tests/functions/test_naturally_equivalent.py index c443e4d..e8dd084 100644 --- a/tests/functions/test_naturally_equivalent.py +++ b/tests/functions/test_naturally_equivalent.py @@ -1,14 +1,13 @@ from sqlalchemy_utils.functions import naturally_equivalent -from tests import TestCase -class TestNaturallyEquivalent(TestCase): - def test_returns_true_when_properties_match(self): +class TestNaturallyEquivalent(object): + def test_returns_true_when_properties_match(self, User): assert naturally_equivalent( - self.User(name=u'someone'), self.User(name=u'someone') + User(name=u'someone'), User(name=u'someone') ) - def test_skips_primary_keys(self): + def test_skips_primary_keys(self, User): assert naturally_equivalent( - self.User(id=1, name=u'someone'), self.User(id=2, name=u'someone') + User(id=1, name=u'someone'), User(id=2, name=u'someone') ) diff --git a/tests/functions/test_non_indexed_foreign_keys.py b/tests/functions/test_non_indexed_foreign_keys.py index 3c1791e..80b4482 100644 --- a/tests/functions/test_non_indexed_foreign_keys.py +++ b/tests/functions/test_non_indexed_foreign_keys.py @@ -1,24 +1,32 @@ from itertools import chain +import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import non_indexed_foreign_keys -from tests import TestCase -class TestFindNonIndexedForeignKeys(TestCase): - def create_models(self): - class User(self.Base): +class TestFindNonIndexedForeignKeys(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) + return User - class Category(self.Base): + @pytest.fixture + def Category(self, Base): + class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) + return Category - class Article(self.Base): + @pytest.fixture + def Article(self, Base, User, Category): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @@ -34,13 +42,14 @@ class TestFindNonIndexedForeignKeys(TestCase): 'articles', ) ) + return Article - self.User = User - self.Category = Category - self.Article = Article + @pytest.fixture + def init_models(self, User, Category, Article): + pass - def test_finds_all_non_indexed_fks(self): - fks = non_indexed_foreign_keys(self.Base.metadata, self.engine) + def test_finds_all_non_indexed_fks(self, session, Base, engine): + fks = non_indexed_foreign_keys(Base.metadata, engine) assert ( 'article' in fks diff --git a/tests/functions/test_quote.py b/tests/functions/test_quote.py index 85b7a31..93e4d93 100644 --- a/tests/functions/test_quote.py +++ b/tests/functions/test_quote.py @@ -1,18 +1,22 @@ from sqlalchemy.dialects import postgresql from sqlalchemy_utils.functions import quote -from tests import TestCase -class TestQuote(TestCase): - def test_quote_with_preserved_keyword(self): - assert quote(self.connection, 'order') == '"order"' - assert quote(self.session, 'order') == '"order"' - assert quote(self.engine, 'order') == '"order"' +class TestQuote(object): + def test_quote_with_preserved_keyword(self, engine, connection, session): + assert quote(connection, 'order') == '"order"' + assert quote(session, 'order') == '"order"' + assert quote(engine, 'order') == '"order"' assert quote(postgresql.dialect(), 'order') == '"order"' - def test_quote_with_non_preserved_keyword(self): - assert quote(self.connection, 'some_order') == 'some_order' - assert quote(self.session, 'some_order') == 'some_order' - assert quote(self.engine, 'some_order') == 'some_order' + def test_quote_with_non_preserved_keyword( + self, + engine, + connection, + session + ): + assert quote(connection, 'some_order') == 'some_order' + assert quote(session, 'some_order') == 'some_order' + assert quote(engine, 'some_order') == 'some_order' assert quote(postgresql.dialect(), 'some_order') == 'some_order' diff --git a/tests/functions/test_render.py b/tests/functions/test_render.py index 2927617..93208ed 100644 --- a/tests/functions/test_render.py +++ b/tests/functions/test_render.py @@ -1,3 +1,4 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import ( @@ -5,52 +6,58 @@ from sqlalchemy_utils.functions import ( render_expression, render_statement ) -from tests import TestCase -class TestRender(TestCase): - def create_models(self): - class User(self.Base): +class TestRender(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) + return User - self.User = User + @pytest.fixture + def init_models(self, User): + pass - def test_render_orm_query(self): - query = self.session.query(self.User).filter_by(id=3) + def test_render_orm_query(self, session, User): + query = session.query(User).filter_by(id=3) text = render_statement(query) assert 'SELECT user.id, user.name' in text assert 'FROM user' in text assert 'WHERE user.id = 3' in text - def test_render_statement(self): - statement = self.User.__table__.select().where(self.User.id == 3) - text = render_statement(statement, bind=self.session.bind) + def test_render_statement(self, session, User): + statement = User.__table__.select().where(User.id == 3) + text = render_statement(statement, bind=session.bind) assert 'SELECT user.id, user.name' in text assert 'FROM user' in text assert 'WHERE user.id = 3' in text - def test_render_statement_without_mapper(self): + def test_render_statement_without_mapper(self, session): statement = sa.select([sa.text('1')]) - text = render_statement(statement, bind=self.session.bind) + text = render_statement(statement, bind=session.bind) assert 'SELECT 1' in text - def test_render_ddl(self): - expression = 'self.User.__table__.create(engine)' - stream = render_expression(expression, self.engine) + def test_render_ddl(self, engine, User): + expression = 'User.__table__.create(engine)' + stream = render_expression(expression, engine) text = stream.getvalue() assert 'CREATE TABLE user' in text assert 'PRIMARY KEY' in text - def test_render_mock_ddl(self): + def test_render_mock_ddl(self, engine, User): + # TODO: mock_engine doesn't seem to work with locally scoped variables. + self.engine = engine with mock_engine('self.engine') as stream: - self.User.__table__.create(self.engine) + User.__table__.create(self.engine) text = stream.getvalue() diff --git a/tests/functions/test_table_name.py b/tests/functions/test_table_name.py index 6018110..e56d965 100644 --- a/tests/functions/test_table_name.py +++ b/tests/functions/test_table_name.py @@ -1,26 +1,33 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import table_name -from tests import TestCase -class TestTableName(TestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + return Building - self.Building = Building - def test_class(self): - assert table_name(self.Building) == 'building' - del self.Building.__tablename__ - assert table_name(self.Building) == 'building' +@pytest.fixture +def init_models(Base): + pass - def test_attribute(self): - assert table_name(self.Building.id) == 'building' - assert table_name(self.Building.name) == 'building' - def test_target(self): - assert table_name(self.Building()) == 'building' +class TestTableName(object): + + def test_class(self, Building): + assert table_name(Building) == 'building' + del Building.__tablename__ + assert table_name(Building) == 'building' + + def test_attribute(self, Building): + assert table_name(Building.id) == 'building' + assert table_name(Building.name) == 'building' + + def test_target(self, Building): + assert table_name(Building()) == 'building' diff --git a/tests/generic_relationship/__init__.py b/tests/generic_relationship/__init__.py index 9b2de76..d5a69bf 100644 --- a/tests/generic_relationship/__init__.py +++ b/tests/generic_relationship/__init__.py @@ -1,109 +1,105 @@ -from __future__ import unicode_literals - import six -from tests import TestCase - -class GenericRelationshipTestCase(TestCase): - def test_set_as_none(self): - event = self.Event() +class GenericRelationshipTestCase(object): + def test_set_as_none(self, Event): + event = Event() event.object = None assert event.object is None - def test_set_manual_and_get(self): - user = self.User() + def test_set_manual_and_get(self, session, User, Event): + user = User() - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - event = self.Event() + event = Event() event.object_id = user.id event.object_type = six.text_type(type(user).__name__) assert event.object is None - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == user - def test_set_and_get(self): - user = self.User() + def test_set_and_get(self, session, User, Event): + user = User() - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - event = self.Event(object=user) + event = Event(object=user) assert event.object_id == user.id assert event.object_type == type(user).__name__ - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == user - def test_compare_instance(self): - user1 = self.User() - user2 = self.User() + def test_compare_instance(self, session, User, Event): + user1 = User() + user2 = User() - self.session.add_all([user1, user2]) - self.session.commit() + session.add_all([user1, user2]) + session.commit() - event = self.Event(object=user1) + event = Event(object=user1) - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == user1 assert event.object != user2 - def test_compare_query(self): - user1 = self.User() - user2 = self.User() + def test_compare_query(self, session, User, Event): + user1 = User() + user2 = User() - self.session.add_all([user1, user2]) - self.session.commit() + session.add_all([user1, user2]) + session.commit() - event = self.Event(object=user1) + event = Event(object=user1) - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() - q = self.session.query(self.Event) + q = session.query(Event) assert q.filter_by(object=user1).first() is not None assert q.filter_by(object=user2).first() is None - assert q.filter(self.Event.object == user2).first() is None + assert q.filter(Event.object == user2).first() is None - def test_compare_not_query(self): - user1 = self.User() - user2 = self.User() + def test_compare_not_query(self, session, User, Event): + user1 = User() + user2 = User() - self.session.add_all([user1, user2]) - self.session.commit() + session.add_all([user1, user2]) + session.commit() - event = self.Event(object=user1) + event = Event(object=user1) - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() - q = self.session.query(self.Event) - assert q.filter(self.Event.object != user2).first() is not None + q = session.query(Event) + assert q.filter(Event.object != user2).first() is not None - def test_compare_type(self): - user1 = self.User() - user2 = self.User() + def test_compare_type(self, session, User, Event): + user1 = User() + user2 = User() - self.session.add_all([user1, user2]) - self.session.commit() + session.add_all([user1, user2]) + session.commit() - event1 = self.Event(object=user1) - event2 = self.Event(object=user2) + event1 = Event(object=user1) + event2 = Event(object=user2) - self.session.add_all([event1, event2]) - self.session.commit() + session.add_all([event1, event2]) + session.commit() - statement = self.Event.object.is_type(self.User) - q = self.session.query(self.Event).filter(statement) + statement = Event.object.is_type(User) + q = session.query(Event).filter(statement) assert q.first() is not None diff --git a/tests/generic_relationship/test_abstract_base_class.py b/tests/generic_relationship/test_abstract_base_class.py index 709f0db..d4b8f0a 100644 --- a/tests/generic_relationship/test_abstract_base_class.py +++ b/tests/generic_relationship/test_abstract_base_class.py @@ -1,36 +1,54 @@ -from __future__ import unicode_literals - +import pytest import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy_utils import generic_relationship -from tests.generic_relationship import GenericRelationshipTestCase + +from . import GenericRelationshipTestCase + + +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + return Building + + +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + return User + + +@pytest.fixture +def EventBase(Base): + class EventBase(Base): + __abstract__ = True + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + @declared_attr + def object(cls): + return generic_relationship('object_type', 'object_id') + return EventBase + + +@pytest.fixture +def Event(EventBase): + class Event(EventBase): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + return Event + + +@pytest.fixture +def init_models(Building, User, Event): + pass class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - - class EventBase(self.Base): - __abstract__ = True - - object_type = sa.Column(sa.Unicode(255)) - object_id = sa.Column(sa.Integer, nullable=False) - - @declared_attr - def object(cls): - return generic_relationship('object_type', 'object_id') - - class Event(EventBase): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - - self.Building = Building - self.User = User - self.Event = Event + pass diff --git a/tests/generic_relationship/test_column_aliases.py b/tests/generic_relationship/test_column_aliases.py index 1d93c79..cb9306e 100644 --- a/tests/generic_relationship/test_column_aliases.py +++ b/tests/generic_relationship/test_column_aliases.py @@ -1,30 +1,44 @@ -from __future__ import unicode_literals - +import pytest import sqlalchemy as sa from sqlalchemy_utils import generic_relationship -from tests.generic_relationship import GenericRelationshipTestCase + +from . import GenericRelationshipTestCase + + +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + return Building + + +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + return User + + +@pytest.fixture +def Event(Base): + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + return Event + + +@pytest.fixture +def init_models(Building, User, Event): + pass class TestGenericRelationship(GenericRelationshipTestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - - class Event(self.Base): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - - object_type = sa.Column(sa.Unicode(255), name="objectType") - object_id = sa.Column(sa.Integer, nullable=False) - - object = generic_relationship(object_type, object_id) - - self.Building = Building - self.User = User - self.Event = Event + pass diff --git a/tests/generic_relationship/test_composite_keys.py b/tests/generic_relationship/test_composite_keys.py index d730d09..c7a6feb 100644 --- a/tests/generic_relationship/test_composite_keys.py +++ b/tests/generic_relationship/test_composite_keys.py @@ -1,66 +1,84 @@ -from __future__ import unicode_literals - +import pytest import six import sqlalchemy as sa from sqlalchemy_utils import generic_relationship -from tests.generic_relationship import GenericRelationshipTestCase + +from ..generic_relationship import GenericRelationshipTestCase + + +@pytest.fixture +def incrementor(): + class Incrementor(object): + value = 1 + return Incrementor() + + +@pytest.fixture +def Building(Base, incrementor): + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + code = sa.Column(sa.Integer, primary_key=True) + + def __init__(self): + incrementor.value += 1 + self.id = incrementor.value + self.code = incrementor.value + return Building + + +@pytest.fixture +def User(Base, incrementor): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + code = sa.Column(sa.Integer, primary_key=True) + + def __init__(self): + incrementor.value += 1 + self.id = incrementor.value + self.code = incrementor.value + return User + + +@pytest.fixture +def Event(Base): + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + object_code = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship( + object_type, (object_id, object_code) + ) + return Event + + +@pytest.fixture +def init_models(Building, User, Event): + pass class TestGenericRelationship(GenericRelationshipTestCase): - index = 1 - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - code = sa.Column(sa.Integer, primary_key=True) + def test_set_manual_and_get(self, session, Event, User): + user = User() - def __init__(obj_self): - self.index += 1 - obj_self.id = self.index - obj_self.code = self.index + session.add(user) + session.commit() - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - code = sa.Column(sa.Integer, primary_key=True) - - def __init__(obj_self): - self.index += 1 - obj_self.id = self.index - obj_self.code = self.index - - class Event(self.Base): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - - object_type = sa.Column(sa.Unicode(255)) - object_id = sa.Column(sa.Integer, nullable=False) - object_code = sa.Column(sa.Integer, nullable=False) - - object = generic_relationship( - object_type, (object_id, object_code) - ) - - self.Building = Building - self.User = User - self.Event = Event - - def test_set_manual_and_get(self): - user = self.User() - - self.session.add(user) - self.session.commit() - - event = self.Event() + event = Event() event.object_id = user.id event.object_type = six.text_type(type(user).__name__) event.object_code = user.code assert event.object is None - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == user diff --git a/tests/generic_relationship/test_hybrid_properties.py b/tests/generic_relationship/test_hybrid_properties.py index ce4575f..7695377 100644 --- a/tests/generic_relationship/test_hybrid_properties.py +++ b/tests/generic_relationship/test_hybrid_properties.py @@ -1,68 +1,79 @@ -from __future__ import unicode_literals - +import pytest import six import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import generic_relationship -from tests import TestCase -class TestGenericRelationship(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + return User - class UserHistory(self.Base): - __tablename__ = 'user_history' - id = sa.Column(sa.Integer, primary_key=True) - transaction_id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def UserHistory(Base): + class UserHistory(Base): + __tablename__ = 'user_history' + id = sa.Column(sa.Integer, primary_key=True) - class Event(self.Base): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) + transaction_id = sa.Column(sa.Integer, primary_key=True) + return UserHistory - transaction_id = sa.Column(sa.Integer) - object_type = sa.Column(sa.Unicode(255)) - object_id = sa.Column(sa.Integer, nullable=False) +@pytest.fixture +def Event(Base): + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) - object = generic_relationship( - object_type, object_id - ) + transaction_id = sa.Column(sa.Integer) - @hybrid_property - def object_version_type(self): - return self.object_type + 'History' + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) - @object_version_type.expression - def object_version_type(cls): - return sa.func.concat(cls.object_type, 'History') + object = generic_relationship( + object_type, object_id + ) - object_version = generic_relationship( - object_version_type, (object_id, transaction_id) - ) + @hybrid_property + def object_version_type(self): + return self.object_type + 'History' - self.User = User - self.UserHistory = UserHistory - self.Event = Event + @object_version_type.expression + def object_version_type(cls): + return sa.func.concat(cls.object_type, 'History') - def test_set_manual_and_get(self): - user = self.User(id=1) - history = self.UserHistory(id=1, transaction_id=1) - self.session.add(user) - self.session.add(history) - self.session.commit() + object_version = generic_relationship( + object_version_type, (object_id, transaction_id) + ) + return Event - event = self.Event(transaction_id=1) + +@pytest.fixture +def init_models(User, UserHistory, Event): + pass + + +class TestGenericRelationship(object): + + def test_set_manual_and_get(self, session, User, UserHistory, Event): + user = User(id=1) + history = UserHistory(id=1, transaction_id=1) + session.add(user) + session.add(history) + session.commit() + + event = Event(transaction_id=1) event.object_id = user.id event.object_type = six.text_type(type(user).__name__) assert event.object is None - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == user assert event.object_version == history diff --git a/tests/generic_relationship/test_single_table_inheritance.py b/tests/generic_relationship/test_single_table_inheritance.py index 1013717..39a9d55 100644 --- a/tests/generic_relationship/test_single_table_inheritance.py +++ b/tests/generic_relationship/test_single_table_inheritance.py @@ -1,164 +1,178 @@ -from __future__ import unicode_literals - +import pytest import six import sqlalchemy as sa from sqlalchemy_utils import generic_relationship -from tests import TestCase -class TestGenericRelationship(TestCase): - def create_models(self): - class Employee(self.Base): - __tablename__ = 'employee' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String(50)) - type = sa.Column(sa.String(20)) +@pytest.fixture +def Employee(Base): + class Employee(Base): + __tablename__ = 'employee' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(50)) + type = sa.Column(sa.String(20)) - __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'employee' - } + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'employee' + } + return Employee - class Manager(Employee): - __mapper_args__ = { - 'polymorphic_identity': 'manager' - } - class Engineer(Employee): - __mapper_args__ = { - 'polymorphic_identity': 'engineer' - } +@pytest.fixture +def Manager(Employee): + class Manager(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'manager' + } + return Manager - class Event(self.Base): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - object_type = sa.Column(sa.Unicode(255)) - object_id = sa.Column(sa.Integer, nullable=False) +@pytest.fixture +def Engineer(Employee): + class Engineer(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'engineer' + } + return Engineer - object = generic_relationship(object_type, object_id) - self.Employee = Employee - self.Manager = Manager - self.Engineer = Engineer - self.Event = Event +@pytest.fixture +def Event(Base): + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) - def test_set_as_none(self): - event = self.Event() + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + return Event + + +@pytest.fixture +def init_models(Employee, Manager, Engineer, Event): + pass + + +class TestGenericRelationship(object): + + def test_set_as_none(self, Event): + event = Event() event.object = None assert event.object is None - def test_set_manual_and_get(self): - manager = self.Manager() + def test_set_manual_and_get(self, session, Manager, Event): + manager = Manager() - self.session.add(manager) - self.session.commit() + session.add(manager) + session.commit() - event = self.Event() + event = Event() event.object_id = manager.id event.object_type = six.text_type(type(manager).__name__) assert event.object is None - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == manager - def test_set_and_get(self): - manager = self.Manager() + def test_set_and_get(self, session, Manager, Event): + manager = Manager() - self.session.add(manager) - self.session.commit() + session.add(manager) + session.commit() - event = self.Event(object=manager) + event = Event(object=manager) assert event.object_id == manager.id assert event.object_type == type(manager).__name__ - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == manager - def test_compare_instance(self): - manager1 = self.Manager() - manager2 = self.Manager() + def test_compare_instance(self, session, Manager, Event): + manager1 = Manager() + manager2 = Manager() - self.session.add_all([manager1, manager2]) - self.session.commit() + session.add_all([manager1, manager2]) + session.commit() - event = self.Event(object=manager1) + event = Event(object=manager1) - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() assert event.object == manager1 assert event.object != manager2 - def test_compare_query(self): - manager1 = self.Manager() - manager2 = self.Manager() + def test_compare_query(self, session, Manager, Event): + manager1 = Manager() + manager2 = Manager() - self.session.add_all([manager1, manager2]) - self.session.commit() + session.add_all([manager1, manager2]) + session.commit() - event = self.Event(object=manager1) + event = Event(object=manager1) - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() - q = self.session.query(self.Event) + q = session.query(Event) assert q.filter_by(object=manager1).first() is not None assert q.filter_by(object=manager2).first() is None - assert q.filter(self.Event.object == manager2).first() is None + assert q.filter(Event.object == manager2).first() is None - def test_compare_not_query(self): - manager1 = self.Manager() - manager2 = self.Manager() + def test_compare_not_query(self, session, Manager, Event): + manager1 = Manager() + manager2 = Manager() - self.session.add_all([manager1, manager2]) - self.session.commit() + session.add_all([manager1, manager2]) + session.commit() - event = self.Event(object=manager1) + event = Event(object=manager1) - self.session.add(event) - self.session.commit() + session.add(event) + session.commit() - q = self.session.query(self.Event) - assert q.filter(self.Event.object != manager2).first() is not None + q = session.query(Event) + assert q.filter(Event.object != manager2).first() is not None - def test_compare_type(self): - manager1 = self.Manager() - manager2 = self.Manager() + def test_compare_type(self, session, Manager, Event): + manager1 = Manager() + manager2 = Manager() - self.session.add_all([manager1, manager2]) - self.session.commit() + session.add_all([manager1, manager2]) + session.commit() - event1 = self.Event(object=manager1) - event2 = self.Event(object=manager2) + event1 = Event(object=manager1) + event2 = Event(object=manager2) - self.session.add_all([event1, event2]) - self.session.commit() + session.add_all([event1, event2]) + session.commit() - statement = self.Event.object.is_type(self.Manager) - q = self.session.query(self.Event).filter(statement) + statement = Event.object.is_type(Manager) + q = session.query(Event).filter(statement) assert q.first() is not None - def test_compare_super_type(self): - manager1 = self.Manager() - manager2 = self.Manager() + def test_compare_super_type(self, session, Manager, Event, Employee): + manager1 = Manager() + manager2 = Manager() - self.session.add_all([manager1, manager2]) - self.session.commit() + session.add_all([manager1, manager2]) + session.commit() - event1 = self.Event(object=manager1) - event2 = self.Event(object=manager2) + event1 = Event(object=manager1) + event2 = Event(object=manager2) - self.session.add_all([event1, event2]) - self.session.commit() + session.add_all([event1, event2]) + session.commit() - statement = self.Event.object.is_type(self.Employee) - q = self.session.query(self.Event).filter(statement) + statement = Event.object.is_type(Employee) + q = session.query(Event).filter(statement) assert q.first() is not None diff --git a/tests/mixins.py b/tests/mixins.py index 1224024..020e093 100644 --- a/tests/mixins.py +++ b/tests/mixins.py @@ -1,18 +1,24 @@ +import pytest import sqlalchemy as sa class ThreeLevelDeepOneToOne(object): - def create_models(self): - class Catalog(self.Base): + + @pytest.fixture + def Catalog(self, Base, Category): + class Catalog(Base): __tablename__ = 'catalog' id = sa.Column('_id', sa.Integer, primary_key=True) category = sa.orm.relationship( - 'Category', + Category, uselist=False, backref='catalog' ) + return Catalog - class Category(self.Base): + @pytest.fixture + def Category(self, Base, SubCategory): + class Category(Base): __tablename__ = 'category' id = sa.Column('_id', sa.Integer, primary_key=True) catalog_id = sa.Column( @@ -22,12 +28,15 @@ class ThreeLevelDeepOneToOne(object): ) sub_category = sa.orm.relationship( - 'SubCategory', + SubCategory, uselist=False, backref='category' ) + return Category - class SubCategory(self.Base): + @pytest.fixture + def SubCategory(self, Base, Product): + class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column('_id', sa.Integer, primary_key=True) category_id = sa.Column( @@ -36,12 +45,15 @@ class ThreeLevelDeepOneToOne(object): sa.ForeignKey('category._id') ) product = sa.orm.relationship( - 'Product', + Product, uselist=False, backref='sub_category' ) + return SubCategory - class Product(self.Base): + @pytest.fixture + def Product(self, Base): + class Product(Base): __tablename__ = 'product' id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Integer) @@ -51,22 +63,27 @@ class ThreeLevelDeepOneToOne(object): sa.Integer, sa.ForeignKey('sub_category._id') ) + return Product - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product + @pytest.fixture + def init_models(self, Catalog, Category, SubCategory, Product): + pass class ThreeLevelDeepOneToMany(object): - def create_models(self): - class Catalog(self.Base): + + @pytest.fixture + def Catalog(self, Base, Category): + class Catalog(Base): __tablename__ = 'catalog' id = sa.Column('_id', sa.Integer, primary_key=True) - categories = sa.orm.relationship('Category', backref='catalog') + categories = sa.orm.relationship(Category, backref='catalog') + return Catalog - class Category(self.Base): + @pytest.fixture + def Category(self, Base, SubCategory): + class Category(Base): __tablename__ = 'category' id = sa.Column('_id', sa.Integer, primary_key=True) catalog_id = sa.Column( @@ -76,10 +93,13 @@ class ThreeLevelDeepOneToMany(object): ) sub_categories = sa.orm.relationship( - 'SubCategory', backref='category' + SubCategory, backref='category' ) + return Category - class SubCategory(self.Base): + @pytest.fixture + def SubCategory(self, Base, Product): + class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column('_id', sa.Integer, primary_key=True) category_id = sa.Column( @@ -88,11 +108,14 @@ class ThreeLevelDeepOneToMany(object): sa.ForeignKey('category._id') ) products = sa.orm.relationship( - 'Product', + Product, backref='sub_category' ) + return SubCategory - class Product(self.Base): + @pytest.fixture + def Product(self, Base): + class Product(Base): __tablename__ = 'product' id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) @@ -105,25 +128,42 @@ class ThreeLevelDeepOneToMany(object): def __repr__(self): return '' % self.id + return Product - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product + @pytest.fixture + def init_models(self, Catalog, Category, SubCategory, Product): + pass class ThreeLevelDeepManyToMany(object): - def create_models(self): + + @pytest.fixture + def Catalog(self, Base, Category): + catalog_category = sa.Table( 'catalog_category', - self.Base.metadata, + Base.metadata, sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')), sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id')) ) + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column('_id', sa.Integer, primary_key=True) + + categories = sa.orm.relationship( + Category, + backref='catalogs', + secondary=catalog_category + ) + return Catalog + + @pytest.fixture + def Category(self, Base, SubCategory): + category_subcategory = sa.Table( 'category_subcategory', - self.Base.metadata, + Base.metadata, sa.Column( 'category_id', sa.Integer, @@ -136,9 +176,23 @@ class ThreeLevelDeepManyToMany(object): ) ) + class Category(Base): + __tablename__ = 'category' + id = sa.Column('_id', sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + SubCategory, + backref='categories', + secondary=category_subcategory + ) + return Category + + @pytest.fixture + def SubCategory(self, Base, Product): + subcategory_product = sa.Table( 'subcategory_product', - self.Base.metadata, + Base.metadata, sa.Column( 'subcategory_id', sa.Integer, @@ -151,41 +205,24 @@ class ThreeLevelDeepManyToMany(object): ) ) - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column('_id', sa.Integer, primary_key=True) - - categories = sa.orm.relationship( - 'Category', - backref='catalogs', - secondary=catalog_category - ) - - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column('_id', sa.Integer, primary_key=True) - - sub_categories = sa.orm.relationship( - 'SubCategory', - backref='categories', - secondary=category_subcategory - ) - - class SubCategory(self.Base): + class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column('_id', sa.Integer, primary_key=True) products = sa.orm.relationship( - 'Product', + Product, backref='sub_categories', secondary=subcategory_product ) + return SubCategory - class Product(self.Base): + @pytest.fixture + def Product(self, Base): + class Product(Base): __tablename__ = 'product' id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) + return Product - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product + @pytest.fixture + def init_models(self, Catalog, Category, SubCategory, Product): + pass diff --git a/tests/observes/test_column_property.py b/tests/observes/test_column_property.py index 058383a..f388f26 100644 --- a/tests/observes/test_column_property.py +++ b/tests/observes/test_column_property.py @@ -1,15 +1,15 @@ +import pytest import sqlalchemy as sa -from pytest import raises from sqlalchemy_utils.observer import observes -from tests import TestCase -class TestObservesForColumn(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForColumn(object): - def create_models(self): - class Product(self.Base): + @pytest.fixture + def Product(self, Base): + class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Integer) @@ -17,21 +17,25 @@ class TestObservesForColumn(TestCase): @observes('price') def product_price_observer(self, price): self.price = price * 2 + return Product - self.Product = Product + @pytest.fixture + def init_models(self, Product): + pass - def test_simple_insert(self): - product = self.Product(price=100) - self.session.add(product) - self.session.flush() + def test_simple_insert(self, session, Product): + product = Product(price=100) + session.add(product) + session.flush() assert product.price == 200 -class TestObservesForColumnWithoutActualChanges(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForColumnWithoutActualChanges(object): - def create_models(self): - class Product(self.Base): + @pytest.fixture + def Product(self, Base): + class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Integer) @@ -39,15 +43,18 @@ class TestObservesForColumnWithoutActualChanges(TestCase): @observes('price') def product_price_observer(self, price): raise Exception('Trying to change price') + return Product - self.Product = Product + @pytest.fixture + def init_models(self, Product): + pass - def test_only_notifies_observer_on_actual_changes(self): - product = self.Product() - self.session.add(product) - self.session.flush() + def test_only_notifies_observer_on_actual_changes(self, session, Product): + product = Product() + session.add(product) + session.flush() - with raises(Exception) as e: + with pytest.raises(Exception) as e: product.price = 500 - self.session.commit() + session.commit() assert str(e.value) == 'Trying to change price' diff --git a/tests/observes/test_m2m_m2m_m2m.py b/tests/observes/test_m2m_m2m_m2m.py index 3b416f2..a9c529a 100644 --- a/tests/observes/test_m2m_m2m_m2m.py +++ b/tests/observes/test_m2m_m2m_m2m.py @@ -1,137 +1,158 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes -from tests import TestCase -class TestObservesForManyToManyToManyToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Catalog(Base): + catalog_category = sa.Table( + 'catalog_category', + Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + ) - def create_models(self): - catalog_category = sa.Table( - 'catalog_category', - self.Base.metadata, - sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), - sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship( + 'Category', + backref='catalogs', + secondary=catalog_category + ) + return Catalog + + +@pytest.fixture +def Category(Base): + category_subcategory = sa.Table( + 'category_subcategory', + Base.metadata, + sa.Column( + 'category_id', + sa.Integer, + sa.ForeignKey('category.id') + ), + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ) + ) + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + 'SubCategory', + backref='categories', + secondary=category_subcategory + ) + return Category + + +@pytest.fixture +def SubCategory(Base): + subcategory_product = sa.Table( + 'subcategory_product', + Base.metadata, + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ), + sa.Column( + 'product_id', + sa.Integer, + sa.ForeignKey('product.id') + ) + ) + + class SubCategory(Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + products = sa.orm.relationship( + 'Product', + backref='sub_categories', + secondary=subcategory_product ) - category_subcategory = sa.Table( - 'category_subcategory', - self.Base.metadata, - sa.Column( - 'category_id', - sa.Integer, - sa.ForeignKey('category.id') - ), - sa.Column( - 'subcategory_id', - sa.Integer, - sa.ForeignKey('sub_category.id') - ) - ) + return SubCategory - subcategory_product = sa.Table( - 'subcategory_product', - self.Base.metadata, - sa.Column( - 'subcategory_id', - sa.Integer, - sa.ForeignKey('sub_category.id') - ), - sa.Column( - 'product_id', - sa.Integer, - sa.ForeignKey('product.id') - ) - ) - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - product_count = sa.Column(sa.Integer, default=0) +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + return Product - @observes('categories.sub_categories.products') - def product_observer(self, products): - self.product_count = len(products) - categories = sa.orm.relationship( - 'Category', - backref='catalogs', - secondary=catalog_category - ) +@pytest.fixture +def init_models(Catalog, Category, SubCategory, Product): + pass - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - sub_categories = sa.orm.relationship( - 'SubCategory', - backref='categories', - secondary=category_subcategory - ) +@pytest.fixture +def catalog(session, Catalog, Category, SubCategory, Product): + sub_category = SubCategory(products=[Product()]) + category = Category(sub_categories=[sub_category]) + catalog = Catalog(categories=[category]) + session.add(catalog) + session.flush() + return catalog - class SubCategory(self.Base): - __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - products = sa.orm.relationship( - 'Product', - backref='sub_categories', - secondary=subcategory_product - ) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - price = sa.Column(sa.Numeric) +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForManyToManyToManyToMany(object): - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product - - def create_catalog(self): - sub_category = self.SubCategory(products=[self.Product()]) - category = self.Category(sub_categories=[sub_category]) - catalog = self.Catalog(categories=[category]) - self.session.add(catalog) - self.session.flush() - return catalog - - def test_simple_insert(self): - catalog = self.create_catalog() + def test_simple_insert(self, catalog): assert catalog.product_count == 1 - def test_add_leaf_object(self): - catalog = self.create_catalog() - product = self.Product() + def test_add_leaf_object(self, catalog, session, Product): + product = Product() catalog.categories[0].sub_categories[0].products.append(product) - self.session.flush() + session.flush() assert catalog.product_count == 2 - def test_remove_leaf_object(self): - catalog = self.create_catalog() - product = self.Product() + def test_remove_leaf_object(self, catalog, session, Product): + product = Product() catalog.categories[0].sub_categories[0].products.append(product) - self.session.flush() - self.session.delete(product) - self.session.flush() + session.flush() + session.delete(product) + session.flush() assert catalog.product_count == 1 - def test_delete_intermediate_object(self): - catalog = self.create_catalog() - self.session.delete(catalog.categories[0].sub_categories[0]) - self.session.commit() + def test_delete_intermediate_object(self, catalog, session): + session.delete(catalog.categories[0].sub_categories[0]) + session.commit() assert catalog.product_count == 0 - def test_gathered_objects_are_distinct(self): - catalog = self.Catalog() - category = self.Category(catalogs=[catalog]) - product = self.Product() + def test_gathered_objects_are_distinct( + self, + session, + Catalog, + Category, + SubCategory, + Product + ): + catalog = Catalog() + category = Category(catalogs=[catalog]) + product = Product() category.sub_categories.append( - self.SubCategory(products=[product]) + SubCategory(products=[product]) ) - self.session.add( - self.SubCategory(categories=[category], products=[product]) + session.add( + SubCategory(categories=[category], products=[product]) ) - self.session.commit() + session.commit() assert catalog.product_count == 1 diff --git a/tests/observes/test_o2m_o2m_o2m.py b/tests/observes/test_o2m_o2m_o2m.py index a11b378..74fda0b 100644 --- a/tests/observes/test_o2m_o2m_o2m.py +++ b/tests/observes/test_o2m_o2m_o2m.py @@ -1,107 +1,127 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes -from tests import TestCase -class TestObservesFor3LevelDeepOneToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Catalog(Base): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - product_count = sa.Column(sa.Integer, default=0) + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) - @observes('categories.sub_categories.products') - def product_observer(self, products): - self.product_count = len(products) + categories = sa.orm.relationship('Category', backref='catalog') + return Catalog - categories = sa.orm.relationship('Category', backref='catalog') - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - sub_categories = sa.orm.relationship( - 'SubCategory', backref='category' - ) + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + return Category - class SubCategory(self.Base): - __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) - products = sa.orm.relationship( - 'Product', - backref='sub_category' - ) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - price = sa.Column(sa.Numeric) +@pytest.fixture +def SubCategory(Base): + class SubCategory(Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship( + 'Product', + backref='sub_category' + ) + return SubCategory - sub_category_id = sa.Column( - sa.Integer, sa.ForeignKey('sub_category.id') - ) - def __repr__(self): - return '' % self.id +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) - def create_catalog(self): - sub_category = self.SubCategory(products=[self.Product()]) - category = self.Category(sub_categories=[sub_category]) - catalog = self.Catalog(categories=[category]) - self.session.add(catalog) - self.session.commit() - return catalog + def __repr__(self): + return '' % self.id + return Product - def test_simple_insert(self): - catalog = self.create_catalog() + +@pytest.fixture +def init_models(Catalog, Category, SubCategory, Product): + pass + + +@pytest.fixture +def catalog(session, Catalog, Category, SubCategory, Product): + sub_category = SubCategory(products=[Product()]) + category = Category(sub_categories=[sub_category]) + catalog = Catalog(categories=[category]) + session.add(catalog) + session.commit() + return catalog + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesFor3LevelDeepOneToMany(object): + + def test_simple_insert(self, catalog): assert catalog.product_count == 1 - def test_add_leaf_object(self): - catalog = self.create_catalog() - product = self.Product() + def test_add_leaf_object(self, catalog, session, Product): + product = Product() catalog.categories[0].sub_categories[0].products.append(product) - self.session.flush() + session.flush() assert catalog.product_count == 2 - def test_remove_leaf_object(self): - catalog = self.create_catalog() - product = self.Product() + def test_remove_leaf_object(self, catalog, session, Product): + product = Product() catalog.categories[0].sub_categories[0].products.append(product) - self.session.flush() - self.session.delete(product) - self.session.commit() + session.flush() + session.delete(product) + session.commit() assert catalog.product_count == 1 - self.session.delete( + session.delete( catalog.categories[0].sub_categories[0].products[0] ) - self.session.commit() + session.commit() assert catalog.product_count == 0 - def test_delete_intermediate_object(self): - catalog = self.create_catalog() - self.session.delete(catalog.categories[0].sub_categories[0]) - self.session.commit() + def test_delete_intermediate_object(self, catalog, session): + session.delete(catalog.categories[0].sub_categories[0]) + session.commit() assert catalog.product_count == 0 - def test_gathered_objects_are_distinct(self): - catalog = self.Catalog() - category = self.Category(catalog=catalog) - product = self.Product() + def test_gathered_objects_are_distinct( + self, + session, + Catalog, + Category, + SubCategory, + Product + ): + catalog = Catalog() + category = Category(catalog=catalog) + product = Product() category.sub_categories.append( - self.SubCategory(products=[product]) + SubCategory(products=[product]) ) - self.session.add( - self.SubCategory(category=category, products=[product]) + session.add( + SubCategory(category=category, products=[product]) ) - self.session.commit() + session.commit() assert catalog.product_count == 1 diff --git a/tests/observes/test_o2m_o2o_o2m.py b/tests/observes/test_o2m_o2o_o2m.py index 2299280..fd53cad 100644 --- a/tests/observes/test_o2m_o2o_o2m.py +++ b/tests/observes/test_o2m_o2o_o2m.py @@ -1,96 +1,116 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes -from tests import TestCase -class TestObservesForOneToManyToOneToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Catalog(Base): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - product_count = sa.Column(sa.Integer, default=0) + @observes('categories.sub_category.products') + def product_observer(self, products): + self.product_count = len(products) - @observes('categories.sub_category.products') - def product_observer(self, products): - self.product_count = len(products) + categories = sa.orm.relationship('Category', backref='catalog') + return Catalog - categories = sa.orm.relationship('Category', backref='catalog') - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - sub_category = sa.orm.relationship( - 'SubCategory', - uselist=False, - backref='category' - ) + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + return Category - class SubCategory(self.Base): - __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) - products = sa.orm.relationship('Product', backref='sub_category') - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - price = sa.Column(sa.Numeric) +@pytest.fixture +def SubCategory(Base): + class SubCategory(Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship('Product', backref='sub_category') + return SubCategory - sub_category_id = sa.Column( - sa.Integer, sa.ForeignKey('sub_category.id') - ) - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) - def create_catalog(self): - sub_category = self.SubCategory(products=[self.Product()]) - category = self.Category(sub_category=sub_category) - catalog = self.Catalog(categories=[category]) - self.session.add(catalog) - self.session.flush() - return catalog + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + return Product - def test_simple_insert(self): - catalog = self.create_catalog() + +@pytest.fixture +def init_models(Catalog, Category, SubCategory, Product): + pass + + +@pytest.fixture +def catalog(session, Catalog, Category, SubCategory, Product): + sub_category = SubCategory(products=[Product()]) + category = Category(sub_category=sub_category) + catalog = Catalog(categories=[category]) + session.add(catalog) + session.flush() + return catalog + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForOneToManyToOneToMany(object): + + def test_simple_insert(self, catalog): assert catalog.product_count == 1 - def test_add_leaf_object(self): - catalog = self.create_catalog() - product = self.Product() + def test_add_leaf_object(self, catalog, session, Product): + product = Product() catalog.categories[0].sub_category.products.append(product) - self.session.flush() + session.flush() assert catalog.product_count == 2 - def test_remove_leaf_object(self): - catalog = self.create_catalog() - product = self.Product() + def test_remove_leaf_object(self, catalog, session, Product): + product = Product() catalog.categories[0].sub_category.products.append(product) - self.session.flush() - self.session.delete(product) - self.session.flush() + session.flush() + session.delete(product) + session.flush() assert catalog.product_count == 1 - def test_delete_intermediate_object(self): - catalog = self.create_catalog() - self.session.delete(catalog.categories[0].sub_category) - self.session.commit() + def test_delete_intermediate_object(self, catalog, session): + session.delete(catalog.categories[0].sub_category) + session.commit() assert catalog.product_count == 0 - def test_gathered_objects_are_distinct(self): - catalog = self.Catalog() - category = self.Category(catalog=catalog) - product = self.Product() - category.sub_category = self.SubCategory(products=[product]) - self.session.add( - self.Category(catalog=catalog, sub_category=category.sub_category) + def test_gathered_objects_are_distinct( + self, + session, + Catalog, + Category, + SubCategory, + Product + ): + catalog = Catalog() + category = Category(catalog=catalog) + product = Product() + category.sub_category = SubCategory(products=[product]) + session.add( + Category(catalog=catalog, sub_category=category.sub_category) ) - self.session.commit() + session.commit() assert catalog.product_count == 1 diff --git a/tests/observes/test_o2o_o2o.py b/tests/observes/test_o2o_o2o.py index a923295..ffbde72 100644 --- a/tests/observes/test_o2o_o2o.py +++ b/tests/observes/test_o2o_o2o.py @@ -1,53 +1,66 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes -from tests import TestCase -class TestObservesForOneToManyToOneToMany(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Device(Base): + class Device(Base): + __tablename__ = 'device' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + return Device - def create_models(self): - class Device(self.Base): - __tablename__ = 'device' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String) - class Order(self.Base): - __tablename__ = 'order' - id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def Order(Base): + class Order(Base): + __tablename__ = 'order' + id = sa.Column(sa.Integer, primary_key=True) - device_id = sa.Column( - 'device', sa.ForeignKey('device.id'), nullable=False + device_id = sa.Column( + 'device', sa.ForeignKey('device.id'), nullable=False + ) + device = sa.orm.relationship('Device', backref='orders') + return Order + + +@pytest.fixture +def SalesInvoice(Base): + class SalesInvoice(Base): + __tablename__ = 'sales_invoice' + id = sa.Column(sa.Integer, primary_key=True) + order_id = sa.Column( + 'order', + sa.ForeignKey('order.id'), + nullable=False + ) + order = sa.orm.relationship( + 'Order', + backref=sa.orm.backref( + 'invoice', + uselist=False ) - device = sa.orm.relationship('Device', backref='orders') + ) + device_name = sa.Column(sa.String) - class SalesInvoice(self.Base): - __tablename__ = 'sales_invoice' - id = sa.Column(sa.Integer, primary_key=True) - order_id = sa.Column( - 'order', - sa.ForeignKey('order.id'), - nullable=False - ) - order = sa.orm.relationship( - 'Order', - backref=sa.orm.backref( - 'invoice', - uselist=False - ) - ) - device_name = sa.Column(sa.String) + @observes('order.device') + def process_device(self, device): + self.device_name = device.name - @observes('order.device') - def process_device(self, device): - self.device_name = device.name + return SalesInvoice - self.Device = Device - self.Order = Order - self.SalesInvoice = SalesInvoice - def test_observable_root_obj_is_none(self): - order = self.Order(device=self.Device(name='Something')) - self.session.add(order) - self.session.flush() +@pytest.fixture +def init_models(Device, Order, SalesInvoice): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForOneToManyToOneToMany(object): + + def test_observable_root_obj_is_none(self, session, Device, Order): + order = Order(device=Device(name='Something')) + session.add(order) + session.flush() diff --git a/tests/observes/test_o2o_o2o_o2o.py b/tests/observes/test_o2o_o2o_o2o.py index 00cfca8..c70302b 100644 --- a/tests/observes/test_o2o_o2o_o2o.py +++ b/tests/observes/test_o2o_o2o_o2o.py @@ -1,84 +1,98 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes -from tests import TestCase -class TestObservesForOneToOneToOneToOne(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def Catalog(Base): + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_price = sa.Column(sa.Integer) - def create_models(self): - class Catalog(self.Base): - __tablename__ = 'catalog' - id = sa.Column(sa.Integer, primary_key=True) - product_price = sa.Column(sa.Integer) + @observes('category.sub_category.product') + def product_observer(self, product): + self.product_price = product.price if product else None - @observes('category.sub_category.product') - def product_observer(self, product): - self.product_price = product.price if product else None + category = sa.orm.relationship( + 'Category', + uselist=False, + backref='catalog' + ) + return Catalog - category = sa.orm.relationship( - 'Category', - uselist=False, - backref='catalog' - ) - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) +@pytest.fixture +def Category(Base): + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) - sub_category = sa.orm.relationship( - 'SubCategory', - uselist=False, - backref='category' - ) + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + return Category - class SubCategory(self.Base): - __tablename__ = 'sub_category' - id = sa.Column(sa.Integer, primary_key=True) - category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) - product = sa.orm.relationship( - 'Product', - uselist=False, - backref='sub_category' - ) - class Product(self.Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, primary_key=True) - price = sa.Column(sa.Integer) +@pytest.fixture +def SubCategory(Base): + class SubCategory(Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + product = sa.orm.relationship( + 'Product', + uselist=False, + backref='sub_category' + ) + return SubCategory - sub_category_id = sa.Column( - sa.Integer, sa.ForeignKey('sub_category.id') - ) - self.Catalog = Catalog - self.Category = Category - self.SubCategory = SubCategory - self.Product = Product +@pytest.fixture +def Product(Base): + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) - def create_catalog(self): - sub_category = self.SubCategory(product=self.Product(price=123)) - category = self.Category(sub_category=sub_category) - catalog = self.Catalog(category=category) - self.session.add(catalog) - self.session.flush() - return catalog + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + return Product - def test_simple_insert(self): - catalog = self.create_catalog() + +@pytest.fixture +def init_models(Catalog, Category, SubCategory, Product): + pass + + +@pytest.fixture +def catalog(session, Catalog, Category, SubCategory, Product): + sub_category = SubCategory(product=Product(price=123)) + category = Category(sub_category=sub_category) + catalog = Catalog(category=category) + session.add(catalog) + session.flush() + return catalog + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestObservesForOneToOneToOneToOne(object): + + def test_simple_insert(self, catalog): assert catalog.product_price == 123 - def test_replace_leaf_object(self): - catalog = self.create_catalog() - product = self.Product(price=44) + def test_replace_leaf_object(self, catalog, session, Product): + product = Product(price=44) catalog.category.sub_category.product = product - self.session.flush() + session.flush() assert catalog.product_price == 44 - def test_delete_leaf_object(self): - catalog = self.create_catalog() - self.session.delete(catalog.category.sub_category.product) - self.session.flush() + def test_delete_leaf_object(self, catalog, session): + session.delete(catalog.category.sub_category.product) + session.flush() assert catalog.product_price is None diff --git a/tests/primitives/test_country.py b/tests/primitives/test_country.py index 751e876..24b9df3 100644 --- a/tests/primitives/test_country.py +++ b/tests/primitives/test_country.py @@ -1,32 +1,36 @@ +import pytest import six -from pytest import mark, raises from sqlalchemy_utils import Country, i18n -@mark.skipif('i18n.babel is None') +@pytest.fixture +def set_get_locale(): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + +@pytest.mark.skipif('i18n.babel is None') +@pytest.mark.usefixtures('set_get_locale') class TestCountry(object): - def setup_method(self, method): - i18n.get_locale = lambda: i18n.babel.Locale('en') def test_init(self): assert Country(u'FI') == Country(Country(u'FI')) def test_constructor_with_wrong_type(self): - with raises(TypeError) as e: + with pytest.raises(TypeError) as e: Country(None) assert str(e.value) == ( "Country() argument must be a string or a country, not 'NoneType'" ) def test_constructor_with_invalid_code(self): - with raises(ValueError) as e: + with pytest.raises(ValueError) as e: Country('SomeUnknownCode') assert str(e.value) == ( 'Could not convert string to country code: SomeUnknownCode' ) - @mark.parametrize( + @pytest.mark.parametrize( 'code', ( 'FI', @@ -37,7 +41,7 @@ class TestCountry(object): Country.validate(code) def test_validate_with_invalid_code(self): - with raises(ValueError) as e: + with pytest.raises(ValueError) as e: Country.validate('SomeUnknownCode') assert str(e.value) == ( 'Could not convert string to country code: SomeUnknownCode' diff --git a/tests/primitives/test_currency.py b/tests/primitives/test_currency.py index a6c4876..5dbe466 100644 --- a/tests/primitives/test_currency.py +++ b/tests/primitives/test_currency.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- +import pytest import six -from pytest import mark, raises from sqlalchemy_utils import Currency, i18n -@mark.skipif('i18n.babel is None') +@pytest.fixture +def set_get_locale(): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + +@pytest.mark.skipif('i18n.babel is None') +@pytest.mark.usefixtures('set_get_locale') class TestCurrency(object): - def setup_method(self, method): - i18n.get_locale = lambda: i18n.babel.Locale('en') def test_init(self): assert Currency('USD') == Currency(Currency('USD')) @@ -17,14 +21,14 @@ class TestCurrency(object): assert len(set([Currency('USD'), Currency('USD')])) == 1 def test_invalid_currency_code(self): - with raises(ValueError): + with pytest.raises(ValueError): Currency('Unknown code') def test_invalid_currency_code_type(self): - with raises(TypeError): + with pytest.raises(TypeError): Currency(None) - @mark.parametrize( + @pytest.mark.parametrize( ('code', 'name'), ( ('USD', 'US Dollar'), @@ -34,7 +38,7 @@ class TestCurrency(object): def test_name_property(self, code, name): assert Currency(code).name == name - @mark.parametrize( + @pytest.mark.parametrize( ('code', 'symbol'), ( ('USD', u'$'), diff --git a/tests/primitives/test_weekdays.py b/tests/primitives/test_weekdays.py index 3e1b246..5dc4ce4 100644 --- a/tests/primitives/test_weekdays.py +++ b/tests/primitives/test_weekdays.py @@ -6,10 +6,14 @@ from sqlalchemy_utils import i18n from sqlalchemy_utils.primitives import WeekDay, WeekDays +@pytest.fixture +def set_get_locale(): + i18n.get_locale = lambda: i18n.babel.Locale('fi') + + @pytest.mark.skipif('i18n.babel is None') +@pytest.mark.usefixtures('set_get_locale') class TestWeekDay(object): - def setup_method(self, method): - i18n.get_locale = lambda: i18n.babel.Locale('fi') def test_constructor_with_valid_index(self): day = WeekDay(1) diff --git a/tests/relationships/test_chained_join.py b/tests/relationships/test_chained_join.py index aff42d4..81e6193 100644 --- a/tests/relationships/test_chained_join.py +++ b/tests/relationships/test_chained_join.py @@ -1,26 +1,27 @@ +import pytest + from sqlalchemy_utils.relationships import chained_join -from tests import TestCase -from tests.mixins import ( + +from ..mixins import ( ThreeLevelDeepManyToMany, ThreeLevelDeepOneToMany, ThreeLevelDeepOneToOne ) -class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - create_tables = False +@pytest.mark.usefixtures('postgresql_dsn') +class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany): - def test_simple_join(self): - assert str(chained_join(self.Catalog.categories)) == ( + def test_simple_join(self, Catalog): + assert str(chained_join(Catalog.categories)) == ( 'catalog_category JOIN category ON ' 'category._id = catalog_category.category_id' ) - def test_two_relations(self): + def test_two_relations(self, Catalog, Category): sql = chained_join( - self.Catalog.categories, - self.Category.sub_categories + Catalog.categories, + Category.sub_categories ) assert str(sql) == ( 'catalog_category JOIN category ON category._id = ' @@ -30,11 +31,11 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): 'category_subcategory.subcategory_id' ) - def test_three_relations(self): + def test_three_relations(self, Catalog, Category, SubCategory): sql = chained_join( - self.Catalog.categories, - self.Category.sub_categories, - self.SubCategory.products + Catalog.categories, + Category.sub_categories, + SubCategory.products ) assert str(sql) == ( 'catalog_category JOIN category ON category._id = ' @@ -47,28 +48,27 @@ class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): ) -class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - create_tables = False +@pytest.mark.usefixtures('postgresql_dsn') +class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany): - def test_simple_join(self): - assert str(chained_join(self.Catalog.categories)) == 'category' + def test_simple_join(self, Catalog): + assert str(chained_join(Catalog.categories)) == 'category' - def test_two_relations(self): + def test_two_relations(self, Catalog, Category): sql = chained_join( - self.Catalog.categories, - self.Category.sub_categories + Catalog.categories, + Category.sub_categories ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' 'sub_category._category_id' ) - def test_three_relations(self): + def test_three_relations(self, Catalog, Category, SubCategory): sql = chained_join( - self.Catalog.categories, - self.Category.sub_categories, - self.SubCategory.products + Catalog.categories, + Category.sub_categories, + SubCategory.products ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' @@ -77,28 +77,27 @@ class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase): ) -class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - create_tables = False +@pytest.mark.usefixtures('postgresql_dsn') +class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne): - def test_simple_join(self): - assert str(chained_join(self.Catalog.category)) == 'category' + def test_simple_join(self, Catalog): + assert str(chained_join(Catalog.category)) == 'category' - def test_two_relations(self): + def test_two_relations(self, Catalog, Category): sql = chained_join( - self.Catalog.category, - self.Category.sub_category + Catalog.category, + Category.sub_category ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' 'sub_category._category_id' ) - def test_three_relations(self): + def test_three_relations(self, Catalog, Category, SubCategory): sql = chained_join( - self.Catalog.category, - self.Category.sub_category, - self.SubCategory.product + Catalog.category, + Category.sub_category, + SubCategory.product ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' diff --git a/tests/relationships/test_select_correlated_expression.py b/tests/relationships/test_select_correlated_expression.py index b098281..996e908 100644 --- a/tests/relationships/test_select_correlated_expression.py +++ b/tests/relationships/test_select_correlated_expression.py @@ -1,31 +1,23 @@ import pytest import sqlalchemy as sa -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import sessionmaker from sqlalchemy_utils.relationships import select_correlated_expression -@pytest.fixture(scope='class') -def base(): - return declarative_base() - - -@pytest.fixture(scope='class') -def group_user_cls(base): +@pytest.fixture +def group_user_tbl(Base): return sa.Table( 'group_user', - base.metadata, + Base.metadata, sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) ) -@pytest.fixture(scope='class') -def group_cls(base): - class Group(base): +@pytest.fixture +def group_tbl(Base): + class Group(Base): __tablename__ = 'group' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) @@ -33,11 +25,11 @@ def group_cls(base): return Group -@pytest.fixture(scope='class') -def friendship_cls(base): +@pytest.fixture +def friendship_tbl(Base): return sa.Table( 'friendships', - base.metadata, + Base.metadata, sa.Column( 'friend_a_id', sa.Integer, @@ -53,35 +45,37 @@ def friendship_cls(base): ) -@pytest.fixture(scope='class') -def user_cls(base, group_user_cls, friendship_cls): - class User(base): +@pytest.fixture +def User(Base, group_user_tbl, friendship_tbl): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) groups = sa.orm.relationship( 'Group', - secondary=group_user_cls, + secondary=group_user_tbl, backref='users' ) # this relationship is used for persistence friends = sa.orm.relationship( 'User', - secondary=friendship_cls, - primaryjoin=id == friendship_cls.c.friend_a_id, - secondaryjoin=id == friendship_cls.c.friend_b_id, + secondary=friendship_tbl, + primaryjoin=id == friendship_tbl.c.friend_a_id, + secondaryjoin=id == friendship_tbl.c.friend_b_id, ) - friendship_union = sa.select([ - friendship_cls.c.friend_a_id, - friendship_cls.c.friend_b_id + friendship_union = ( + sa.select([ + friendship_tbl.c.friend_a_id, + friendship_tbl.c.friend_b_id ]).union( sa.select([ - friendship_cls.c.friend_b_id, - friendship_cls.c.friend_a_id] + friendship_tbl.c.friend_b_id, + friendship_tbl.c.friend_a_id] ) - ).alias() + ).alias() + ) User.all_friends = sa.orm.relationship( 'User', @@ -94,9 +88,9 @@ def user_cls(base, group_user_cls, friendship_cls): return User -@pytest.fixture(scope='class') -def category_cls(base, group_user_cls, friendship_cls): - class Category(base): +@pytest.fixture +def Category(Base, group_user_tbl, friendship_tbl): + class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) @@ -111,9 +105,9 @@ def category_cls(base, group_user_cls, friendship_cls): return Category -@pytest.fixture(scope='class') -def article_cls(base, category_cls, user_cls): - class Article(base): +@pytest.fixture +def Article(Base, Category, User): + class Article(Base): __tablename__ = 'article' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column(sa.String) @@ -129,144 +123,104 @@ def article_cls(base, category_cls, user_cls): content = sa.Column(sa.String) - category_id = sa.Column(sa.Integer, sa.ForeignKey(category_cls.id)) - category = sa.orm.relationship(category_cls, backref='articles') + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + category = sa.orm.relationship(Category, backref='articles') - author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id)) + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship( - user_cls, - primaryjoin=author_id == user_cls.id, + User, + primaryjoin=author_id == User.id, backref='authored_articles' ) - owner_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id)) + owner_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) owner = sa.orm.relationship( - user_cls, - primaryjoin=owner_id == user_cls.id, + User, + primaryjoin=owner_id == User.id, backref='owned_articles' ) return Article -@pytest.fixture(scope='class') -def comment_cls(base, article_cls, user_cls): - class Comment(base): +@pytest.fixture +def Comment(Base, Article, User): + class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.String) - article_id = sa.Column(sa.Integer, sa.ForeignKey(article_cls.id)) - article = sa.orm.relationship(article_cls, backref='comments') + article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) + article = sa.orm.relationship(Article, backref='comments') - author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id)) - author = sa.orm.relationship(user_cls, backref='comments') + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + author = sa.orm.relationship(User, backref='comments') - article_cls.comment_count = sa.orm.column_property( + Article.comment_count = sa.orm.column_property( sa.select([sa.func.count(Comment.id)]) - .where(Comment.article_id == article_cls.id) - .correlate_except(article_cls) + .where(Comment.article_id == Article.id) + .correlate_except(Article) ) return Comment -@pytest.fixture(scope='class') -def composite_pk_cls(base): - class CompositePKModel(base): - __tablename__ = 'composite_pk_model' - a = sa.Column(sa.Integer, primary_key=True) - b = sa.Column(sa.Integer, primary_key=True) - return CompositePKModel - - -@pytest.fixture(scope='class') -def dns(): - return 'postgres://postgres@localhost/sqlalchemy_utils_test' - - -@pytest.yield_fixture(scope='class') -def engine(dns): - engine = create_engine(dns) - engine.echo = True - yield engine - engine.dispose() - - -@pytest.yield_fixture(scope='class') -def connection(engine): - conn = engine.connect() - yield conn - conn.close() - - -@pytest.fixture(scope='class') -def model_mapping(article_cls, category_cls, comment_cls, group_cls, user_cls): +@pytest.fixture +def model_mapping(Article, Category, Comment, group_tbl, User): return { - 'articles': article_cls, - 'categories': category_cls, - 'comments': comment_cls, - 'groups': group_cls, - 'users': user_cls + 'articles': Article, + 'categories': Category, + 'comments': Comment, + 'groups': group_tbl, + 'users': User } -@pytest.yield_fixture(scope='class') -def table_creator(base, connection, model_mapping): - sa.orm.configure_mappers() - base.metadata.create_all(connection) - yield - base.metadata.drop_all(connection) +@pytest.fixture +def init_models(Article, Category, Comment, group_tbl, User): + pass -@pytest.yield_fixture(scope='class') -def session(connection): - Session = sessionmaker(bind=connection) - session = Session() - yield session - session.close_all() - - -@pytest.fixture(scope='class') +@pytest.fixture def dataset( session, - user_cls, - group_cls, - article_cls, - category_cls, - comment_cls + User, + group_tbl, + Article, + Category, + Comment ): - group = group_cls(name='Group 1') - group2 = group_cls(name='Group 2') - user = user_cls(id=1, name='User 1', groups=[group, group2]) - user2 = user_cls(id=2, name='User 2') - user3 = user_cls(id=3, name='User 3', groups=[group]) - user4 = user_cls(id=4, name='User 4', groups=[group2]) - user5 = user_cls(id=5, name='User 5') + group = group_tbl(name='Group 1') + group2 = group_tbl(name='Group 2') + user = User(id=1, name='User 1', groups=[group, group2]) + user2 = User(id=2, name='User 2') + user3 = User(id=3, name='User 3', groups=[group]) + user4 = User(id=4, name='User 4', groups=[group2]) + user5 = User(id=5, name='User 5') user.friends = [user2] user2.friends = [user3, user4] user3.friends = [user5] - article = article_cls( + article = Article( name='Some article', author=user, owner=user2, - category=category_cls( + category=Category( id=1, name='Some category', subcategories=[ - category_cls( + Category( id=2, name='Subcategory 1', subcategories=[ - category_cls( + Category( id=3, name='Subsubcategory 1', subcategories=[ - category_cls( + Category( id=5, name='Subsubsubcategory 1', ), - category_cls( + Category( id=6, name='Subsubsubcategory 2', ) @@ -274,11 +228,11 @@ def dataset( ) ] ), - category_cls(id=4, name='Subcategory 2'), + Category(id=4, name='Subcategory 2'), ] ), comments=[ - comment_cls( + Comment( content='Some comment', author=user ) @@ -290,7 +244,7 @@ def dataset( session.commit() -@pytest.mark.usefixtures('table_creator', 'dataset') +@pytest.mark.usefixtures('dataset', 'postgresql_dsn') class TestSelectCorrelatedExpression(object): @pytest.mark.parametrize( ('model_key', 'related_model_key', 'path', 'result'), @@ -428,20 +382,20 @@ class TestSelectCorrelatedExpression(object): def test_with_non_aggregate_function( self, session, - user_cls, - article_cls + User, + Article ): aggregate = select_correlated_expression( - article_cls, - sa.func.json_build_object('name', user_cls.name), + Article, + sa.func.json_build_object('name', User.name), 'comments.author', - user_cls + User ) query = session.query( - article_cls.id, + Article.id, aggregate.label('author_json') - ).order_by(article_cls.id) + ).order_by(Article.id) result = query.all() assert result == [ (1, {'name': 'User 1'}) diff --git a/tests/test_asserts.py b/tests/test_asserts.py index 2583217..3379574 100644 --- a/tests/test_asserts.py +++ b/tests/test_asserts.py @@ -9,143 +9,152 @@ from sqlalchemy_utils import ( assert_non_nullable, assert_nullable ) -from tests import TestCase -class AssertionTestCase(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture() +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column('_name', sa.String(20)) + age = sa.Column('_age', sa.Integer, nullable=False) + email = sa.Column( + '_email', sa.String(200), nullable=False, unique=True + ) + fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer)) - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column('_id', sa.Integer, primary_key=True) - name = sa.Column('_name', sa.String(20)) - age = sa.Column('_age', sa.Integer, nullable=False) - email = sa.Column( - '_email', sa.String(200), nullable=False, unique=True - ) - fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer)) - - __table_args__ = ( - sa.CheckConstraint(sa.and_(age >= 0, age <= 150)), - sa.CheckConstraint( - sa.and_( - sa.func.array_length(fav_numbers, 1) <= 8 - ) + __table_args__ = ( + sa.CheckConstraint(sa.and_(age >= 0, age <= 150)), + sa.CheckConstraint( + sa.and_( + sa.func.array_length(fav_numbers, 1) <= 8 ) ) - - self.User = User - - def setup_method(self, method): - TestCase.setup_method(self, method) - user = self.User( - name='Someone', - email='someone@example.com', - age=15, - fav_numbers=[1, 2, 3] ) - self.session.add(user) - self.session.commit() - self.user = user + return User -class TestAssertMaxLengthWithArray(AssertionTestCase): - def test_with_max_length(self): - assert_max_length(self.user, 'fav_numbers', 8) - assert_max_length(self.user, 'fav_numbers', 8) +@pytest.fixture() +def user(User, session): + user = User( + name='Someone', + email='someone@example.com', + age=15, + fav_numbers=[1, 2, 3] + ) + session.add(user) + session.commit() + return user - def test_smaller_than_max_length(self): + +@pytest.mark.usefixtures('postgresql_dsn') +class TestAssertMaxLengthWithArray(object): + + def test_with_max_length(self, user): + assert_max_length(user, 'fav_numbers', 8) + assert_max_length(user, 'fav_numbers', 8) + + def test_smaller_than_max_length(self, user): with pytest.raises(AssertionError): - assert_max_length(self.user, 'fav_numbers', 7) + assert_max_length(user, 'fav_numbers', 7) with pytest.raises(AssertionError): - assert_max_length(self.user, 'fav_numbers', 7) + assert_max_length(user, 'fav_numbers', 7) - def test_bigger_than_max_length(self): + def test_bigger_than_max_length(self, user): with pytest.raises(AssertionError): - assert_max_length(self.user, 'fav_numbers', 9) + assert_max_length(user, 'fav_numbers', 9) with pytest.raises(AssertionError): - assert_max_length(self.user, 'fav_numbers', 9) + assert_max_length(user, 'fav_numbers', 9) -class TestAssertNonNullable(AssertionTestCase): - def test_non_nullable_column(self): +@pytest.mark.usefixtures('postgresql_dsn') +class TestAssertNonNullable(object): + + def test_non_nullable_column(self, user): # Test everything twice so that session gets rolled back properly - assert_non_nullable(self.user, 'age') - assert_non_nullable(self.user, 'age') + assert_non_nullable(user, 'age') + assert_non_nullable(user, 'age') - def test_nullable_column(self): + def test_nullable_column(self, user): with pytest.raises(AssertionError): - assert_non_nullable(self.user, 'name') + assert_non_nullable(user, 'name') with pytest.raises(AssertionError): - assert_non_nullable(self.user, 'name') + assert_non_nullable(user, 'name') -class TestAssertNullable(AssertionTestCase): - def test_nullable_column(self): - assert_nullable(self.user, 'name') - assert_nullable(self.user, 'name') +@pytest.mark.usefixtures('postgresql_dsn') +class TestAssertNullable(object): - def test_non_nullable_column(self): + def test_nullable_column(self, user): + assert_nullable(user, 'name') + assert_nullable(user, 'name') + + def test_non_nullable_column(self, user): with pytest.raises(AssertionError): - assert_nullable(self.user, 'age') + assert_nullable(user, 'age') with pytest.raises(AssertionError): - assert_nullable(self.user, 'age') + assert_nullable(user, 'age') -class TestAssertMaxLength(AssertionTestCase): - def test_with_max_length(self): - assert_max_length(self.user, 'name', 20) - assert_max_length(self.user, 'name', 20) +@pytest.mark.usefixtures('postgresql_dsn') +class TestAssertMaxLength(object): - def test_with_non_nullable_column(self): - assert_max_length(self.user, 'email', 200) - assert_max_length(self.user, 'email', 200) + def test_with_max_length(self, user): + assert_max_length(user, 'name', 20) + assert_max_length(user, 'name', 20) - def test_smaller_than_max_length(self): - with pytest.raises(AssertionError): - assert_max_length(self.user, 'name', 19) - with pytest.raises(AssertionError): - assert_max_length(self.user, 'name', 19) + def test_with_non_nullable_column(self, user): + assert_max_length(user, 'email', 200) + assert_max_length(user, 'email', 200) - def test_bigger_than_max_length(self): + def test_smaller_than_max_length(self, user): with pytest.raises(AssertionError): - assert_max_length(self.user, 'name', 21) + assert_max_length(user, 'name', 19) with pytest.raises(AssertionError): - assert_max_length(self.user, 'name', 21) + assert_max_length(user, 'name', 19) + + def test_bigger_than_max_length(self, user): + with pytest.raises(AssertionError): + assert_max_length(user, 'name', 21) + with pytest.raises(AssertionError): + assert_max_length(user, 'name', 21) -class TestAssertMinValue(AssertionTestCase): - def test_with_min_value(self): - assert_min_value(self.user, 'age', 0) - assert_min_value(self.user, 'age', 0) +@pytest.mark.usefixtures('postgresql_dsn') +class TestAssertMinValue(object): - def test_smaller_than_min_value(self): - with pytest.raises(AssertionError): - assert_min_value(self.user, 'age', -1) - with pytest.raises(AssertionError): - assert_min_value(self.user, 'age', -1) + def test_with_min_value(self, user): + assert_min_value(user, 'age', 0) + assert_min_value(user, 'age', 0) - def test_bigger_than_min_value(self): + def test_smaller_than_min_value(self, user): with pytest.raises(AssertionError): - assert_min_value(self.user, 'age', 1) + assert_min_value(user, 'age', -1) with pytest.raises(AssertionError): - assert_min_value(self.user, 'age', 1) + assert_min_value(user, 'age', -1) + + def test_bigger_than_min_value(self, user): + with pytest.raises(AssertionError): + assert_min_value(user, 'age', 1) + with pytest.raises(AssertionError): + assert_min_value(user, 'age', 1) -class TestAssertMaxValue(AssertionTestCase): - def test_with_min_value(self): - assert_max_value(self.user, 'age', 150) - assert_max_value(self.user, 'age', 150) +@pytest.mark.usefixtures('postgresql_dsn') +class TestAssertMaxValue(object): - def test_smaller_than_max_value(self): - with pytest.raises(AssertionError): - assert_max_value(self.user, 'age', 149) - with pytest.raises(AssertionError): - assert_max_value(self.user, 'age', 149) + def test_with_min_value(self, user): + assert_max_value(user, 'age', 150) + assert_max_value(user, 'age', 150) - def test_bigger_than_max_value(self): + def test_smaller_than_max_value(self, user): with pytest.raises(AssertionError): - assert_max_value(self.user, 'age', 151) + assert_max_value(user, 'age', 149) with pytest.raises(AssertionError): - assert_max_value(self.user, 'age', 151) + assert_max_value(user, 'age', 149) + + def test_bigger_than_max_value(self, user): + with pytest.raises(AssertionError): + assert_max_value(user, 'age', 151) + with pytest.raises(AssertionError): + assert_max_value(user, 'age', 151) diff --git a/tests/test_auto_delete_orphans.py b/tests/test_auto_delete_orphans.py index 6efb5b9..06daaa1 100644 --- a/tests/test_auto_delete_orphans.py +++ b/tests/test_auto_delete_orphans.py @@ -1,117 +1,108 @@ +import pytest import sqlalchemy as sa -from pytest import raises from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured -from tests import TestCase -class TestAutoDeleteOrphans(TestCase): - def create_models(self): - tagging = sa.Table( - 'tagging', - self.Base.metadata, - sa.Column( - 'tag_id', - sa.Integer, - sa.ForeignKey('tag.id', ondelete='cascade'), - primary_key=True - ), - sa.Column( - 'entry_id', - sa.Integer, - sa.ForeignKey('entry.id', ondelete='cascade'), - primary_key=True - ) +@pytest.fixture +def tagging_tbl(Base): + return sa.Table( + 'tagging', + Base.metadata, + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id', ondelete='cascade'), + primary_key=True + ), + sa.Column( + 'entry_id', + sa.Integer, + sa.ForeignKey('entry.id', ondelete='cascade'), + primary_key=True ) + ) - class Tag(self.Base): - __tablename__ = 'tag' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String(100), unique=True, nullable=False) - def __init__(self, name=None): - self.name = name +@pytest.fixture +def Tag(Base): + class Tag(Base): + __tablename__ = 'tag' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100), unique=True, nullable=False) - class Entry(self.Base): - __tablename__ = 'entry' + def __init__(self, name=None): + self.name = name + return Tag - id = sa.Column(sa.Integer, primary_key=True) - tags = sa.orm.relationship( - 'Tag', - secondary=tagging, - backref='entries' - ) +@pytest.fixture +def Entry(Base, Tag, tagging_tbl): + class Entry(Base): + __tablename__ = 'entry' - auto_delete_orphans(Entry.tags) + id = sa.Column(sa.Integer, primary_key=True) - self.Tag = Tag - self.Entry = Entry + tags = sa.orm.relationship( + 'Tag', + secondary=tagging_tbl, + backref='entries' + ) + auto_delete_orphans(Entry.tags) + return Entry - def test_orphan_deletion(self): - r1 = self.Entry() - r2 = self.Entry() - r3 = self.Entry() + +@pytest.fixture +def EntryWithoutTagsBackref(Base, Tag, tagging_tbl): + class EntryWithoutTagsBackref(Base): + __tablename__ = 'entry' + + id = sa.Column(sa.Integer, primary_key=True) + + tags = sa.orm.relationship( + 'Tag', + secondary=tagging_tbl + ) + return EntryWithoutTagsBackref + + +class TestAutoDeleteOrphans(object): + + @pytest.fixture + def init_models(self, Entry, Tag): + pass + + def test_orphan_deletion(self, session, Entry, Tag): + r1 = Entry() + r2 = Entry() + r3 = Entry() t1, t2, t3, t4 = ( - self.Tag('t1'), - self.Tag('t2'), - self.Tag('t3'), - self.Tag('t4') + Tag('t1'), + Tag('t2'), + Tag('t3'), + Tag('t4') ) r1.tags.extend([t1, t2]) r2.tags.extend([t2, t3]) r3.tags.extend([t4]) - self.session.add_all([r1, r2, r3]) + session.add_all([r1, r2, r3]) - assert self.session.query(self.Tag).count() == 4 + assert session.query(Tag).count() == 4 r2.tags.remove(t2) - assert self.session.query(self.Tag).count() == 4 + assert session.query(Tag).count() == 4 r1.tags.remove(t2) - assert self.session.query(self.Tag).count() == 3 + assert session.query(Tag).count() == 3 r1.tags.remove(t1) - assert self.session.query(self.Tag).count() == 2 + assert session.query(Tag).count() == 2 -class TestAutoDeleteOrphansWithoutBackref(TestCase): - def create_models(self): - tagging = sa.Table( - 'tagging', - self.Base.metadata, - sa.Column( - 'tag_id', - sa.Integer, - sa.ForeignKey('tag.id', ondelete='cascade'), - primary_key=True - ), - sa.Column( - 'entry_id', - sa.Integer, - sa.ForeignKey('entry.id', ondelete='cascade'), - primary_key=True - ) - ) +class TestAutoDeleteOrphansWithoutBackref(object): - class Tag(self.Base): - __tablename__ = 'tag' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String(100), unique=True, nullable=False) + @pytest.fixture + def init_models(self, EntryWithoutTagsBackref, Tag): + pass - def __init__(self, name=None): - self.name = name - - class Entry(self.Base): - __tablename__ = 'entry' - - id = sa.Column(sa.Integer, primary_key=True) - - tags = sa.orm.relationship( - 'Tag', - secondary=tagging - ) - - self.Entry = Entry - - def test_orphan_deletion(self): - with raises(ImproperlyConfigured): - auto_delete_orphans(self.Entry.tags) + def test_orphan_deletion(self, EntryWithoutTagsBackref): + with pytest.raises(ImproperlyConfigured): + auto_delete_orphans(EntryWithoutTagsBackref.tags) diff --git a/tests/test_case_insensitive_comparator.py b/tests/test_case_insensitive_comparator.py index 19812e1..51740d5 100644 --- a/tests/test_case_insensitive_comparator.py +++ b/tests/test_case_insensitive_comparator.py @@ -1,50 +1,60 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import EmailType -from tests import TestCase -class TestCaseInsensitiveComparator(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - email = sa.Column(EmailType) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + email = sa.Column(EmailType) - def __repr__(self): - return 'Building(%r)' % self.id + def __repr__(self): + return 'Building(%r)' % self.id + return User - self.User = User - def test_supports_equals(self): +@pytest.fixture +def init_models(User): + pass + + +class TestCaseInsensitiveComparator(object): + + def test_supports_equals(self, session, User): query = ( - self.session.query(self.User) - .filter(self.User.email == u'email@example.com') + session.query(User) + .filter(User.email == u'email@example.com') ) assert '"user".email = lower(:lower_1)' in str(query) - def test_supports_in_(self): + def test_supports_in_(self, session, User): query = ( - self.session.query(self.User) - .filter(self.User.email.in_([u'email@example.com', u'a'])) + session.query(User) + .filter(User.email.in_([u'email@example.com', u'a'])) ) assert ( '"user".email IN (lower(:lower_1), lower(:lower_2))' in str(query) ) - def test_supports_notin_(self): + def test_supports_notin_(self, session, User): query = ( - self.session.query(self.User) - .filter(self.User.email.notin_([u'email@example.com', u'a'])) + session.query(User) + .filter(User.email.notin_([u'email@example.com', u'a'])) ) assert ( '"user".email NOT IN (lower(:lower_1), lower(:lower_2))' in str(query) ) - def test_does_not_apply_lower_to_types_that_are_already_lowercased(self): - assert str(self.User.email == self.User.email) == ( + def test_does_not_apply_lower_to_types_that_are_already_lowercased( + self, + User + ): + assert str(User.email == User.email) == ( '"user".email = "user".email' ) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1bcd223..00032c4 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,86 +1,93 @@ +import pytest import sqlalchemy as sa -from pytest import raises from sqlalchemy.dialects import postgresql from sqlalchemy_utils import Asterisk, row_to_json from sqlalchemy_utils.expressions import explain, explain_analyze -from tests import TestCase -class ExpressionTestCase(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - - def create_models(self): - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - content = sa.Column(sa.UnicodeText) - - self.Article = Article - - def assert_startswith(self, query, query_part): +@pytest.fixture +def assert_startswith(session): + def assert_startswith(query, query_part): assert str( query.compile(dialect=postgresql.dialect()) ).startswith(query_part) # Check that query executes properly - self.session.execute(query) + session.execute(query) + return assert_startswith -class TestExplain(ExpressionTestCase): - def test_render_explain(self): - self.assert_startswith( - explain(self.session.query(self.Article)), +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + content = sa.Column(sa.UnicodeText) + return Article + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestExplain(object): + + def test_render_explain(self, session, assert_startswith, Article): + assert_startswith( + explain(session.query(Article)), 'EXPLAIN SELECT' ) - def test_render_explain_with_analyze(self): - self.assert_startswith( - explain(self.session.query(self.Article), analyze=True), + def test_render_explain_with_analyze( + self, + session, + assert_startswith, + Article + ): + assert_startswith( + explain(session.query(Article), analyze=True), 'EXPLAIN (ANALYZE true) SELECT' ) - def test_with_string_as_stmt_param(self): - self.assert_startswith( + def test_with_string_as_stmt_param(self, assert_startswith): + assert_startswith( explain('SELECT 1 FROM article'), 'EXPLAIN SELECT' ) - def test_format(self): - self.assert_startswith( + def test_format(self, assert_startswith): + assert_startswith( explain('SELECT 1 FROM article', format='json'), 'EXPLAIN (FORMAT json) SELECT' ) - def test_timing(self): - self.assert_startswith( + def test_timing(self, assert_startswith): + assert_startswith( explain('SELECT 1 FROM article', analyze=True, timing=False), 'EXPLAIN (ANALYZE true, TIMING false) SELECT' ) - def test_verbose(self): - self.assert_startswith( + def test_verbose(self, assert_startswith): + assert_startswith( explain('SELECT 1 FROM article', verbose=True), 'EXPLAIN (VERBOSE true) SELECT' ) - def test_buffers(self): - self.assert_startswith( + def test_buffers(self, assert_startswith): + assert_startswith( explain('SELECT 1 FROM article', analyze=True, buffers=True), 'EXPLAIN (ANALYZE true, BUFFERS true) SELECT' ) - def test_costs(self): - self.assert_startswith( + def test_costs(self, assert_startswith): + assert_startswith( explain('SELECT 1 FROM article', costs=False), 'EXPLAIN (COSTS false) SELECT' ) -class TestExplainAnalyze(ExpressionTestCase): - def test_render_explain_analyze(self): +class TestExplainAnalyze(object): + def test_render_explain_analyze(self, session, Article): assert str( - explain_analyze(self.session.query(self.Article)) + explain_analyze(session.query(Article)) .compile( dialect=postgresql.dialect() ) @@ -111,7 +118,7 @@ class TestAsterisk(object): class TestRowToJson(object): def test_compiler_with_default_dialect(self): - with raises(sa.exc.CompileError): + with pytest.raises(sa.exc.CompileError): str(row_to_json(sa.text('article.*'))) def test_compiler_with_postgresql(self): @@ -128,7 +135,7 @@ class TestRowToJson(object): class TestArrayAgg(object): def test_compiler_with_default_dialect(self): - with raises(sa.exc.CompileError): + with pytest.raises(sa.exc.CompileError): str(sa.func.array_agg(sa.text('u.name'))) def test_compiler_with_postgresql(self): diff --git a/tests/test_instant_defaults_listener.py b/tests/test_instant_defaults_listener.py index 0ec444a..702e723 100644 --- a/tests/test_instant_defaults_listener.py +++ b/tests/test_instant_defaults_listener.py @@ -1,27 +1,29 @@ from datetime import datetime +import pytest import sqlalchemy as sa from sqlalchemy_utils.listeners import force_instant_defaults -from tests import TestCase force_instant_defaults() -class TestInstantDefaultListener(TestCase): - def create_models(self): - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255), default=u'Some article') - created_at = sa.Column(sa.DateTime, default=datetime.now) +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255), default=u'Some article') + created_at = sa.Column(sa.DateTime, default=datetime.now) + return Article - self.Article = Article - def test_assigns_defaults_on_object_construction(self): - article = self.Article() +class TestInstantDefaultListener(object): + + def test_assigns_defaults_on_object_construction(self, Article): + article = Article() assert article.name == u'Some article' - def test_callables_as_defaults(self): - article = self.Article() + def test_callables_as_defaults(self, Article): + article = Article() assert isinstance(article.created_at, datetime) diff --git a/tests/test_instrumented_list.py b/tests/test_instrumented_list.py index 525d46b..d02c23e 100644 --- a/tests/test_instrumented_list.py +++ b/tests/test_instrumented_list.py @@ -1,14 +1,19 @@ -from tests import TestCase - - -class TestInstrumentedList(TestCase): - def test_any_returns_true_if_member_has_attr_defined(self): - category = self.Category() - category.articles.append(self.Article()) - category.articles.append(self.Article(name=u'some name')) +class TestInstrumentedList(object): + def test_any_returns_true_if_member_has_attr_defined( + self, + Category, + Article + ): + category = Category() + category.articles.append(Article()) + category.articles.append(Article(name=u'some name')) assert category.articles.any('name') - def test_any_returns_false_if_no_member_has_attr_defined(self): - category = self.Category() - category.articles.append(self.Article()) + def test_any_returns_false_if_no_member_has_attr_defined( + self, + Category, + Article + ): + category = Category() + category.articles.append(Article()) assert not category.articles.any('name') diff --git a/tests/test_models.py b/tests/test_models.py index 83aa912..d0790b5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,39 +1,40 @@ from datetime import datetime +import pytest import sqlalchemy as sa from sqlalchemy_utils import Timestamp -from tests import TestCase -class TestTimestamp(TestCase): +@pytest.fixture +def Article(Base): + class Article(Base, Timestamp): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255), default=u'Some article') + return Article - def create_models(self): - class Article(self.Base, Timestamp): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255), default=u'Some article') - self.Article = Article +class TestTimestamp(object): - def test_created(self): + def test_created(self, session, Article): then = datetime.utcnow() - article = self.Article() + article = Article() - self.session.add(article) - self.session.commit() + session.add(article) + session.commit() assert article.created >= then and article.created <= datetime.utcnow() - def test_updated(self): - article = self.Article() + def test_updated(self, session, Article): + article = Article() - self.session.add(article) - self.session.commit() + session.add(article) + session.commit() then = datetime.utcnow() article.name = u"Something" - self.session.commit() + session.commit() assert article.updated >= then and article.updated <= datetime.utcnow() diff --git a/tests/test_path.py b/tests/test_path.py index 06af498..7e08f4f 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -1,122 +1,127 @@ +import pytest import six import sqlalchemy as sa -from pytest import mark from sqlalchemy.util.langhelpers import symbol from sqlalchemy_utils.path import AttrPath, Path -from tests import TestCase -class TestAttrPath(TestCase): - def create_models(self): - class Document(self.Base): - __tablename__ = 'document' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) +@pytest.fixture +def Document(Base): + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + return Document - class Section(self.Base): - __tablename__ = 'section' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) - document_id = sa.Column( - sa.Integer, sa.ForeignKey(Document.id) - ) +@pytest.fixture +def Section(Base, Document): + class Section(Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) - document = sa.orm.relationship(Document, backref='sections') - - class SubSection(self.Base): - __tablename__ = 'subsection' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - locale = sa.Column(sa.String(10)) - - section_id = sa.Column( - sa.Integer, sa.ForeignKey(Section.id) - ) - - section = sa.orm.relationship(Section, backref='subsections') - - self.Document = Document - self.Section = Section - self.SubSection = SubSection - - @mark.parametrize( - ('class_', 'path', 'direction'), - ( - ('SubSection', 'section', symbol('MANYTOONE')), + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) ) - ) - def test_direction(self, class_, path, direction): + + document = sa.orm.relationship(Document, backref='sections') + return Section + + +@pytest.fixture +def SubSection(Base, Section): + class SubSection(Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship(Section, backref='subsections') + return SubSection + + +class TestAttrPath(object): + + @pytest.fixture + def init_models(self, Document, Section, SubSection): + pass + + def test_direction(self, SubSection): assert ( - AttrPath(getattr(self, class_), path).direction == direction + AttrPath(SubSection, 'section').direction == symbol('MANYTOONE') ) - def test_invert(self): - path = ~ AttrPath(self.SubSection, 'section.document') + def test_invert(self, Document, Section, SubSection): + path = ~ AttrPath(SubSection, 'section.document') assert path.parts == [ - self.Document.sections, - self.Section.subsections + Document.sections, + Section.subsections ] assert str(path.path) == 'sections.subsections' - def test_len(self): - len(AttrPath(self.SubSection, 'section.document')) == 2 + def test_len(self, SubSection): + len(AttrPath(SubSection, 'section.document')) == 2 - def test_init(self): - path = AttrPath(self.SubSection, 'section.document') - assert path.class_ == self.SubSection + def test_init(self, SubSection): + path = AttrPath(SubSection, 'section.document') + assert path.class_ == SubSection assert path.path == Path('section.document') - def test_iter(self): - path = AttrPath(self.SubSection, 'section.document') + def test_iter(self, Section, SubSection): + path = AttrPath(SubSection, 'section.document') assert list(path) == [ - self.SubSection.section, - self.Section.document + SubSection.section, + Section.document ] - def test_repr(self): - path = AttrPath(self.SubSection, 'section.document') + def test_repr(self, SubSection): + path = AttrPath(SubSection, 'section.document') assert repr(path) == ( "AttrPath(SubSection, 'section.document')" ) - def test_index(self): - path = AttrPath(self.SubSection, 'section.document') - assert path.index(self.Section.document) == 1 - assert path.index(self.SubSection.section) == 0 + def test_index(self, Section, SubSection): + path = AttrPath(SubSection, 'section.document') + assert path.index(Section.document) == 1 + assert path.index(SubSection.section) == 0 - def test_getitem(self): - path = AttrPath(self.SubSection, 'section.document') - assert path[0] is self.SubSection.section - assert path[1] is self.Section.document + def test_getitem(self, Section, SubSection): + path = AttrPath(SubSection, 'section.document') + assert path[0] is SubSection.section + assert path[1] is Section.document - def test_getitem_with_slice(self): - path = AttrPath(self.SubSection, 'section.document') - assert path[:] == AttrPath(self.SubSection, 'section.document') - assert path[:-1] == AttrPath(self.SubSection, 'section') - assert path[1:] == AttrPath(self.Section, 'document') + def test_getitem_with_slice(self, Section, SubSection): + path = AttrPath(SubSection, 'section.document') + assert path[:] == AttrPath(SubSection, 'section.document') + assert path[:-1] == AttrPath(SubSection, 'section') + assert path[1:] == AttrPath(Section, 'document') - def test_eq(self): + def test_eq(self, SubSection): assert ( - AttrPath(self.SubSection, 'section.document') == - AttrPath(self.SubSection, 'section.document') + AttrPath(SubSection, 'section.document') == + AttrPath(SubSection, 'section.document') ) assert not ( - AttrPath(self.SubSection, 'section') == - AttrPath(self.SubSection, 'section.document') + AttrPath(SubSection, 'section') == + AttrPath(SubSection, 'section.document') ) - def test_ne(self): + def test_ne(self, SubSection): assert not ( - AttrPath(self.SubSection, 'section.document') != - AttrPath(self.SubSection, 'section.document') + AttrPath(SubSection, 'section.document') != + AttrPath(SubSection, 'section.document') ) assert ( - AttrPath(self.SubSection, 'section') != - AttrPath(self.SubSection, 'section.document') + AttrPath(SubSection, 'section') != + AttrPath(SubSection, 'section.document') ) @@ -133,7 +138,7 @@ class TestPath(object): path = Path('s.s2.s3') assert list(path) == ['s', 's2', 's3'] - @mark.parametrize(('path', 'length'), ( + @pytest.mark.parametrize(('path', 'length'), ( (Path('s.s2.s3'), 3), (Path('s.s2'), 2), (Path(''), 0) @@ -167,14 +172,14 @@ class TestPath(object): path = Path('s.s2.s3') assert path[1:] == Path('s2.s3') - @mark.parametrize(('test', 'result'), ( + @pytest.mark.parametrize(('test', 'result'), ( (Path('s.s2') == Path('s.s2'), True), (Path('s.s2') == Path('s.s3'), False) )) def test_eq(self, test, result): assert test is result - @mark.parametrize(('test', 'result'), ( + @pytest.mark.parametrize(('test', 'result'), ( (Path('s.s2') != Path('s.s2'), False), (Path('s.s2') != Path('s.s3'), True) )) diff --git a/tests/test_proxy_dict.py b/tests/test_proxy_dict.py index a87f39c..a9776b0 100644 --- a/tests/test_proxy_dict.py +++ b/tests/test_proxy_dict.py @@ -1,119 +1,134 @@ +import pytest import sqlalchemy as sa from flexmock import flexmock from sqlalchemy_utils import proxy_dict, ProxyDict -from tests import TestCase -class TestProxyDict(TestCase): - def create_models(self): - class Article(self.Base): - __tablename__ = 'article' +@pytest.fixture +def ArticleTranslation(Base): + class ArticleTranslation(Base): + __tablename__ = 'article_translation' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - description = sa.Column(sa.UnicodeText) - _translations = sa.orm.relationship( - 'ArticleTranslation', - lazy='dynamic', - cascade='all, delete-orphan', - passive_deletes=True, - backref=sa.orm.backref('parent'), + id = sa.Column( + sa.Integer, + sa.ForeignKey('article.id'), + autoincrement=True, + primary_key=True + ) + locale = sa.Column(sa.String(10), primary_key=True) + name = sa.Column(sa.UnicodeText) + return ArticleTranslation + + +@pytest.fixture +def Article(Base, ArticleTranslation): + + class Article(Base): + __tablename__ = 'article' + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + description = sa.Column(sa.UnicodeText) + _translations = sa.orm.relationship( + ArticleTranslation, + lazy='dynamic', + cascade='all, delete-orphan', + passive_deletes=True, + backref=sa.orm.backref('parent'), + ) + + @property + def translations(self): + return proxy_dict( + self, + '_translations', + ArticleTranslation.locale ) + return Article - @property - def translations(self): - return proxy_dict( - self, - '_translations', - ArticleTranslation.locale - ) - class ArticleTranslation(self.Base): - __tablename__ = 'article_translation' +@pytest.fixture +def init_models(ArticleTranslation, Article): + pass - id = sa.Column( - sa.Integer, - sa.ForeignKey(Article.id), - autoincrement=True, - primary_key=True - ) - locale = sa.Column(sa.String(10), primary_key=True) - name = sa.Column(sa.UnicodeText) - self.Article = Article - self.ArticleTranslation = ArticleTranslation +class TestProxyDict(object): - def test_access_key_for_pending_parent(self): - article = self.Article() - self.session.add(article) + def test_access_key_for_pending_parent(self, session, Article): + article = Article() + session.add(article) assert article.translations['en'] - def test_access_key_for_transient_parent(self): - article = self.Article() + def test_access_key_for_transient_parent(self, Article): + article = Article() assert article.translations['en'] - def test_cache(self): - article = self.Article() + def test_cache(self, session, Article): + article = Article() ( flexmock(ProxyDict) .should_receive('fetch') .once() ) - self.session.add(article) - self.session.commit() + session.add(article) + session.commit() article.translations['en'] article.translations['en'] - def test_set_updates_cache(self): - article = self.Article() + def test_set_updates_cache(self, session, Article, ArticleTranslation): + article = Article() ( flexmock(ProxyDict) .should_receive('fetch') .once() ) - self.session.add(article) - self.session.commit() + session.add(article) + session.commit() article.translations['en'] - article.translations['en'] = self.ArticleTranslation( + article.translations['en'] = ArticleTranslation( locale='en', name=u'something' ) article.translations['en'] - def test_contains_efficiency(self): - article = self.Article() - self.session.add(article) - self.session.commit() + def test_contains_efficiency(self, connection, session, Article): + article = Article() + session.add(article) + session.commit() article.id - query_count = self.connection.query_count + query_count = connection.query_count 'en' in article.translations 'en' in article.translations 'en' in article.translations - assert self.connection.query_count == query_count + 1 + assert connection.query_count == query_count + 1 - def test_getitem_with_none_value_in_cache(self): - article = self.Article() - self.session.add(article) - self.session.commit() + def test_getitem_with_none_value_in_cache(self, session, Article): + article = Article() + session.add(article) + session.commit() article.id 'en' in article.translations assert article.translations['en'] - def test_contains(self): - article = self.Article() + def test_contains(self, Article): + article = Article() assert 'en' not in article.translations # does not auto-append new translation assert 'en' not in article.translations - def test_committing_session_empties_proxy_dict_cache(self): - article = self.Article() + def test_committing_session_empties_proxy_dict_cache( + self, + session, + Article + ): + article = Article() ( flexmock(ProxyDict) .should_receive('fetch') .twice() ) - self.session.add(article) - self.session.commit() + session.add(article) + session.commit() article.translations['en'] - self.session.commit() + session.commit() article.translations['en'] diff --git a/tests/test_query_chain.py b/tests/test_query_chain.py index 89b4ea4..fc0a61a 100644 --- a/tests/test_query_chain.py +++ b/tests/test_query_chain.py @@ -1,93 +1,110 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import QueryChain -from tests import TestCase -class TestQueryChain(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + return User - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - class BlogPost(self.Base): - __tablename__ = 'blog_post' - id = sa.Column(sa.Integer, primary_key=True) +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + return Article - self.User = User - self.Article = Article - self.BlogPost = BlogPost - def setup_method(self, method): - TestCase.setup_method(self, method) - self.users = [ - self.User(), - self.User() - ] - self.articles = [ - self.Article(), - self.Article(), - self.Article(), - self.Article() - ] - self.posts = [ - self.BlogPost(), - self.BlogPost(), - self.BlogPost(), +@pytest.fixture +def BlogPost(Base): + class BlogPost(Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + return BlogPost + + +@pytest.fixture +def init_models(User, Article, BlogPost): + pass + + +@pytest.fixture +def users(session, User): + users = [User(), User()] + session.add_all(users) + session.commit() + return users + + +@pytest.fixture +def articles(session, Article): + articles = [Article(), Article(), Article(), Article()] + session.add_all(articles) + session.commit() + return articles + + +@pytest.fixture +def posts(session, BlogPost): + posts = [BlogPost(), BlogPost(), BlogPost()] + session.add_all(posts) + session.commit() + return posts + + +@pytest.fixture +def chain(session, users, articles, posts, User, Article, BlogPost): + return QueryChain( + [ + session.query(User).order_by('id'), + session.query(Article).order_by('id'), + session.query(BlogPost).order_by('id') ] + ) - self.session.add_all(self.users) - self.session.add_all(self.articles) - self.session.add_all(self.posts) - self.session.commit() - self.chain = QueryChain( - [ - self.session.query(self.User).order_by('id'), - self.session.query(self.Article).order_by('id'), - self.session.query(self.BlogPost).order_by('id') - ] - ) +class TestQueryChain(object): - def test_iter(self): - assert len(list(self.chain)) == 9 + def test_iter(self, chain): + assert len(list(chain)) == 9 - def test_iter_with_limit(self): - chain = self.chain.limit(4) - objects = list(chain) - assert self.users == objects[0:2] - assert self.articles[0:2] == objects[2:] + def test_iter_with_limit(self, chain, users, articles): + c = chain.limit(4) + objects = list(c) + assert users == objects[0:2] + assert articles[0:2] == objects[2:] - def test_iter_with_offset(self): - chain = self.chain.offset(3) - objects = list(chain) - assert self.articles[1:] + self.posts == objects + def test_iter_with_offset(self, chain, articles, posts): + c = chain.offset(3) + objects = list(c) + assert articles[1:] + posts == objects - def test_iter_with_limit_and_offset(self): - chain = self.chain.offset(3).limit(4) - objects = list(chain) - assert self.articles[1:] + self.posts[0:1] == objects + def test_iter_with_limit_and_offset(self, chain, articles, posts): + c = chain.offset(3).limit(4) + objects = list(c) + assert articles[1:] + posts[0:1] == objects - def test_iter_with_offset_spanning_multiple_queries(self): - chain = self.chain.offset(7) - objects = list(chain) - assert self.posts[1:] == objects + def test_iter_with_offset_spanning_multiple_queries(self, chain, posts): + c = chain.offset(7) + objects = list(c) + assert posts[1:] == objects - def test_repr(self): - assert repr(self.chain) == '' % id(self.chain) + def test_repr(self, chain): + assert repr(chain) == '' % id(chain) - def test_getitem_with_slice(self): - chain = self.chain[1:] - assert chain._offset == 1 - assert chain._limit is None + def test_getitem_with_slice(self, chain): + c = chain[1:] + assert c._offset == 1 + assert c._limit is None - def test_getitem_with_single_key(self): - article = self.chain[2] - assert article == self.articles[0] + def test_getitem_with_single_key(self, chain, articles): + article = chain[2] + assert article == articles[0] - def test_count(self): - assert self.chain.count() == 9 + def test_count(self, chain): + assert chain.count() == 9 diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index 221a47d..d44b650 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -1,137 +1,142 @@ +import pytest import sqlalchemy as sa -from pytest import raises from sqlalchemy_utils import sort_query from sqlalchemy_utils.functions import QuerySorterException -from tests import assert_contains, TestCase + +from . import assert_contains -class TestSortQuery(TestCase): - def test_without_sort_param_returns_the_query_object_untouched(self): - query = self.session.query(self.Article) +class TestSortQuery(object): + def test_without_sort_param_returns_the_query_object_untouched( + self, + session, + Article + ): + query = session.query(Article) query = sort_query(query, '') assert query == query - def test_column_ascending(self): - query = sort_query(self.session.query(self.Article), 'name') + def test_column_ascending(self, session, Article): + query = sort_query(session.query(Article), 'name') assert_contains('ORDER BY article.name ASC', query) - def test_column_descending(self): - query = sort_query(self.session.query(self.Article), '-name') + def test_column_descending(self, session, Article): + query = sort_query(session.query(Article), '-name') assert_contains('ORDER BY article.name DESC', query) - def test_skips_unknown_columns(self): - query = self.session.query(self.Article) + def test_skips_unknown_columns(self, session, Article): + query = session.query(Article) query = sort_query(query, '-unknown') assert query == query - def test_non_silent_mode(self): - query = self.session.query(self.Article) - with raises(QuerySorterException): + def test_non_silent_mode(self, session, Article): + query = session.query(Article) + with pytest.raises(QuerySorterException): sort_query(query, '-unknown', silent=False) - def test_join(self): + def test_join(self, session, Article): query = ( - self.session.query(self.Article) - .join(self.Article.category) + session.query(Article) + .join(Article.category) ) query = sort_query(query, 'name', silent=False) assert_contains('ORDER BY article.name ASC', query) - def test_calculated_value_ascending(self): - query = self.session.query( - self.Category, sa.func.count(self.Article.id).label('articles') + def test_calculated_value_ascending(self, session, Article, Category): + query = session.query( + Category, sa.func.count(Article.id).label('articles') ) query = sort_query(query, 'articles') assert_contains('ORDER BY articles ASC', query) - def test_calculated_value_descending(self): - query = self.session.query( - self.Category, sa.func.count(self.Article.id).label('articles') + def test_calculated_value_descending(self, session, Article, Category): + query = session.query( + Category, sa.func.count(Article.id).label('articles') ) query = sort_query(query, '-articles') assert_contains('ORDER BY articles DESC', query) - def test_subqueried_scalar(self): + def test_subqueried_scalar(self, session, Article, Category): article_count = ( sa.sql.select( - [sa.func.count(self.Article.id)], - from_obj=[self.Article.__table__] + [sa.func.count(Article.id)], + from_obj=[Article.__table__] ) - .where(self.Article.category_id == self.Category.id) - .correlate(self.Category.__table__) + .where(Article.category_id == Category.id) + .correlate(Category.__table__) ) - query = self.session.query( - self.Category, article_count.label('articles') + query = session.query( + Category, article_count.label('articles') ) query = sort_query(query, '-articles') assert_contains('ORDER BY articles DESC', query) - def test_aliased_joined_entity(self): - alias = sa.orm.aliased(self.Category, name='categories') - query = self.session.query( - self.Article + def test_aliased_joined_entity(self, session, Article, Category): + alias = sa.orm.aliased(Category, name='categories') + query = session.query( + Article ).join( - alias, self.Article.category + alias, Article.category ) query = sort_query(query, '-categories-name') assert_contains('ORDER BY categories.name DESC', query) - def test_joined_table_column(self): - query = self.session.query(self.Article).join(self.Article.category) + def test_joined_table_column(self, session, Article): + query = session.query(Article).join(Article.category) query = sort_query(query, 'category-name') assert_contains('category.name ASC', query) - def test_multiple_columns(self): - query = self.session.query(self.Article) + def test_multiple_columns(self, session, Article): + query = session.query(Article) query = sort_query(query, 'name', 'id') assert_contains('article.name ASC, article.id ASC', query) - def test_column_property(self): - self.Category.article_count = sa.orm.column_property( - sa.select([sa.func.count(self.Article.id)]) - .where(self.Article.category_id == self.Category.id) + def test_column_property(self, session, Article, Category): + Category.article_count = sa.orm.column_property( + sa.select([sa.func.count(Article.id)]) + .where(Article.category_id == Category.id) .label('article_count') ) - query = self.session.query(self.Category) + query = session.query(Category) query = sort_query(query, 'article_count') assert_contains('article_count ASC', query) - def test_column_property_descending(self): - self.Category.article_count = sa.orm.column_property( - sa.select([sa.func.count(self.Article.id)]) - .where(self.Article.category_id == self.Category.id) + def test_column_property_descending(self, session, Article, Category): + Category.article_count = sa.orm.column_property( + sa.select([sa.func.count(Article.id)]) + .where(Article.category_id == Category.id) .label('article_count') ) - query = self.session.query(self.Category) + query = session.query(Category) query = sort_query(query, '-article_count') assert_contains('article_count DESC', query) - def test_relationship_property(self): - query = self.session.query(self.Category) + def test_relationship_property(self, session, Category): + query = session.query(Category) query = sort_query(query, 'articles') assert 'ORDER BY' not in str(query) - def test_regular_property(self): - query = self.session.query(self.Category) + def test_regular_property(self, session, Category): + query = session.query(Category) query = sort_query(query, 'name_alias') assert 'ORDER BY' not in str(query) - def test_synonym_property(self): - query = self.session.query(self.Category) + def test_synonym_property(self, session, Category): + query = session.query(Category) query = sort_query(query, 'name_synonym') assert_contains('ORDER BY category.name ASC', query) - def test_hybrid_property(self): - query = self.session.query(self.Category) + def test_hybrid_property(self, session, Category): + query = session.query(Category) query = sort_query(query, 'articles_count') assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) - def test_hybrid_property_descending(self): - query = self.session.query(self.Category) + def test_hybrid_property_descending(self, session, Category): + query = session.query(Category) query = sort_query(query, '-articles_count') assert_contains( 'ORDER BY (SELECT count(article.id) AS count_1', @@ -139,88 +144,88 @@ class TestSortQuery(TestCase): ) assert ' DESC' in str(query) - def test_assigned_hybrid_property(self): + def test_assigned_hybrid_property(self, session, Article): def getter(self): return self.name - self.Article.some_hybrid = sa.ext.hybrid.hybrid_property( + Article.some_hybrid = sa.ext.hybrid.hybrid_property( fget=getter ) - query = self.session.query(self.Article) + query = session.query(Article) query = sort_query(query, 'some_hybrid') assert_contains('ORDER BY article.name ASC', query) - def test_with_mapper_and_column_property(self): - class Apple(self.Base): + def test_with_mapper_and_column_property(self, session, Base, Article): + class Apple(Base): __tablename__ = 'apple' id = sa.Column(sa.Integer, primary_key=True) - article_id = sa.Column(sa.Integer, sa.ForeignKey(self.Article.id)) + article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) - self.Article.apples = sa.orm.relationship(Apple) + Article.apples = sa.orm.relationship(Apple) - self.Article.apple_count = sa.orm.column_property( + Article.apple_count = sa.orm.column_property( sa.select([sa.func.count(Apple.id)]) - .where(Apple.article_id == self.Article.id) - .correlate(self.Article.__table__) + .where(Apple.article_id == Article.id) + .correlate(Article.__table__) .label('apple_count'), deferred=True ) query = ( - self.session.query(sa.inspect(self.Article)) - .outerjoin(self.Article.apples) + session.query(sa.inspect(Article)) + .outerjoin(Article.apples) .options( - sa.orm.undefer(self.Article.apple_count) + sa.orm.undefer(Article.apple_count) ) - .options(sa.orm.contains_eager(self.Article.apples)) + .options(sa.orm.contains_eager(Article.apples)) ) query = sort_query(query, 'apple_count') assert 'ORDER BY apple_count' in str(query) - def test_table(self): - query = self.session.query(self.Article.__table__) + def test_table(self, session, Article): + query = session.query(Article.__table__) query = sort_query(query, 'name') assert_contains('ORDER BY article.name', query) -class TestSortQueryRelationshipCounts(TestCase): +@pytest.mark.usefixtures('postgresql_dsn') +class TestSortQueryRelationshipCounts(object): """ Currently this doesn't work with SQLite """ - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - def test_relation_hybrid_property(self): + def test_relation_hybrid_property(self, session, Article): query = ( - self.session.query(self.Article) - .join(self.Article.category) - ).group_by(self.Article.id) + session.query(Article) + .join(Article.category) + ).group_by(Article.id) query = sort_query(query, '-category-articles_count') assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) - def test_aliased_hybrid_property(self): + def test_aliased_hybrid_property(self, session, Article, Category): alias = sa.orm.aliased( - self.Category, + Category, name='categories' ) query = ( - self.session.query(self.Article) - .outerjoin(alias, self.Article.category) + session.query(Article) + .outerjoin(alias, Article.category) .options( - sa.orm.contains_eager(self.Article.category, alias=alias) + sa.orm.contains_eager(Article.category, alias=alias) ) - ).group_by(alias.id, self.Article.id) + ).group_by(alias.id, Article.id) query = sort_query(query, '-categories-articles_count') assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) - def test_aliased_concat_hybrid_property(self): + def test_aliased_concat_hybrid_property(self, session, Article, Category): alias = sa.orm.aliased( - self.Category, + Category, name='aliased' ) query = ( - self.session.query(self.Article) - .outerjoin(alias, self.Article.category) + session.query(Article) + .outerjoin(alias, Article.category) .options( - sa.orm.contains_eager(self.Article.category, alias=alias) + sa.orm.contains_eager(Article.category, alias=alias) ) ) query = sort_query(query, 'aliased-full_name') @@ -229,14 +234,15 @@ class TestSortQueryRelationshipCounts(TestCase): ) -class TestSortQueryWithPolymorphicInheritance(TestCase): +@pytest.mark.usefixtures('postgresql_dsn') +class TestSortQueryWithPolymorphicInheritance(object): """ Currently this doesn't work with SQLite """ - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - def create_models(self): - class TextItem(self.Base): + @pytest.fixture + def TextItem(self, Base): + class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) @@ -246,7 +252,10 @@ class TestSortQueryWithPolymorphicInheritance(TestCase): 'polymorphic_on': type, 'with_polymorphic': '*' } + return TextItem + @pytest.fixture + def Article(self, TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( @@ -256,50 +265,53 @@ class TestSortQueryWithPolymorphicInheritance(TestCase): __mapper_args__ = { 'polymorphic_identity': u'article' } + return Article - self.TextItem = TextItem - self.Article = Article + @pytest.fixture + def init_models(self, TextItem, Article): + pass - def test_column_property(self): - self.TextItem.item_count = sa.orm.column_property( + def test_column_property(self, session, TextItem): + TextItem.item_count = sa.orm.column_property( sa.select( [ sa.func.count('1') ], ) - .select_from(self.TextItem.__table__) + .select_from(TextItem.__table__) .label('item_count') ) query = sort_query( - self.session.query(self.TextItem), + session.query(TextItem), 'item_count' ) assert_contains('ORDER BY item_count', query) - def test_child_class_attribute(self): + def test_child_class_attribute(self, session, TextItem): query = sort_query( - self.session.query(self.TextItem), + session.query(TextItem), 'category' ) assert_contains('ORDER BY article.category ASC', query) - def test_with_ambiguous_column(self): + def test_with_ambiguous_column(self, session, TextItem): query = sort_query( - self.session.query(self.TextItem), + session.query(TextItem), 'id' ) assert_contains('ORDER BY text_item.id ASC', query) -class TestSortQueryWithCustomPolymorphic(TestCase): +@pytest.mark.usefixtures('postgresql_dsn') +class TestSortQueryWithCustomPolymorphic(object): """ Currently this doesn't work with SQLite """ - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - def create_models(self): - class TextItem(self.Base): + @pytest.fixture + def TextItem(self, Base): + class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) @@ -308,7 +320,10 @@ class TestSortQueryWithCustomPolymorphic(TestCase): __mapper_args__ = { 'polymorphic_on': type, } + return TextItem + @pytest.fixture + def Article(self, TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( @@ -318,7 +333,10 @@ class TestSortQueryWithCustomPolymorphic(TestCase): __mapper_args__ = { 'polymorphic_identity': u'article' } + return Article + @pytest.fixture + def BlogPost(self, TextItem): class BlogPost(TextItem): __tablename__ = 'blog_post' id = sa.Column( @@ -327,24 +345,21 @@ class TestSortQueryWithCustomPolymorphic(TestCase): __mapper_args__ = { 'polymorphic_identity': u'blog_post' } + return BlogPost - self.TextItem = TextItem - self.Article = Article - self.BlogPost = BlogPost - - def test_with_unknown_column(self): + def test_with_unknown_column(self, session, TextItem, BlogPost): query = sort_query( - self.session.query( - sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) + session.query( + sa.orm.with_polymorphic(TextItem, [BlogPost]) ), 'category' ) assert 'ORDER BY' not in str(query) - def test_with_existing_column(self): + def test_with_existing_column(self, session, TextItem, Article): query = sort_query( - self.session.query( - sa.orm.with_polymorphic(self.TextItem, [self.Article]) + session.query( + sa.orm.with_polymorphic(TextItem, [Article]) ), 'category' ) diff --git a/tests/test_translation_hybrid.py b/tests/test_translation_hybrid.py index dc81f65..465ba40 100644 --- a/tests/test_translation_hybrid.py +++ b/tests/test_translation_hybrid.py @@ -1,58 +1,64 @@ +import pytest import sqlalchemy as sa from flexmock import flexmock -from pytest import mark from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy_utils import i18n, TranslationHybrid # noqa -from tests import TestCase -@mark.skipif('i18n.babel is None') -class TestTranslationHybrid(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def translation_hybrid(): + return TranslationHybrid('fi', 'en') - def create_models(self): - class City(self.Base): - __tablename__ = 'city' - id = sa.Column(sa.Integer, primary_key=True) - name_translations = sa.Column(HSTORE) - name = self.translation_hybrid(name_translations) - locale = 'en' - self.City = City +@pytest.fixture +def City(Base, translation_hybrid): + class City(Base): + __tablename__ = 'city' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(HSTORE) + name = translation_hybrid(name_translations) + locale = 'en' + return City - def setup_method(self, method): - self.translation_hybrid = TranslationHybrid('fi', 'en') - TestCase.setup_method(self, method) - def test_using_hybrid_as_constructor(self): - city = self.City(name='Helsinki') +@pytest.fixture +def init_models(City): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +@pytest.mark.skipif('i18n.babel is None') +class TestTranslationHybrid(object): + + def test_using_hybrid_as_constructor(self, City): + city = City(name='Helsinki') assert city.name_translations['fi'] == 'Helsinki' - def test_if_no_translation_exists_returns_none(self): - city = self.City() + def test_if_no_translation_exists_returns_none(self, City): + city = City() assert city.name is None - def test_custom_default_value(self): - self.translation_hybrid.default_value = 'Some value' - city = self.City() + def test_custom_default_value(self, City, translation_hybrid): + translation_hybrid.default_value = 'Some value' + city = City() assert city.name is 'Some value' - def test_fall_back_to_default_translation(self): - city = self.City(name_translations={'en': 'Helsinki'}) - self.translation_hybrid.current_locale = 'sv' + def test_fall_back_to_default_translation(self, City, translation_hybrid): + city = City(name_translations={'en': 'Helsinki'}) + translation_hybrid.current_locale = 'sv' assert city.name == 'Helsinki' - def test_fallback_to_dynamic_locale(self): - self.translation_hybrid.current_locale = 'en' - self.translation_hybrid.default_locale = lambda self: self.locale - city = self.City(name_translations={}) + def test_fallback_to_dynamic_locale(self, City, translation_hybrid): + translation_hybrid.current_locale = 'en' + translation_hybrid.default_locale = lambda self: self.locale + city = City(name_translations={}) city.locale = 'fi' city.name_translations['fi'] = 'Helsinki' assert city.name == 'Helsinki' - @mark.parametrize( + @pytest.mark.parametrize( ('name_translations', 'name'), ( ({'fi': 'Helsinki', 'en': 'Helsing'}, 'Helsinki'), @@ -61,20 +67,26 @@ class TestTranslationHybrid(TestCase): ({}, None), ) ) - def test_hybrid_as_an_expression(self, name_translations, name): - city = self.City(name_translations=name_translations) - self.session.add(city) - self.session.commit() + def test_hybrid_as_an_expression( + self, + session, + City, + name_translations, + name + ): + city = City(name_translations=name_translations) + session.add(city) + session.commit() - assert self.session.query(self.City.name).scalar() == name + assert session.query(City.name).scalar() == name - def test_dynamic_locale(self): + def test_dynamic_locale(self, Base): translation_hybrid = TranslationHybrid( lambda obj: obj.locale, 'fi' ) - class Article(self.Base): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name_translations = sa.Column(HSTORE) @@ -86,7 +98,7 @@ class TestTranslationHybrid(TestCase): in str(Article.name) ) - def test_locales_casted_only_in_compilation_phase(self): + def test_locales_casted_only_in_compilation_phase(self, Base): class LocaleGetter(object): def current_locale(self): return lambda obj: obj.locale @@ -97,7 +109,7 @@ class TestTranslationHybrid(TestCase): 'fi' ) - class Article(self.Base): + class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name_translations = sa.Column(HSTORE) diff --git a/tests/types/test_arrow.py b/tests/types/test_arrow.py index f3ca975..38e829d 100644 --- a/tests/types/test_arrow.py +++ b/tests/types/test_arrow.py @@ -1,52 +1,58 @@ from datetime import datetime +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils.types import arrow -from tests import TestCase -@mark.skipif('arrow.arrow is None') -class TestArrowDateTimeType(TestCase): - def create_models(self): - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - created_at = sa.Column(arrow.ArrowType) +@pytest.fixture +def Article(Base): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + created_at = sa.Column(arrow.ArrowType) + return Article - self.Article = Article - def test_parameter_processing(self): - article = self.Article( +@pytest.fixture +def init_models(Article): + pass + + +@pytest.mark.skipif('arrow.arrow is None') +class TestArrowDateTimeType(object): + + def test_parameter_processing(self, session, Article): + article = Article( created_at=arrow.arrow.get(datetime(2000, 11, 1)) ) - self.session.add(article) - self.session.commit() + session.add(article) + session.commit() - article = self.session.query(self.Article).first() + article = session.query(Article).first() assert article.created_at.datetime - def test_string_coercion(self): - article = self.Article( + def test_string_coercion(self, Article): + article = Article( created_at='1367900664' ) assert article.created_at.year == 2013 - def test_utc(self): + def test_utc(self, session, Article): time = arrow.arrow.utcnow() - article = self.Article(created_at=time) - self.session.add(article) + article = Article(created_at=time) + session.add(article) assert article.created_at == time - self.session.commit() + session.commit() assert article.created_at == time - def test_other_tz(self): + def test_other_tz(self, session, Article): time = arrow.arrow.utcnow() local = time.to('US/Pacific') - article = self.Article(created_at=local) - self.session.add(article) + article = Article(created_at=local) + session.add(article) assert article.created_at == time == local - self.session.commit() + session.commit() assert article.created_at == time diff --git a/tests/types/test_choice.py b/tests/types/test_choice.py index 831f0d6..ff9057c 100644 --- a/tests/types/test_choice.py +++ b/tests/types/test_choice.py @@ -1,10 +1,9 @@ +import pytest import sqlalchemy as sa from flexmock import flexmock -from pytest import mark, raises from sqlalchemy_utils import Choice, ChoiceType, ImproperlyConfigured from sqlalchemy_utils.types.choice import Enum -from tests import TestCase class TestChoice(object): @@ -18,9 +17,12 @@ class TestChoice(object): assert not (Choice(1, 1) != 1) -class TestChoiceType(TestCase): - def create_models(self): - class User(self.Base): +class TestChoiceType(object): + + @pytest.fixture + def User(self, Base): + + class User(Base): TYPES = [ ('admin', 'Admin'), ('regular-user', 'Regular user') @@ -33,61 +35,71 @@ class TestChoiceType(TestCase): def __repr__(self): return 'User(%r)' % self.id - self.User = User + return User - def test_python_type(self): - type_ = self.User.__table__.c.type.type + @pytest.fixture + def init_models(self, User): + pass + + def test_python_type(self, User): + type_ = User.__table__.c.type.type assert type_.python_type - def test_string_processing(self): + def test_string_processing(self, session, User): flexmock(ChoiceType).should_receive('_coerce').and_return( u'admin' ) - user = self.User( + user = User( type=u'admin' ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.type.value == u'Admin' - def test_parameter_processing(self): - user = self.User( + def test_parameter_processing(self, session, User): + user = User( type=u'admin' ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.type.value == u'Admin' - def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(type=u'admin') + def test_scalar_attributes_get_coerced_to_objects(self, User): + user = User(type=u'admin') assert isinstance(user.type, Choice) def test_throws_exception_if_no_choices_given(self): - with raises(ImproperlyConfigured): + with pytest.raises(ImproperlyConfigured): ChoiceType([]) -class TestChoiceTypeWithCustomUnderlyingType(TestCase): +class TestChoiceTypeWithCustomUnderlyingType(object): def test_init_type(self): type_ = ChoiceType([(1, u'something')], impl=sa.Integer) assert type_.impl == sa.Integer -@mark.skipif('Enum is None') -class TestEnumType(TestCase): - def create_models(self): +@pytest.mark.skipif('Enum is None') +class TestEnumType(object): + + @pytest.fixture + def OrderStatus(self): class OrderStatus(Enum): unpaid = 0 paid = 1 + return OrderStatus - class Order(self.Base): + @pytest.fixture + def Order(self, Base, OrderStatus): + + class Order(Base): __tablename__ = 'order' id_ = sa.Column(sa.Integer, primary_key=True) status = sa.Column( @@ -98,7 +110,12 @@ class TestEnumType(TestCase): def __repr__(self): return 'Order(%r, %r)' % (self.id_, self.status) - class OrderNullable(self.Base): + return Order + + @pytest.fixture + def OrderNullable(self, Base, OrderStatus): + + class OrderNullable(Base): __tablename__ = 'order_nullable' id_ = sa.Column(sa.Integer, primary_key=True) status = sa.Column( @@ -106,76 +123,83 @@ class TestEnumType(TestCase): nullable=True, ) - self.OrderStatus = OrderStatus - self.Order = Order - self.OrderNullable = OrderNullable + return OrderNullable - def test_parameter_initialization(self): - order = self.Order() + @pytest.fixture + def init_models(self, Order, OrderNullable): + pass - self.session.add(order) - self.session.commit() + def test_parameter_initialization(self, session, Order, OrderStatus): + order = Order() - order = self.session.query(self.Order).first() - assert order.status is self.OrderStatus.unpaid + session.add(order) + session.commit() + + order = session.query(Order).first() + assert order.status is OrderStatus.unpaid assert order.status.value == 0 - def test_setting_by_value(self): - order = self.Order() + def test_setting_by_value(self, session, Order, OrderStatus): + order = Order() order.status = 1 - self.session.add(order) - self.session.commit() + session.add(order) + session.commit() - order = self.session.query(self.Order).first() - assert order.status is self.OrderStatus.paid + order = session.query(Order).first() + assert order.status is OrderStatus.paid - def test_setting_by_enum(self): - order = self.Order() - order.status = self.OrderStatus.paid + def test_setting_by_enum(self, session, Order, OrderStatus): + order = Order() + order.status = OrderStatus.paid - self.session.add(order) - self.session.commit() + session.add(order) + session.commit() - order = self.session.query(self.Order).first() - assert order.status is self.OrderStatus.paid + order = session.query(Order).first() + assert order.status is OrderStatus.paid - def test_setting_value_that_resolves_to_none(self): - order = self.Order() + def test_setting_value_that_resolves_to_none( + self, + session, + Order, + OrderStatus + ): + order = Order() order.status = 0 - self.session.add(order) - self.session.commit() + session.add(order) + session.commit() - order = self.session.query(self.Order).first() - assert order.status is self.OrderStatus.unpaid + order = session.query(Order).first() + assert order.status is OrderStatus.unpaid - def test_setting_to_wrong_enum_raises_valueerror(self): + def test_setting_to_wrong_enum_raises_valueerror(self, Order): class WrongEnum(Enum): foo = 0 bar = 1 - order = self.Order() + order = Order() - with raises(ValueError): + with pytest.raises(ValueError): order.status = WrongEnum.foo - def test_setting_to_uncoerceable_type_raises_valueerror(self): - order = self.Order() - with raises(ValueError): + def test_setting_to_uncoerceable_type_raises_valueerror(self, Order): + order = Order() + with pytest.raises(ValueError): order.status = 'Bad value' - def test_order_nullable_stores_none(self): + def test_order_nullable_stores_none(self, session, OrderNullable): # With nullable=False as in `Order`, a `None` value is always # converted to the default value, unless we explicitly set it to # sqlalchemy.sql.null(), so we use this class to test our ability # to set and retrive `None`. - order_nullable = self.OrderNullable() + order_nullable = OrderNullable() assert order_nullable.status is None order_nullable.status = None - self.session.add(order_nullable) - self.session.commit() + session.add(order_nullable) + session.commit() assert order_nullable.status is None diff --git a/tests/types/test_color.py b/tests/types/test_color.py index aa1d60c..3842ec5 100644 --- a/tests/types/test_color.py +++ b/tests/types/test_color.py @@ -1,56 +1,62 @@ +import pytest import sqlalchemy as sa from flexmock import flexmock -from pytest import mark from sqlalchemy_utils import ColorType, types # noqa -from tests import TestCase -@mark.skipif('types.color.python_colour_type is None') -class TestColorType(TestCase): - def create_models(self): - class Document(self.Base): - __tablename__ = 'document' - id = sa.Column(sa.Integer, primary_key=True) - bg_color = sa.Column(ColorType) +@pytest.fixture +def Document(Base): + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + bg_color = sa.Column(ColorType) - def __repr__(self): - return 'Document(%r)' % self.id + def __repr__(self): + return 'Document(%r)' % self.id + return Document - self.Document = Document - def test_string_parameter_processing(self): +@pytest.fixture +def init_models(Document): + pass + + +@pytest.mark.skipif('types.color.python_colour_type is None') +class TestColorType(object): + + def test_string_parameter_processing(self, session, Document): from colour import Color flexmock(ColorType).should_receive('_coerce').and_return( u'white' ) - document = self.Document( + document = Document( bg_color='white' ) - self.session.add(document) - self.session.commit() + session.add(document) + session.commit() - document = self.session.query(self.Document).first() + document = session.query(Document).first() assert document.bg_color.hex == Color(u'white').hex - def test_color_parameter_processing(self): + def test_color_parameter_processing(self, session, Document): from colour import Color - document = self.Document( + document = Document( bg_color=Color(u'white') ) - self.session.add(document) - self.session.commit() + session.add(document) + session.commit() - document = self.session.query(self.Document).first() + document = session.query(Document).first() assert document.bg_color.hex == Color(u'white').hex - def test_scalar_attributes_get_coerced_to_objects(self): + def test_scalar_attributes_get_coerced_to_objects(self, Document): from colour import Color - document = self.Document(bg_color='white') + document = Document(bg_color='white') assert isinstance(document.bg_color, Color) diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index f5f7fba..f125522 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- +import pytest import sqlalchemy as sa -from pytest import mark -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import ( @@ -17,14 +15,14 @@ from sqlalchemy_utils import ( ) from sqlalchemy_utils.types import pg_composite from sqlalchemy_utils.types.range import intervals -from tests import TestCase -class TestCompositeTypeWithRegularTypes(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeWithRegularTypes(object): - def create_models(self): - class Account(self.Base): + @pytest.fixture + def Account(self, Base): + class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balance = sa.Column( @@ -37,43 +35,44 @@ class TestCompositeTypeWithRegularTypes(TestCase): ) ) - self.Account = Account + return Account - def test_parameter_processing(self): - account = self.Account( + @pytest.fixture + def init_models(self, Account): + pass + + def test_parameter_processing(self, session, Account): + account = Account( balance=('USD', 15) ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert account.balance.currency == 'USD' assert account.balance.amount == 15 - def test_non_ascii_chars(self): - account = self.Account( + def test_non_ascii_chars(self, session, Account): + account = Account( balance=(u'ääöö', 15) ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert account.balance.currency == u'ääöö' assert account.balance.amount == 15 -@mark.skipif('i18n.babel is None') -class TestCompositeTypeWithTypeDecorators(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.skipif('i18n.babel is None') +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeWithTypeDecorators(object): - def setup_method(self, method): - TestCase.setup_method(self, method) - i18n.get_locale = lambda: i18n.babel.Locale('en') - - def create_models(self): - class Account(self.Base): + @pytest.fixture + def Account(self, Base): + class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balance = sa.Column( @@ -86,39 +85,44 @@ class TestCompositeTypeWithTypeDecorators(TestCase): ) ) - self.Account = Account + return Account - def test_result_set_processing(self): - account = self.Account( + @pytest.fixture + def init_models(self, Account): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def test_result_set_processing(self, session, Account): + account = Account( balance=('USD', 15) ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert account.balance.currency == Currency('USD') assert account.balance.amount == 15 - def test_parameter_processing(self): - account = self.Account( + def test_parameter_processing(self, session, Account): + account = Account( balance=(Currency('USD'), 15) ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert account.balance.currency == Currency('USD') assert account.balance.amount == 15 -@mark.skipif('i18n.babel is None') -class TestCompositeTypeInsideArray(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.skipif('i18n.babel is None') +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeInsideArray(object): - def setup_method(self, method): - self.type = CompositeType( + @pytest.fixture + def type_(self): + return CompositeType( 'money_type', [ sa.Column('currency', CurrencyType), @@ -126,43 +130,46 @@ class TestCompositeTypeInsideArray(TestCase): ] ) - TestCase.setup_method(self, method) - i18n.get_locale = lambda: i18n.babel.Locale('en') - - def create_models(self): - class Account(self.Base): + @pytest.fixture + def Account(self, Base, type_): + class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balances = sa.Column( - CompositeArray(self.type) + CompositeArray(type_) ) - self.Account = Account + return Account - def test_parameter_processing(self): - account = self.Account( + @pytest.fixture + def init_models(self, Account): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def test_parameter_processing(self, session, type_, Account): + account = Account( balances=[ - self.type.type_cls(Currency('USD'), 15), - self.type.type_cls(Currency('AUD'), 20) + type_.type_cls(Currency('USD'), 15), + type_.type_cls(Currency('AUD'), 20) ] ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert account.balances[0].currency == Currency('USD') assert account.balances[0].amount == 15 assert account.balances[1].currency == Currency('AUD') assert account.balances[1].amount == 20 -@mark.skipif('intervals is None') -class TestCompositeTypeWithRangeTypeInsideArray(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.skipif('intervals is None') +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeWithRangeTypeInsideArray(object): - def setup_method(self, method): - self.type = CompositeType( + @pytest.fixture + def type_(self): + return CompositeType( 'category', [ sa.Column('scale', NumericRangeType), @@ -170,36 +177,44 @@ class TestCompositeTypeWithRangeTypeInsideArray(TestCase): ] ) - TestCase.setup_method(self, method) - - def create_models(self): - class Account(self.Base): + @pytest.fixture + def Account(self, Base, type_): + class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) categories = sa.Column( - CompositeArray(self.type) + CompositeArray(type_) ) - self.Account = Account + return Account - def test_parameter_processing_with_named_tuple(self): - account = self.Account( + @pytest.fixture + def init_models(self, Account): + pass + + def test_parameter_processing_with_named_tuple( + self, + session, + type_, + Account + ): + account = Account( categories=[ - self.type.type_cls( + type_.type_cls( intervals.DecimalInterval([15, 18]), 'bad' ), - self.type.type_cls( + type_.type_cls( intervals.DecimalInterval([18, 20]), 'good' ) ] ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert ( account.categories[0].scale == intervals.DecimalInterval([15, 18]) ) @@ -209,18 +224,18 @@ class TestCompositeTypeWithRangeTypeInsideArray(TestCase): ) assert account.categories[1].name == 'good' - def test_parameter_processing_with_tuple(self): - account = self.Account( + def test_parameter_processing_with_tuple(self, session, Account): + account = Account( categories=[ (intervals.DecimalInterval([15, 18]), 'bad'), (intervals.DecimalInterval([18, 20]), 'good') ] ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert ( account.categories[0].scale == intervals.DecimalInterval([15, 18]) ) @@ -230,15 +245,19 @@ class TestCompositeTypeWithRangeTypeInsideArray(TestCase): ) assert account.categories[1].name == 'good' - def test_parameter_processing_with_nulls_as_composite_fields(self): - account = self.Account( + def test_parameter_processing_with_nulls_as_composite_fields( + self, + session, + Account + ): + account = Account( categories=[ (None, 'bad'), (intervals.DecimalInterval([18, 20]), None) ] ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() assert account.categories[0].scale is None assert account.categories[0].name == 'bad' assert ( @@ -246,31 +265,32 @@ class TestCompositeTypeWithRangeTypeInsideArray(TestCase): ) assert account.categories[1].name is None - def test_parameter_processing_with_nulls_as_composites(self): - account = self.Account( + def test_parameter_processing_with_nulls_as_composites( + self, + session, + Account + ): + account = Account( categories=[ (None, None), None ] ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() assert account.categories[0].scale is None assert account.categories[0].name is None assert account.categories[1] is None -class TestCompositeTypeWhenTypeAlreadyExistsInDatabase(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeWhenTypeAlreadyExistsInDatabase(object): - def setup_method(self, method): - self.engine = create_engine(self.dns) - self.engine.echo = True - self.connection = self.engine.connect() - self.Base = declarative_base() + @pytest.fixture + def Account(self, Base): pg_composite.registered_composites = {} - self.type = CompositeType( + type_ = CompositeType( 'money_type', [ sa.Column('currency', sa.String), @@ -278,46 +298,50 @@ class TestCompositeTypeWhenTypeAlreadyExistsInDatabase(TestCase): ] ) - self.create_models() + class Account(Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balance = sa.Column(type_) + + return Account + + @pytest.fixture + def session(self, request, engine, connection, Base, Account): sa.orm.configure_mappers() - Session = sessionmaker(bind=self.connection) - self.session = Session() - self.session.execute( + Session = sessionmaker(bind=connection) + session = Session() + session.execute( "CREATE TYPE money_type AS (currency VARCHAR, amount INTEGER)" ) - self.session.execute( + session.execute( """CREATE TABLE account ( id SERIAL, balance MONEY_TYPE, PRIMARY KEY(id) )""" ) - register_composites(self.connection) - def teardown_method(self, method): - self.session.execute('DROP TABLE account') - self.session.execute('DROP TYPE money_type') - self.session.commit() - self.session.close_all() - self.connection.close() - remove_composite_listeners() - self.engine.dispose() + def teardown(): + session.execute('DROP TABLE account') + session.execute('DROP TYPE money_type') + session.commit() + session.close_all() + connection.close() + remove_composite_listeners() + engine.dispose() - def create_models(self): - class Account(self.Base): - __tablename__ = 'account' - id = sa.Column(sa.Integer, primary_key=True) - balance = sa.Column(self.type) + register_composites(connection) + request.addfinalizer(teardown) - self.Account = Account + return session - def test_parameter_processing(self): - account = self.Account( + def test_parameter_processing(self, session, Account): + account = Account( balance=('USD', 15), ) - self.session.add(account) - self.session.commit() + session.add(account) + session.commit() - account = self.session.query(self.Account).first() + account = session.query(Account).first() assert account.balance.currency == 'USD' assert account.balance.amount == 15 diff --git a/tests/types/test_country.py b/tests/types/test_country.py index 59bae9d..7ef53b4 100644 --- a/tests/types/test_country.py +++ b/tests/types/test_country.py @@ -1,34 +1,40 @@ +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import Country, CountryType, i18n # noqa -from tests import TestCase -@mark.skipif('i18n.babel is None') -class TestCountryType(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - country = sa.Column(CountryType) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + country = sa.Column(CountryType) - def __repr__(self): - return 'User(%r)' % self.id + def __repr__(self): + return 'User(%r)' % self.id + return User - self.User = User - def test_parameter_processing(self): - user = self.User( +@pytest.fixture +def init_models(User): + pass + + +@pytest.mark.skipif('i18n.babel is None') +class TestCountryType(object): + + def test_parameter_processing(self, session, User): + user = User( country=Country(u'FI') ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.country.name == u'Finland' - def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(country='FI') + def test_scalar_attributes_get_coerced_to_objects(self, User): + user = User(country='FI') assert isinstance(user.country, Country) diff --git a/tests/types/test_currency.py b/tests/types/test_currency.py index caa7a1d..8266170 100644 --- a/tests/types/test_currency.py +++ b/tests/types/test_currency.py @@ -1,39 +1,51 @@ -# -*- coding: utf-8 -*- +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import Currency, CurrencyType, i18n -from tests import TestCase -@mark.skipif('i18n.babel is None') -class TestCurrencyType(TestCase): - def setup_method(self, method): - TestCase.setup_method(self, method) - i18n.get_locale = lambda: i18n.babel.Locale('en') +@pytest.fixture +def set_get_locale(): + i18n.get_locale = lambda: i18n.babel.Locale('en') - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - currency = sa.Column(CurrencyType) - def __repr__(self): - return 'User(%r)' % self.id +@pytest.fixture +def User(Base): - self.User = User + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + currency = sa.Column(CurrencyType) - def test_parameter_processing(self): - user = self.User( + def __repr__(self): + return 'User(%r)' % self.id + + return User + + +@pytest.fixture +def init_models(User): + pass + + +@pytest.mark.skipif('i18n.babel is None') +class TestCurrencyType(object): + + def test_parameter_processing(self, session, User, set_get_locale): + user = User( currency=Currency('USD') ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.currency.name == u'US Dollar' - def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(currency='USD') + def test_scalar_attributes_get_coerced_to_objects( + self, + User, + set_get_locale + ): + user = User(currency='USD') assert isinstance(user.currency, Currency) diff --git a/tests/types/test_date_range.py b/tests/types/test_date_range.py index 238e30f..779bd62 100644 --- a/tests/types/test_date_range.py +++ b/tests/types/test_date_range.py @@ -1,10 +1,9 @@ from datetime import datetime, timedelta +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import DateRangeType -from tests import TestCase intervals = None inf = 0 @@ -15,30 +14,41 @@ except ImportError: pass -@mark.skipif('intervals is None') -class DateRangeTestCase(TestCase): - def create_models(self): - class Booking(self.Base): - __tablename__ = 'booking' - id = sa.Column(sa.Integer, primary_key=True) - during = sa.Column(DateRangeType) +@pytest.fixture +def Booking(Base): + class Booking(Base): + __tablename__ = 'booking' + id = sa.Column(sa.Integer, primary_key=True) + during = sa.Column(DateRangeType) - self.Booking = Booking + return Booking - def create_booking(self, date_range): - booking = self.Booking( + +@pytest.fixture +def create_booking(session, Booking): + def create_booking(date_range): + booking = Booking( during=date_range ) + session.add(booking) + session.commit() + return session.query(Booking).first() + return create_booking - self.session.add(booking) - self.session.commit() - return self.session.query(self.Booking).first() - def test_nullify_range(self): - booking = self.create_booking(None) +@pytest.fixture +def init_models(Booking): + pass + + +@pytest.mark.skipif('intervals is None') +class DateRangeTestCase(object): + + def test_nullify_range(self, create_booking): + booking = create_booking(None) assert booking.during is None - @mark.parametrize( + @pytest.mark.parametrize( ('date_range'), ( [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()], @@ -46,38 +56,38 @@ class DateRangeTestCase(TestCase): [-inf, datetime(2015, 1, 1).date()] ) ) - def test_save_date_range(self, date_range): - booking = self.create_booking(date_range) + def test_save_date_range(self, create_booking, date_range): + booking = create_booking(date_range) assert booking.during.lower == date_range[0] assert booking.during.upper == date_range[1] - def test_nullify_date_range(self): - booking = self.Booking( + def test_nullify_date_range(self, session, Booking): + booking = Booking( during=intervals.DateInterval( [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()] ) ) - self.session.add(booking) - self.session.commit() + session.add(booking) + session.commit() - booking = self.session.query(self.Booking).first() + booking = session.query(Booking).first() booking.during = None - self.session.commit() + session.commit() - booking = self.session.query(self.Booking).first() + booking = session.query(Booking).first() assert booking.during is None - def test_integer_coercion(self): - booking = self.Booking(during=datetime(2015, 1, 1).date()) + def test_integer_coercion(self, Booking): + booking = Booking(during=datetime(2015, 1, 1).date()) assert booking.during.lower == datetime(2015, 1, 1).date() assert booking.during.upper == datetime(2015, 1, 1).date() -class TestDateRangeOnPostgres(DateRangeTestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestDateRangeOnPostgres(object): - @mark.parametrize( + @pytest.mark.parametrize( ('date_range', 'length'), ( ( @@ -92,9 +102,16 @@ class TestDateRangeOnPostgres(DateRangeTestCase): ([datetime(2015, 1, 1).date(), inf], None), ) ) - def test_length(self, date_range, length): - self.create_booking(date_range) + def test_length( + self, + session, + Booking, + create_booking, + date_range, + length + ): + create_booking(date_range) query = ( - self.session.query(self.Booking.during.length) + session.query(Booking.during.length) ) assert query.scalar() == length diff --git a/tests/types/test_datetime_range.py b/tests/types/test_datetime_range.py index b6bf731..f25c477 100644 --- a/tests/types/test_datetime_range.py +++ b/tests/types/test_datetime_range.py @@ -1,10 +1,9 @@ from datetime import datetime, timedelta +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import DateTimeRangeType -from tests import TestCase intervals = None inf = 0 @@ -15,30 +14,41 @@ except ImportError: pass -@mark.skipif('intervals is None') -class DateRangeTestCase(TestCase): - def create_models(self): - class Booking(self.Base): - __tablename__ = 'booking' - id = sa.Column(sa.Integer, primary_key=True) - during = sa.Column(DateTimeRangeType) +@pytest.fixture +def Booking(Base): + class Booking(Base): + __tablename__ = 'booking' + id = sa.Column(sa.Integer, primary_key=True) + during = sa.Column(DateTimeRangeType) - self.Booking = Booking + return Booking - def create_booking(self, date_range): - booking = self.Booking( + +@pytest.fixture +def create_booking(session, Booking): + def create_booking(date_range): + booking = Booking( during=date_range ) + session.add(booking) + session.commit() + return session.query(Booking).first() + return create_booking - self.session.add(booking) - self.session.commit() - return self.session.query(self.Booking).first() - def test_nullify_range(self): - booking = self.create_booking(None) +@pytest.fixture +def init_models(Booking): + pass + + +@pytest.mark.skipif('intervals is None') +class DateRangeTestCase(object): + + def test_nullify_range(self, create_booking): + booking = create_booking(None) assert booking.during is None - @mark.parametrize( + @pytest.mark.parametrize( ('date_range'), ( [datetime(2015, 1, 1), datetime(2015, 1, 3)], @@ -46,38 +56,38 @@ class DateRangeTestCase(TestCase): [-inf, datetime(2015, 1, 1)] ) ) - def test_save_date_range(self, date_range): - booking = self.create_booking(date_range) + def test_save_date_range(self, create_booking, date_range): + booking = create_booking(date_range) assert booking.during.lower == date_range[0] assert booking.during.upper == date_range[1] - def test_nullify_date_range(self): - booking = self.Booking( + def test_nullify_date_range(self, session, Booking): + booking = Booking( during=intervals.DateInterval( [datetime(2015, 1, 1), datetime(2015, 1, 3)] ) ) - self.session.add(booking) - self.session.commit() + session.add(booking) + session.commit() - booking = self.session.query(self.Booking).first() + booking = session.query(Booking).first() booking.during = None - self.session.commit() + session.commit() - booking = self.session.query(self.Booking).first() + booking = session.query(Booking).first() assert booking.during is None - def test_integer_coercion(self): - booking = self.Booking(during=datetime(2015, 1, 1)) + def test_integer_coercion(self, Booking): + booking = Booking(during=datetime(2015, 1, 1)) assert booking.during.lower == datetime(2015, 1, 1) assert booking.during.upper == datetime(2015, 1, 1) -class TestDateRangeOnPostgres(DateRangeTestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.mark.usefixtures('postgresql_dsn') +class TestDateRangeOnPostgres(object): - @mark.parametrize( + @pytest.mark.parametrize( ('date_range', 'length'), ( ( @@ -92,9 +102,16 @@ class TestDateRangeOnPostgres(DateRangeTestCase): ([datetime(2015, 1, 1), inf], None), ) ) - def test_length(self, date_range, length): - self.create_booking(date_range) + def test_length( + self, + session, + Booking, + create_booking, + date_range, + length + ): + create_booking(date_range) query = ( - self.session.query(self.Booking.during.length) + session.query(Booking.during.length) ) assert query.scalar() == length diff --git a/tests/types/test_email.py b/tests/types/test_email.py index 5b65f71..9bc5c78 100644 --- a/tests/types/test_email.py +++ b/tests/types/test_email.py @@ -1,28 +1,30 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils import EmailType -from tests import TestCase -class TestEmailType(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - email = sa.Column(EmailType) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + email = sa.Column(EmailType) - def __repr__(self): - return 'User(%r)' % self.id + def __repr__(self): + return 'User(%r)' % self.id + return User - self.User = User - def test_saves_email_as_lowercased(self): - user = self.User( +class TestEmailType(object): + + def test_saves_email_as_lowercased(self, session, User): + user = User( email=u'Someone@example.com' ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.email == u'someone@example.com' diff --git a/tests/types/test_encrypted.py b/tests/types/test_encrypted.py index a91e642..f59bcd1 100644 --- a/tests/types/test_encrypted.py +++ b/tests/types/test_encrypted.py @@ -2,11 +2,9 @@ from datetime import date, datetime, time import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import ColorType, EncryptedType, PhoneNumberType from sqlalchemy_utils.types.encrypted import AesEngine, FernetEngine -from tests import TestCase cryptography = None try: @@ -15,225 +13,278 @@ except ImportError: pass -@mark.skipif('cryptography is None') -class EncryptedTypeTestCase(TestCase): +@pytest.fixture +def User(Base, encryption_engine, test_key): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) - @pytest.fixture(scope='function') - def user(self, request): - # set the values to the user object - self.user = self.User() - self.user.username = self.user_name - self.user.phone = self.user_phone - self.user.color = self.user_color - self.user.date = self.user_date - self.user.time = self.user_time - self.user.enum = self.user_enum - self.user.datetime = self.user_datetime - self.user.access_token = self.test_token - self.user.is_active = self.active - self.user.accounts_num = self.accounts_num - self.session.add(self.user) - self.session.commit() + username = sa.Column(EncryptedType( + sa.Unicode, + test_key, + encryption_engine) + ) - # register a finalizer to cleanup - def finalize(): - del self.user_name - del self.test_token - del self.active - del self.accounts_num - del self.test_key - del self.searched_user + access_token = sa.Column(EncryptedType( + sa.String, + test_key, + encryption_engine) + ) - request.addfinalizer(finalize) + is_active = sa.Column(EncryptedType( + sa.Boolean, + test_key, + encryption_engine) + ) - return self.session.query(self.User).get(self.user.id) + accounts_num = sa.Column(EncryptedType( + sa.Integer, + test_key, + encryption_engine) + ) - def generate_test_token(self): - import string - import random - token = '' - characters = string.ascii_letters + string.digits - for i in range(60): - token += ''.join(random.choice(characters)) - return token + phone = sa.Column(EncryptedType( + PhoneNumberType, + test_key, + encryption_engine) + ) - def create_models(self): - # set some test values - self.test_key = 'secretkey1234' - self.user_name = u'someone' - self.user_phone = u'(555) 555-5555' - self.user_color = u'#fff' - self.user_enum = 'One' - self.user_date = date(2010, 10, 2) - self.user_time = time(10, 12) - self.user_datetime = datetime(2010, 10, 2, 10, 12) - self.test_token = self.generate_test_token() - self.active = True - self.accounts_num = 2 - self.searched_user = None + color = sa.Column(EncryptedType( + ColorType, + test_key, + encryption_engine) + ) - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) + date = sa.Column(EncryptedType( + sa.Date, + test_key, + encryption_engine) + ) - username = sa.Column(EncryptedType( - sa.Unicode, - self.test_key, - self.__class__.encryption_engine) - ) + time = sa.Column(EncryptedType( + sa.Time, + test_key, + encryption_engine) + ) - access_token = sa.Column(EncryptedType( - sa.String, - self.test_key, - self.__class__.encryption_engine) - ) + datetime = sa.Column(EncryptedType( + sa.DateTime, + test_key, + encryption_engine) + ) - is_active = sa.Column(EncryptedType( - sa.Boolean, - self.test_key, - self.__class__.encryption_engine) - ) + enum = sa.Column(EncryptedType( + sa.Enum('One', name='user_enum_t'), + test_key, + encryption_engine) + ) - accounts_num = sa.Column(EncryptedType( - sa.Integer, - self.test_key, - self.__class__.encryption_engine) - ) + return User - phone = sa.Column(EncryptedType( - PhoneNumberType, - self.test_key, - self.__class__.encryption_engine) - ) - color = sa.Column(EncryptedType( - ColorType, - self.test_key, - self.__class__.encryption_engine) - ) +@pytest.fixture +def test_key(): + return 'secretkey1234' - date = sa.Column(EncryptedType( - sa.Date, - self.test_key, - self.__class__.encryption_engine) - ) - time = sa.Column(EncryptedType( - sa.Time, - self.test_key, - self.__class__.encryption_engine) - ) +@pytest.fixture +def user_name(): + return u'someone' - datetime = sa.Column(EncryptedType( - sa.DateTime, - self.test_key, - self.__class__.encryption_engine) - ) - enum = sa.Column(EncryptedType( - sa.Enum('One', name='user_enum_t'), - self.test_key, - self.__class__.encryption_engine) - ) +@pytest.fixture +def user_phone(): + return u'(555) 555-5555' - self.User = User - class Team(self.Base): +@pytest.fixture +def user_color(): + return u'#fff' + + +@pytest.fixture +def user_enum(): + return 'One' + + +@pytest.fixture +def user_date(): + return date(2010, 10, 2) + + +@pytest.fixture +def user_time(): + return time(10, 12) + + +@pytest.fixture +def user_datetime(): + return datetime(2010, 10, 2, 10, 12) + + +@pytest.fixture +def test_token(): + import string + import random + token = '' + characters = string.ascii_letters + string.digits + for i in range(60): + token += ''.join(random.choice(characters)) + return token + + +@pytest.fixture +def active(): + return True + + +@pytest.fixture +def accounts_num(): + return 2 + + +@pytest.fixture +def user( + request, + session, + User, + user_name, + user_phone, + user_color, + user_date, + user_time, + user_enum, + user_datetime, + test_token, + active, + accounts_num +): + # set the values to the user object + user = User() + user.username = user_name + user.phone = user_phone + user.color = user_color + user.date = user_date + user.time = user_time + user.enum = user_enum + user.datetime = user_datetime + user.access_token = test_token + user.is_active = active + user.accounts_num = accounts_num + session.add(user) + session.commit() + + return session.query(User).get(user.id) + + +@pytest.mark.skipif('cryptography is None') +class EncryptedTypeTestCase(object): + + @pytest.fixture + def Team(self, Base, encryption_engine): + self._team_key = None + + class Team(Base): __tablename__ = 'team' id = sa.Column(sa.Integer, primary_key=True) key = sa.Column(sa.String(50)) name = sa.Column(EncryptedType( sa.Unicode, lambda: self._team_key, - self.__class__.encryption_engine) + encryption_engine) ) + return Team - self.Team = Team + @pytest.fixture + def init_models(self, User, Team): + pass - def test_unicode(self, user): - assert user.username == self.user_name + def test_unicode(self, user, user_name): + assert user.username == user_name - def test_string(self, user): - assert user.access_token == self.test_token + def test_string(self, user, test_token): + assert user.access_token == test_token - def test_boolean(self, user): - assert user.is_active == self.active + def test_boolean(self, user, active): + assert user.is_active == active - def test_integer(self, user): - assert user.accounts_num == self.accounts_num + def test_integer(self, user, accounts_num): + assert user.accounts_num == accounts_num - def test_phone_number(self, user): - assert str(user.phone) == self.user_phone + def test_phone_number(self, user, user_phone): + assert str(user.phone) == user_phone - def test_color(self, user): - assert user.color.hex == self.user_color + def test_color(self, user, user_color): + assert user.color.hex == user_color - def test_date(self, user): - assert user.date == self.user_date + def test_date(self, user, user_date): + assert user.date == user_date - def test_datetime(self, user): - assert user.datetime == self.user_datetime + def test_datetime(self, user, user_datetime): + assert user.datetime == user_datetime - def test_time(self, user): - assert user.time == self.user_time + def test_time(self, user, user_time): + assert user.time == user_time - def test_enum(self, user): - assert user.enum == self.user_enum + def test_enum(self, user, user_enum): + assert user.enum == user_enum - def test_lookup_key(self): + def test_lookup_key(self, session, Team): # Add teams self._team_key = 'one' - team = self.Team(key=self._team_key, name=u'One') - self.session.add(team) - self.session.commit() + team = Team(key=self._team_key, name=u'One') + session.add(team) + session.commit() team_1_id = team.id self._team_key = 'two' - team = self.Team(key=self._team_key) + team = Team(key=self._team_key) team.name = u'Two' - self.session.add(team) - self.session.commit() + session.add(team) + session.commit() team_2_id = team.id # Lookup teams - self._team_key = self.session.query(self.Team.key).filter_by( + self._team_key = session.query(Team.key).filter_by( id=team_1_id ).one()[0] - team = self.session.query(self.Team).get(team_1_id) + team = session.query(Team).get(team_1_id) assert team.name == u'One' with pytest.raises(Exception): - self.session.query(self.Team).get(team_2_id) + session.query(Team).get(team_2_id) - self.session.expunge_all() + session.expunge_all() - self._team_key = self.session.query(self.Team.key).filter_by( + self._team_key = session.query(Team.key).filter_by( id=team_2_id ).one()[0] - team = self.session.query(self.Team).get(team_2_id) + team = session.query(Team).get(team_2_id) assert team.name == u'Two' with pytest.raises(Exception): - self.session.query(self.Team).get(team_1_id) + session.query(Team).get(team_1_id) - self.session.expunge_all() + session.expunge_all() # Remove teams - self.session.query(self.Team).delete() - self.session.commit() + session.query(Team).delete() + session.commit() class TestAesEncryptedTypeTestcase(EncryptedTypeTestCase): - encryption_engine = AesEngine + @pytest.fixture + def encryption_engine(self): + return AesEngine - def test_lookup_by_encrypted_string(self, user): - test = self.session.query(self.User).filter( - self.User.username == self.user_name + def test_lookup_by_encrypted_string(self, session, User, user, user_name): + test = session.query(User).filter( + User.username == user_name ).first() assert test.username == user.username @@ -241,4 +292,6 @@ class TestAesEncryptedTypeTestcase(EncryptedTypeTestCase): class TestFernetEncryptedTypeTestCase(EncryptedTypeTestCase): - encryption_engine = FernetEngine + @pytest.fixture + def encryption_engine(self): + return FernetEngine diff --git a/tests/types/test_int_range.py b/tests/types/test_int_range.py index c349496..41c6607 100644 --- a/tests/types/test_int_range.py +++ b/tests/types/test_int_range.py @@ -1,8 +1,7 @@ +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import IntRangeType -from tests import TestCase intervals = None inf = -1 @@ -13,91 +12,102 @@ except ImportError: pass -@mark.skipif('intervals is None') -class NumberRangeTestCase(TestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - persons_at_night = sa.Column(IntRangeType) +@pytest.fixture +def Building(Base): + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + persons_at_night = sa.Column(IntRangeType) - def __repr__(self): - return 'Building(%r)' % self.id + def __repr__(self): + return 'Building(%r)' % self.id + return Building - self.Building = Building - def create_building(self, number_range): - building = self.Building( +@pytest.fixture +def init_models(Building): + pass + + +@pytest.fixture +def create_building(session, Building): + def create_building(number_range): + building = Building( persons_at_night=number_range ) - self.session.add(building) - self.session.commit() - return self.session.query(self.Building).first() + session.add(building) + session.commit() + return session.query(Building).first() + return create_building - def test_nullify_range(self): - building = self.create_building(None) + +@pytest.mark.skipif('intervals is None') +class NumberRangeTestCase(object): + + def test_nullify_range(self, create_building): + building = create_building(None) assert building.persons_at_night is None - def test_update_with_none(self): + def test_update_with_none(self, session, create_building): interval = intervals.IntInterval('(,)') - building = self.create_building(interval) + building = create_building(interval) building.persons_at_night = None assert building.persons_at_night is None - self.session.commit() + session.commit() assert building.persons_at_night is None - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], '1 - 3', ) ) - def test_save_number_range(self, number_range): - building = self.create_building(number_range) + def test_save_number_range(self, create_building, number_range): + building = create_building(number_range) assert building.persons_at_night.lower == 1 assert building.persons_at_night.upper == 3 - def test_infinite_upper_bound(self): - building = self.create_building([1, inf]) + def test_infinite_upper_bound(self, create_building): + building = create_building([1, inf]) assert building.persons_at_night.lower == 1 assert building.persons_at_night.upper == inf - def test_infinite_lower_bound(self): - building = self.create_building([-inf, 1]) + def test_infinite_lower_bound(self, create_building): + building = create_building([-inf, 1]) assert building.persons_at_night.lower == -inf assert building.persons_at_night.upper == 1 - def test_nullify_number_range(self): - building = self.Building( + def test_nullify_number_range(self, session, Building): + building = Building( persons_at_night=intervals.IntInterval([1, 3]) ) - self.session.add(building) - self.session.commit() + session.add(building) + session.commit() - building = self.session.query(self.Building).first() + building = session.query(Building).first() building.persons_at_night = None - self.session.commit() + session.commit() - building = self.session.query(self.Building).first() + building = session.query(Building).first() assert building.persons_at_night is None - def test_string_coercion(self): - building = self.Building(persons_at_night='[12, 18]') + def test_string_coercion(self, Building): + building = Building(persons_at_night='[12, 18]') assert isinstance(building.persons_at_night, intervals.IntInterval) - def test_integer_coercion(self): - building = self.Building(persons_at_night=15) + def test_integer_coercion(self, Building): + building = Building(persons_at_night=15) assert building.persons_at_night.lower == 15 assert building.persons_at_night.upper == 15 +@pytest.mark.usefixtures('postgresql_dsn') class TestIntRangeTypeOnPostgres(NumberRangeTestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], @@ -105,15 +115,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): (0, 4) ) ) - def test_eq_operator(self, number_range): - self.create_building([1, 3]) + def test_eq_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night == number_range) + session.query(Building) + .filter(Building.persons_at_night == number_range) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( ('number_range', 'length'), ( ([1, 3], 2), @@ -125,14 +141,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): ([-3, -1], 2) ) ) - def test_length(self, number_range, length): - self.create_building(number_range) + def test_length( + self, + session, + Building, + create_building, + number_range, + length + ): + create_building(number_range) query = ( - self.session.query(self.Building.persons_at_night.length) + session.query(Building.persons_at_night.length) ) assert query.scalar() == length - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [[1, 3]], @@ -140,15 +163,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): [(0, 4)], ) ) - def test_in_operator(self, number_range): - self.create_building([1, 3]) + def test_in_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night.in_(number_range)) + session.query(Building) + .filter(Building.persons_at_night.in_(number_range)) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], @@ -156,15 +185,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): (0, 4), ) ) - def test_rshift_operator(self, number_range): - self.create_building([5, 6]) + def test_rshift_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([5, 6]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night >> number_range) + session.query(Building) + .filter(Building.persons_at_night >> number_range) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], @@ -172,15 +207,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): (0, 4), ) ) - def test_lshift_operator(self, number_range): - self.create_building([-1, 0]) + def test_lshift_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([-1, 0]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night << number_range) + session.query(Building) + .filter(Building.persons_at_night << number_range) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], @@ -189,15 +230,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): 2 ) ) - def test_contains_operator(self, number_range): - self.create_building([1, 3]) + def test_contains_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night.contains(number_range)) + session.query(Building) + .filter(Building.persons_at_night.contains(number_range)) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], @@ -206,15 +253,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): (-inf, inf) ) ) - def test_contained_by_operator(self, number_range): - self.create_building([1, 3]) + def test_contained_by_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night.contained_by(number_range)) + session.query(Building) + .filter(Building.persons_at_night.contained_by(number_range)) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [2, 5], @@ -222,27 +275,32 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): 0 ) ) - def test_not_in_operator(self, number_range): - self.create_building([1, 3]) + def test_not_in_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(~ self.Building.persons_at_night.in_([number_range])) + session.query(Building) + .filter(~ Building.persons_at_night.in_([number_range])) ) assert query.count() - def test_eq_with_query_arg(self): - self.create_building([1, 3]) + def test_eq_with_query_arg(self, session, Building, create_building): + create_building([1, 3]) query = ( - self.session.query(self.Building) + session.query(Building) .filter( - self.Building.persons_at_night == - self.session.query( - self.Building.persons_at_night) - ).order_by(self.Building.persons_at_night).limit(1) + Building.persons_at_night == + session.query(Building.persons_at_night) + ).order_by(Building.persons_at_night).limit(1) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 2], @@ -253,15 +311,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): 1, ) ) - def test_ge_operator(self, number_range): - self.create_building([1, 3]) + def test_ge_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night >= number_range) + session.query(Building) + .filter(Building.persons_at_night >= number_range) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [0, 2], @@ -269,15 +333,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): [-inf, 2] ) ) - def test_gt_operator(self, number_range): - self.create_building([1, 3]) + def test_gt_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night > number_range) + session.query(Building) + .filter(Building.persons_at_night > number_range) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 4], @@ -285,15 +355,21 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): [2, inf] ) ) - def test_le_operator(self, number_range): - self.create_building([1, 3]) + def test_le_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night <= number_range) + session.query(Building) + .filter(Building.persons_at_night <= number_range) ) assert query.count() - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [2, 4], @@ -301,11 +377,17 @@ class TestIntRangeTypeOnPostgres(NumberRangeTestCase): [1, inf] ) ) - def test_lt_operator(self, number_range): - self.create_building([1, 3]) + def test_lt_operator( + self, + session, + Building, + create_building, + number_range + ): + create_building([1, 3]) query = ( - self.session.query(self.Building) - .filter(self.Building.persons_at_night < number_range) + session.query(Building) + .filter(Building.persons_at_night < number_range) ) assert query.count() diff --git a/tests/types/test_ip_address.py b/tests/types/test_ip_address.py index 3b5a4e5..f1df074 100644 --- a/tests/types/test_ip_address.py +++ b/tests/types/test_ip_address.py @@ -1,31 +1,37 @@ +import pytest import six import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils.types import ip_address -from tests import TestCase -@mark.skipif('ip_address.ip_address is None') -class TestIPAddressType(TestCase): - def create_models(self): - class Visitor(self.Base): - __tablename__ = 'document' - id = sa.Column(sa.Integer, primary_key=True) - ip_address = sa.Column(ip_address.IPAddressType) +@pytest.fixture +def Visitor(Base): + class Visitor(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + ip_address = sa.Column(ip_address.IPAddressType) - def __repr__(self): - return 'Visitor(%r)' % self.id + def __repr__(self): + return 'Visitor(%r)' % self.id + return Visitor - self.Visitor = Visitor - def test_parameter_processing(self): - visitor = self.Visitor( +@pytest.fixture +def init_models(Visitor): + pass + + +@pytest.mark.skipif('ip_address.ip_address is None') +class TestIPAddressType(object): + + def test_parameter_processing(self, session, Visitor): + visitor = Visitor( ip_address=u'111.111.111.111' ) - self.session.add(visitor) - self.session.commit() + session.add(visitor) + session.commit() - visitor = self.session.query(self.Visitor).first() + visitor = session.query(Visitor).first() assert six.text_type(visitor.ip_address) == u'111.111.111.111' diff --git a/tests/types/test_json.py b/tests/types/test_json.py index 8305a36..caed82f 100644 --- a/tests/types/test_json.py +++ b/tests/types/test_json.py @@ -1,59 +1,67 @@ # -*- coding: utf-8 -*- +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils.types import json -from tests import TestCase -class JSONTestCase(TestCase): - def create_models(self): - class Document(self.Base): - __tablename__ = 'document' - id = sa.Column(sa.Integer, primary_key=True) - json = sa.Column(json.JSONType) +@pytest.fixture +def Document(Base): + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + json = sa.Column(json.JSONType) + return Document - self.Document = Document - def test_list(self): - document = self.Document( +@pytest.fixture +def init_models(Document): + pass + + +class JSONTestCase(object): + + def test_list(self, session, Document): + document = Document( json=[1, 2, 3] ) - self.session.add(document) - self.session.commit() + session.add(document) + session.commit() - document = self.session.query(self.Document).first() + document = session.query(Document).first() assert document.json == [1, 2, 3] - def test_parameter_processing(self): - document = self.Document( + def test_parameter_processing(self, session, Document): + document = Document( json={'something': 12} ) - self.session.add(document) - self.session.commit() + session.add(document) + session.commit() - document = self.session.query(self.Document).first() + document = session.query(Document).first() assert document.json == {'something': 12} - def test_non_ascii_chars(self): - document = self.Document( + def test_non_ascii_chars(self, session, Document): + document = Document( json={'something': u'äääööö'} ) - self.session.add(document) - self.session.commit() + session.add(document) + session.commit() - document = self.session.query(self.Document).first() + document = session.query(Document).first() assert document.json == {'something': u'äääööö'} -@mark.skipif('json.json is None') +@pytest.mark.skipif('json.json is None') +@pytest.mark.usefixtures('sqlite_memory_dsn') class TestSqliteJSONType(JSONTestCase): pass -@mark.skipif('json.json is None') +@pytest.mark.skipif('json.json is None') +@pytest.mark.usefixtures('postgresql_dsn') class TestPostgresJSONType(JSONTestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + pass diff --git a/tests/types/test_locale.py b/tests/types/test_locale.py index 7360a22..ee6225b 100644 --- a/tests/types/test_locale.py +++ b/tests/types/test_locale.py @@ -1,51 +1,57 @@ +import pytest import sqlalchemy as sa -from pytest import mark, raises from sqlalchemy_utils.types import locale -from tests import TestCase -@mark.skipif('locale.babel is None') -class TestLocaleType(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - locale = sa.Column(locale.LocaleType) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + locale = sa.Column(locale.LocaleType) - def __repr__(self): - return 'User(%r)' % self.id + def __repr__(self): + return 'User(%r)' % self.id + return User - self.User = User - def test_parameter_processing(self): - user = self.User( +@pytest.fixture +def init_models(User): + pass + + +@pytest.mark.skipif('locale.babel is None') +class TestLocaleType(object): + + def test_parameter_processing(self, session, User): + user = User( locale=locale.babel.Locale(u'fi') ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() - def test_territory_parsing(self): + def test_territory_parsing(self, session, User): ko_kr = locale.babel.Locale(u'ko', territory=u'KR') - user = self.User(locale=ko_kr) - self.session.add(user) - self.session.commit() + user = User(locale=ko_kr) + session.add(user) + session.commit() - assert self.session.query(self.User.locale).first()[0] == ko_kr + assert session.query(User.locale).first()[0] == ko_kr - def test_coerce_territory_parsing(self): - user = self.User() + def test_coerce_territory_parsing(self, User): + user = User() user.locale = 'ko_KR' assert user.locale == locale.babel.Locale(u'ko', territory=u'KR') - def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(locale='en_US') + def test_scalar_attributes_get_coerced_to_objects(self, User): + user = User(locale='en_US') assert isinstance(user.locale, locale.babel.Locale) - def test_unknown_locale_throws_exception(self): - with raises(locale.babel.UnknownLocaleError): - self.User(locale=u'unknown') + def test_unknown_locale_throws_exception(self, User): + with pytest.raises(locale.babel.UnknownLocaleError): + User(locale=u'unknown') diff --git a/tests/types/test_numeric_range.py b/tests/types/test_numeric_range.py index ad40acb..0ffd9bd 100644 --- a/tests/types/test_numeric_range.py +++ b/tests/types/test_numeric_range.py @@ -1,10 +1,9 @@ from decimal import Decimal +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import NumericRangeType -from tests import TestCase intervals = None inf = 0 @@ -15,80 +14,90 @@ except ImportError: pass -@mark.skipif('intervals is None') -class NumericRangeTestCase(TestCase): - def create_models(self): - class Car(self.Base): +@pytest.fixture +def create_car(session, Car): + def create_car(number_range): + car = Car( + price_range=number_range + ) + + session.add(car) + session.commit() + return session.query(Car).first() + return create_car + + +@pytest.mark.skipif('intervals is None') +class NumericRangeTestCase(object): + + @pytest.fixture + def Car(self, Base): + class Car(Base): __tablename__ = 'car' id = sa.Column(sa.Integer, primary_key=True) price_range = sa.Column(NumericRangeType) - self.Car = Car + return Car - def create_car(self, number_range): - car = self.Car( - price_range=number_range - ) + @pytest.fixture + def init_models(self, Car): + pass - self.session.add(car) - self.session.commit() - return self.session.query(self.Car).first() - - def test_nullify_range(self): - car = self.create_car(None) + def test_nullify_range(self, create_car): + car = create_car(None) assert car.price_range is None - @mark.parametrize( + @pytest.mark.parametrize( 'number_range', ( [1, 3], '1 - 3', ) ) - def test_save_number_range(self, number_range): - car = self.create_car(number_range) + def test_save_number_range(self, create_car, number_range): + car = create_car(number_range) assert car.price_range.lower == 1 assert car.price_range.upper == 3 - def test_infinite_upper_bound(self): - car = self.create_car([1, inf]) + def test_infinite_upper_bound(self, create_car): + car = create_car([1, inf]) assert car.price_range.lower == 1 assert car.price_range.upper == inf - def test_infinite_lower_bound(self): - car = self.create_car([-inf, 1]) + def test_infinite_lower_bound(self, create_car): + car = create_car([-inf, 1]) assert car.price_range.lower == -inf assert car.price_range.upper == 1 - def test_nullify_number_range(self): - car = self.Car( + def test_nullify_number_range(self, session, Car): + car = Car( price_range=intervals.DecimalInterval([1, 3]) ) - self.session.add(car) - self.session.commit() + session.add(car) + session.commit() - car = self.session.query(self.Car).first() + car = session.query(Car).first() car.price_range = None - self.session.commit() + session.commit() - car = self.session.query(self.Car).first() + car = session.query(Car).first() assert car.price_range is None - def test_string_coercion(self): - car = self.Car(price_range='[12, 18]') + def test_string_coercion(self, Car): + car = Car(price_range='[12, 18]') assert isinstance(car.price_range, intervals.DecimalInterval) - def test_integer_coercion(self): - car = self.Car(price_range=15) + def test_integer_coercion(self, Car): + car = Car(price_range=15) assert car.price_range.lower == 15 assert car.price_range.upper == 15 +@pytest.mark.usefixtures('postgresql_dsn') class TestNumericRangeOnPostgres(NumericRangeTestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' - @mark.parametrize( + @pytest.mark.parametrize( ('number_range', 'length'), ( ([1, 3], 2), @@ -100,43 +109,39 @@ class TestNumericRangeOnPostgres(NumericRangeTestCase): ([-3, -1], 2) ) ) - def test_length(self, number_range, length): - self.create_car(number_range) + def test_length(self, session, Car, create_car, number_range, length): + create_car(number_range) query = ( - self.session.query(self.Car.price_range.length) + session.query(Car.price_range.length) ) assert query.scalar() == length -@mark.skipif('intervals is None') -class TestNumericRangeWithStep(TestCase): - def create_models(self): - class Car(self.Base): +@pytest.mark.skipif('intervals is None') +class TestNumericRangeWithStep(object): + + @pytest.fixture + def Car(self, Base): + class Car(Base): __tablename__ = 'car' id = sa.Column(sa.Integer, primary_key=True) price_range = sa.Column(NumericRangeType(step=Decimal('0.5'))) + return Car - self.Car = Car + @pytest.fixture + def init_models(self, Car): + pass - def create_car(self, number_range): - car = self.Car( - price_range=number_range - ) - - self.session.add(car) - self.session.commit() - return self.session.query(self.Car).first() - - def test_passes_step_argument_to_interval_object(self): - car = self.create_car([Decimal('0.2'), Decimal('0.8')]) + def test_passes_step_argument_to_interval_object(self, create_car): + car = create_car([Decimal('0.2'), Decimal('0.8')]) assert car.price_range.lower == Decimal('0') assert car.price_range.upper == Decimal('1') assert car.price_range.step == Decimal('0.5') - def test_passes_step_fetched_objects(self): - self.create_car([Decimal('0.2'), Decimal('0.8')]) - self.session.expunge_all() - car = self.session.query(self.Car).first() + def test_passes_step_fetched_objects(self, session, Car, create_car): + create_car([Decimal('0.2'), Decimal('0.8')]) + session.expunge_all() + car = session.query(Car).first() assert car.price_range.lower == Decimal('0') assert car.price_range.upper == Decimal('1') assert car.price_range.step == Decimal('0.5') diff --git a/tests/types/test_password.py b/tests/types/test_password.py index 474ad94..e3c2d8a 100644 --- a/tests/types/test_password.py +++ b/tests/types/test_password.py @@ -1,61 +1,67 @@ +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy import inspect from sqlalchemy_utils import Password, PasswordType, types # noqa -from tests import TestCase -@mark.skipif('types.password.passlib is None') -class TestPasswordType(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - password = sa.Column(PasswordType( - schemes=[ - 'pbkdf2_sha512', - 'pbkdf2_sha256', - 'md5_crypt', - 'hex_md5' - ], +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + password = sa.Column(PasswordType( + schemes=[ + 'pbkdf2_sha512', + 'pbkdf2_sha256', + 'md5_crypt', + 'hex_md5' + ], - deprecated=['md5_crypt', 'hex_md5'] - )) + deprecated=['md5_crypt', 'hex_md5'] + )) - def __repr__(self): - return 'User(%r)' % self.id + def __repr__(self): + return 'User(%r)' % self.id + return User - self.User = User - def test_encrypt(self): +@pytest.fixture +def init_models(User): + pass + + +@pytest.mark.skipif('types.password.passlib is None') +class TestPasswordType(object): + + def test_encrypt(self, User): """Should encrypt the password on setting the attribute.""" - obj = self.User() + obj = User() obj.password = b'b' assert obj.password.hash != 'b' assert obj.password.hash.startswith(b'$pbkdf2-sha512$') - def test_check(self): + def test_check(self, session, User): """ Should be able to compare the plaintext against the encrypted form. """ - obj = self.User() + obj = User() obj.password = 'b' assert obj.password == 'b' assert obj.password != 'a' - self.session.add(obj) - self.session.commit() + session.add(obj) + session.commit() - obj = self.session.query(self.User).get(obj.id) + obj = session.query(User).get(obj.id) assert obj.password == b'b' assert obj.password != 'a' - def test_check_and_update(self): + def test_check_and_update(self, User): """ Should be able to compare the plaintext against a deprecated encrypted form and have it auto-update to the preferred version. @@ -63,20 +69,20 @@ class TestPasswordType(TestCase): from passlib.hash import md5_crypt - obj = self.User() + obj = User() obj.password = Password(md5_crypt.encrypt('b')) assert obj.password.hash.decode('utf8').startswith('$1$') assert obj.password == 'b' assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha512$') - def test_auto_column_length(self): + def test_auto_column_length(self, User): """Should derive the correct column length from the specified schemes. """ from passlib.hash import pbkdf2_sha512 - kind = inspect(self.User).c.password.type + kind = inspect(User).c.password.type # name + rounds + salt + hash + ($ * 4) of largest hash expected_length = len(pbkdf2_sha512.name) @@ -90,55 +96,55 @@ class TestPasswordType(TestCase): def test_without_schemes(self): assert PasswordType(schemes=[]).length == 1024 - def test_compare(self): + def test_compare(self, User): from passlib.hash import md5_crypt - obj = self.User() + obj = User() obj.password = Password(md5_crypt.encrypt('b')) - other = self.User() + other = User() other.password = Password(md5_crypt.encrypt('b')) # Not sure what to assert here; the test raised an error before. assert obj.password != other.password - def test_set_none(self): + def test_set_none(self, session, User): - obj = self.User() + obj = User() obj.password = None assert obj.password is None - self.session.add(obj) - self.session.commit() + session.add(obj) + session.commit() - obj = self.session.query(self.User).get(obj.id) + obj = session.query(User).get(obj.id) assert obj.password is None - def test_update_none(self): + def test_update_none(self, session, User): """ Should be able to change a password from ``None`` to a valid password. """ - obj = self.User() + obj = User() obj.password = None - self.session.add(obj) - self.session.commit() + session.add(obj) + session.commit() - obj = self.session.query(self.User).get(obj.id) + obj = session.query(User).get(obj.id) obj.password = 'b' - self.session.commit() + session.commit() - def test_compare_none(self): + def test_compare_none(self, User): """ Should be able to compare a password of ``None``. """ - obj = self.User() + obj = User() obj.password = None assert obj.password is None @@ -149,7 +155,7 @@ class TestPasswordType(TestCase): assert obj.password is not None assert obj.password != None # noqa - def test_check_and_update_persist(self): + def test_check_and_update_persist(self, session, User): """ When a password is compared, the hash should update if needed to change the algorithm; and, commit to the database. @@ -157,18 +163,18 @@ class TestPasswordType(TestCase): from passlib.hash import md5_crypt - obj = self.User() + obj = User() obj.password = Password(md5_crypt.encrypt('b')) - self.session.add(obj) - self.session.commit() + session.add(obj) + session.commit() assert obj.password.hash.decode('utf8').startswith('$1$') assert obj.password == 'b' - self.session.commit() + session.commit() - obj = self.session.query(self.User).get(obj.id) + obj = session.query(User).get(obj.id) assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha512$') assert obj.password == 'b' diff --git a/tests/types/test_phonenumber.py b/tests/types/test_phonenumber.py index dd8e3dc..6e3e719 100644 --- a/tests/types/test_phonenumber.py +++ b/tests/types/test_phonenumber.py @@ -1,37 +1,76 @@ +import pytest import six import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils import PhoneNumber, PhoneNumberType, types # noqa -from tests import TestCase -@mark.skipif('types.phone_number.phonenumbers is None') +@pytest.fixture +def valid_phone_numbers(): + return [ + '040 1234567', + '+358 401234567', + '09 2501234', + '+358 92501234', + '0800 939393', + '09 4243 0456', + '0600 900 500' + ] + + +@pytest.fixture +def invalid_phone_numbers(): + return [ + 'abc', + '+040 1234567', + '0111234567', + '358' + ] + + +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + phone_number = sa.Column(PhoneNumberType()) + return User + + +@pytest.fixture +def init_models(User): + pass + + +@pytest.fixture +def phone_number(): + return PhoneNumber( + '040 1234567', + 'FI' + ) + + +@pytest.fixture +def user(session, User, phone_number): + user = User() + user.name = u'Someone' + user.phone_number = phone_number + session.add(user) + session.commit() + return user + + +@pytest.mark.skipif('types.phone_number.phonenumbers is None') class TestPhoneNumber(object): - def setup_method(self, method): - self.valid_phone_numbers = [ - '040 1234567', - '+358 401234567', - '09 2501234', - '+358 92501234', - '0800 939393', - '09 4243 0456', - '0600 900 500' - ] - self.invalid_phone_numbers = [ - 'abc', - '+040 1234567', - '0111234567', - '358' - ] - def test_valid_phone_numbers(self): - for raw_number in self.valid_phone_numbers: + def test_valid_phone_numbers(self, valid_phone_numbers): + for raw_number in valid_phone_numbers: number = PhoneNumber(raw_number, 'FI') assert number.is_valid_number() - def test_invalid_phone_numbers(self): - for raw_number in self.invalid_phone_numbers: + def test_invalid_phone_numbers(self, invalid_phone_numbers): + for raw_number in invalid_phone_numbers: try: number = PhoneNumber(raw_number, 'FI') assert not number.is_valid_number() @@ -53,73 +92,59 @@ class TestPhoneNumber(object): assert str(number) == number.national -@mark.skipif('types.phone_number.phonenumbers is None') -class TestPhoneNumberType(TestCase): +@pytest.mark.skipif('types.phone_number.phonenumbers is None') +class TestPhoneNumberType(object): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - phone_number = sa.Column(PhoneNumberType()) + def test_query_returns_phone_number_object( + self, + session, + User, + user, + phone_number + ): + queried_user = session.query(User).first() + assert queried_user.phone_number == phone_number - self.User = User - - def setup_method(self, method): - super(TestPhoneNumberType, self).setup_method(method) - self.phone_number = PhoneNumber( - '040 1234567', - 'FI' - ) - self.user = self.User() - self.user.name = u'Someone' - self.user.phone_number = self.phone_number - self.session.add(self.user) - self.session.commit() - - def test_query_returns_phone_number_object(self): - queried_user = self.session.query(self.User).first() - assert queried_user.phone_number == self.phone_number - - def test_phone_number_is_stored_as_string(self): - result = self.session.execute( + def test_phone_number_is_stored_as_string(self, session, user): + result = session.execute( 'SELECT phone_number FROM "user" WHERE id=:param', - {'param': self.user.id} + {'param': user.id} ) assert result.first()[0] == u'+358401234567' - def test_phone_number_with_extension(self): - user = self.User(phone_number='555-555-5555 Ext. 555') + def test_phone_number_with_extension(self, session, User): + user = User(phone_number='555-555-5555 Ext. 555') - self.session.add(user) - self.session.commit() - self.session.refresh(user) + session.add(user) + session.commit() + session.refresh(user) assert user.phone_number.extension == '555' - def test_empty_phone_number_is_equiv_to_none(self): - user = self.User(phone_number='') + def test_empty_phone_number_is_equiv_to_none(self, session, User): + user = User(phone_number='') - self.session.add(user) - self.session.commit() - self.session.refresh(user) + session.add(user) + session.commit() + session.refresh(user) assert user.phone_number is None - def test_phone_number_is_none(self): + @pytest.mark.usefixtures('user') + def test_phone_number_is_none(self, session, User): phone_number = None - user = self.User() + user = User() user.name = u'Someone' user.phone_number = phone_number - self.session.add(user) - self.session.commit() - queried_user = self.session.query(self.User)[1] + session.add(user) + session.commit() + queried_user = session.query(User)[1] assert queried_user.phone_number is None - result = self.session.execute( + result = session.execute( 'SELECT phone_number FROM "user" WHERE id=:param', {'param': user.id} ) assert result.first()[0] is None - def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(phone_number='050111222') + def test_scalar_attributes_get_coerced_to_objects(self, User): + user = User(phone_number='050111222') assert isinstance(user.phone_number, PhoneNumber) diff --git a/tests/types/test_scalar_list.py b/tests/types/test_scalar_list.py index 3f556fc..2a7c852 100644 --- a/tests/types/test_scalar_list.py +++ b/tests/types/test_scalar_list.py @@ -1,14 +1,15 @@ +import pytest import six import sqlalchemy as sa -from pytest import raises from sqlalchemy_utils import ScalarListType -from tests import TestCase -class TestScalarIntegerList(TestCase): - def create_models(self): - class User(self.Base): +class TestScalarIntegerList(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) some_list = sa.Column(ScalarListType(int)) @@ -16,23 +17,29 @@ class TestScalarIntegerList(TestCase): def __repr__(self): return 'User(%r)' % self.id - self.User = User + return User - def test_save_integer_list(self): - user = self.User( + @pytest.fixture + def init_models(self, User): + pass + + def test_save_integer_list(self, session, User): + user = User( some_list=[1, 2, 3, 4] ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.some_list == [1, 2, 3, 4] -class TestScalarUnicodeList(TestCase): - def create_models(self): - class User(self.Base): +class TestScalarUnicodeList(object): + + @pytest.fixture + def User(self, Base): + class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) some_list = sa.Column(ScalarListType(six.text_type)) @@ -40,40 +47,48 @@ class TestScalarUnicodeList(TestCase): def __repr__(self): return 'User(%r)' % self.id - self.User = User + return User - def test_throws_exception_if_using_separator_in_list_values(self): - user = self.User( + @pytest.fixture + def init_models(self, User): + pass + + def test_throws_exception_if_using_separator_in_list_values( + self, + session, + User + ): + user = User( some_list=[u','] ) - self.session.add(user) - with raises(sa.exc.StatementError) as db_err: - self.session.commit() + session.add(user) + with pytest.raises(sa.exc.StatementError) as db_err: + session.commit() assert ( "List values can't contain string ',' (its being used as " "separator. If you wish for scalar list values to contain " "these strings, use a different separator string.)" ) in str(db_err.value) - def test_save_unicode_list(self): - user = self.User( + def test_save_unicode_list(self, session, User): + user = User( some_list=[u'1', u'2', u'3', u'4'] ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.some_list == [u'1', u'2', u'3', u'4'] - def test_save_and_retrieve_empty_list(self): - user = self.User( + def test_save_and_retrieve_empty_list(self, session, User): + user = User( some_list=[] ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert user.some_list == [] diff --git a/tests/types/test_timezone.py b/tests/types/test_timezone.py index 3d8a11c..502288f 100644 --- a/tests/types/test_timezone.py +++ b/tests/types/test_timezone.py @@ -1,39 +1,46 @@ +import pytest import sqlalchemy as sa from sqlalchemy_utils.types import timezone -from tests import TestCase -class TestTimezoneType(TestCase): - def create_models(self): - class Visitor(self.Base): - __tablename__ = 'visitor' - id = sa.Column(sa.Integer, primary_key=True) - timezone_dateutil = sa.Column( - timezone.TimezoneType(backend='dateutil') - ) - timezone_pytz = sa.Column( - timezone.TimezoneType(backend='pytz') - ) +@pytest.fixture +def Visitor(Base): + class Visitor(Base): + __tablename__ = 'visitor' + id = sa.Column(sa.Integer, primary_key=True) + timezone_dateutil = sa.Column( + timezone.TimezoneType(backend='dateutil') + ) + timezone_pytz = sa.Column( + timezone.TimezoneType(backend='pytz') + ) - def __repr__(self): - return 'Visitor(%r)' % self.id + def __repr__(self): + return 'Visitor(%r)' % self.id + return Visitor - self.Visitor = Visitor - def test_parameter_processing(self): - visitor = self.Visitor( +@pytest.fixture +def init_models(Visitor): + pass + + +class TestTimezoneType(object): + + def test_parameter_processing(self, session, Visitor): + visitor = Visitor( timezone_dateutil=u'America/Los_Angeles', timezone_pytz=u'America/Los_Angeles' ) - self.session.add(visitor) - self.session.commit() + session.add(visitor) + session.commit() - visitor_dateutil = self.session.query(self.Visitor).filter_by( + visitor_dateutil = session.query(Visitor).filter_by( timezone_dateutil=u'America/Los_Angeles' ).first() - visitor_pytz = self.session.query(self.Visitor).filter_by( + visitor_pytz = session.query(Visitor).filter_by( timezone_pytz=u'America/Los_Angeles' ).first() diff --git a/tests/types/test_tsvector.py b/tests/types/test_tsvector.py index 5a75a91..959a17f 100644 --- a/tests/types/test_tsvector.py +++ b/tests/types/test_tsvector.py @@ -1,37 +1,44 @@ +import pytest import sqlalchemy as sa from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy_utils import TSVectorType -from tests import TestCase -class TestTSVector(TestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + search_index = sa.Column( + TSVectorType(name, regconfig='pg_catalog.finnish') + ) - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - search_index = sa.Column( - TSVectorType(name, regconfig='pg_catalog.finnish') - ) + def __repr__(self): + return 'User(%r)' % self.id + return User - def __repr__(self): - return 'User(%r)' % self.id - self.User = User +@pytest.fixture +def init_models(User): + pass - def test_generates_table(self): - assert 'search_index' in self.User.__table__.c - def test_type_reflection(self): +@pytest.mark.usefixtures('postgresql_dsn') +class TestTSVector(object): + + def test_generates_table(self, User): + assert 'search_index' in User.__table__.c + + @pytest.mark.usefixtures('session') + def test_type_reflection(self, engine): reflected_metadata = sa.schema.MetaData() table = sa.schema.Table( 'user', reflected_metadata, autoload=True, - autoload_with=self.engine + autoload_with=engine ) assert isinstance(table.c['search_index'].type, TSVECTOR) @@ -40,32 +47,32 @@ class TestTSVector(TestCase): assert type_.columns == ('name', 'age') assert type_.options['regconfig'] == 'pg_catalog.simple' - def test_match(self): - expr = self.User.search_index.match(u'something') - assert str(expr.compile(self.connection)) == ( + def test_match(self, connection, User): + expr = User.search_index.match(u'something') + assert str(expr.compile(connection)) == ( '''"user".search_index @@ to_tsquery('pg_catalog.finnish', ''' '''%(search_index_1)s)''' ) - def test_concat(self): - assert str(self.User.search_index | self.User.search_index) == ( + def test_concat(self, User): + assert str(User.search_index | User.search_index) == ( '"user".search_index || "user".search_index' ) - def test_match_concatenation(self): - concat = self.User.search_index | self.User.search_index - bind = self.session.bind + def test_match_concatenation(self, session, User): + concat = User.search_index | User.search_index + bind = session.bind assert str(concat.match('something').compile(bind)) == ( '("user".search_index || "user".search_index) @@ ' "to_tsquery('pg_catalog.finnish', %(param_1)s)" ) - def test_match_with_catalog(self): - expr = self.User.search_index.match( + def test_match_with_catalog(self, connection, User): + expr = User.search_index.match( u'something', postgresql_regconfig='pg_catalog.simple' ) - assert str(expr.compile(self.connection)) == ( + assert str(expr.compile(connection)) == ( '''"user".search_index @@ to_tsquery('pg_catalog.simple', ''' '''%(search_index_1)s)''' ) diff --git a/tests/types/test_url.py b/tests/types/test_url.py index 63970d4..c1efb6f 100644 --- a/tests/types/test_url.py +++ b/tests/types/test_url.py @@ -1,35 +1,41 @@ +import pytest import sqlalchemy as sa -from pytest import mark from sqlalchemy_utils.types import url -from tests import TestCase -@mark.skipif('url.furl is None') -class TestURLType(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - website = sa.Column(url.URLType) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + website = sa.Column(url.URLType) - def __repr__(self): - return 'User(%r)' % self.id + def __repr__(self): + return 'User(%r)' % self.id + return User - self.User = User - def test_color_parameter_processing(self): - user = self.User( +@pytest.fixture +def init_models(User): + pass + + +@pytest.mark.skipif('url.furl is None') +class TestURLType(object): + + def test_color_parameter_processing(self, session, User): + user = User( website=url.furl(u'www.example.com') ) - self.session.add(user) - self.session.commit() + session.add(user) + session.commit() - user = self.session.query(self.User).first() + user = session.query(User).first() assert isinstance(user.website, url.furl) - def test_scalar_attributes_get_coerced_to_objects(self): - user = self.User(website=u'www.example.com') + def test_scalar_attributes_get_coerced_to_objects(self, User): + user = User(website=u'www.example.com') assert isinstance(user.website, url.furl) diff --git a/tests/types/test_uuid.py b/tests/types/test_uuid.py index 5bc4437..4ef7901 100644 --- a/tests/types/test_uuid.py +++ b/tests/types/test_uuid.py @@ -1,35 +1,42 @@ import uuid +import pytest import sqlalchemy as sa from sqlalchemy_utils import UUIDType -from tests import TestCase -class TestUUIDType(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True) +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True) - def __repr__(self): - return 'User(%r)' % self.id + def __repr__(self): + return 'User(%r)' % self.id + return User - self.User = User - def test_commit(self): - obj = self.User() +@pytest.fixture +def init_models(User): + pass + + +class TestUUIDType(object): + + def test_commit(self, session, User): + obj = User() obj.id = uuid.uuid4().hex - self.session.add(obj) - self.session.commit() + session.add(obj) + session.commit() - u = self.session.query(self.User).one() + u = session.query(User).one() assert u.id == obj.id - def test_coerce(self): - obj = self.User() + def test_coerce(self, User): + obj = User() obj.id = identifier = uuid.uuid4().hex assert isinstance(obj.id, uuid.UUID) diff --git a/tests/types/test_weekdays.py b/tests/types/test_weekdays.py index adabc60..48bd2ef 100644 --- a/tests/types/test_weekdays.py +++ b/tests/types/test_weekdays.py @@ -5,49 +5,60 @@ import sqlalchemy as sa from sqlalchemy_utils import i18n from sqlalchemy_utils.primitives import WeekDays from sqlalchemy_utils.types import WeekDaysType -from tests import TestCase +@pytest.fixture +def Schedule(Base): + class Schedule(Base): + __tablename__ = 'schedule' + id = sa.Column(sa.Integer, primary_key=True) + working_days = sa.Column(WeekDaysType) + + def __repr__(self): + return 'Schedule(%r)' % self.id + return Schedule + + +@pytest.fixture +def init_models(Schedule): + pass + + +@pytest.fixture +def set_get_locale(): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + +@pytest.mark.usefixtures('set_get_locale') @pytest.mark.skipif('i18n.babel is None') -class WeekDaysTypeTestCase(TestCase): - def setup_method(self, method): - TestCase.setup_method(self, method) - i18n.get_locale = lambda: i18n.babel.Locale('en') +class WeekDaysTypeTestCase(object): - def create_models(self): - class Schedule(self.Base): - __tablename__ = 'schedule' - id = sa.Column(sa.Integer, primary_key=True) - working_days = sa.Column(WeekDaysType) - - def __repr__(self): - return 'Schedule(%r)' % self.id - - self.Schedule = Schedule - - def test_color_parameter_processing(self): - schedule = self.Schedule( + def test_color_parameter_processing(self, session, Schedule): + schedule = Schedule( working_days=b'0001111' ) - self.session.add(schedule) - self.session.commit() + session.add(schedule) + session.commit() - schedule = self.session.query(self.Schedule).first() + schedule = session.query(Schedule).first() assert isinstance(schedule.working_days, WeekDays) - def test_scalar_attributes_get_coerced_to_objects(self): - schedule = self.Schedule(working_days=b'1010101') + def test_scalar_attributes_get_coerced_to_objects(self, Schedule): + schedule = Schedule(working_days=b'1010101') assert isinstance(schedule.working_days, WeekDays) +@pytest.mark.usefixtures('sqlite_memory_dsn') class TestWeekDaysTypeOnSQLite(WeekDaysTypeTestCase): - dns = 'sqlite:///:memory:' + pass +@pytest.mark.usefixtures('postgresql_dsn') class TestWeekDaysTypeOnPostgres(WeekDaysTypeTestCase): - dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + pass +@pytest.mark.usefixtures('mysql_dsn') class TestWeekDaysTypeOnMySQL(WeekDaysTypeTestCase): - dns = 'mysql+pymysql://travis@localhost/sqlalchemy_utils_test' + pass diff --git a/tox.ini b/tox.ini index de28324..ef84916 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,36 @@ [tox] -envlist = py26, py27, py33, py34 +envlist = py26, py27, py33, py34, py35, lint [testenv] -commands = py.test {posargs} +commands = + py.test sqlalchemy_utils tests deps = - SQLAlchemy==1.0.4 .[test_all] +passenv = SQLALCHEMY_UTILS_TEST_DB SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER SQLALCHEMY_UTILS_TEST_MYSQL_USER + +[testenv:py26] +recreate = True + +[testenv:py27] +recreate = True + +[testenv:py33] +recreate = True + +[testenv:py34] +recreate = True + +[testenv:py35] +recreate = True + +[testenv:lint] +recreate = True +commands = + flake8 sqlalchemy_utils tests + isort --verbose --recursive --diff sqlalchemy_utils tests + isort --verbose --recursive --check-only sqlalchemy_utils tests +skip_install = True +deps = + .[test_all] + flake8>=2.5.0 + isort==4.2.2