Files
deb-python-sqlalchemy-utils/tests/__init__.py
Konsta Vesterinen 204aba376d Fix babel primitive types
- Make babel dependent primitive types to use Locale('en') for data
validation instead of current locale. Using current locale leads to
infinite recursion in cases where the loaded data has dependency to
the loaded object's locale.
2015-06-17 19:36:46 +03:00

125 lines
3.5 KiB
Python

import warnings
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import (
aggregates,
coercion_listener,
i18n,
InstrumentedList
)
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
try:
conn.query_count += 1
except AttributeError:
conn.query_count = 0
warnings.simplefilter('error', sa.exc.SAWarning)
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
def get_locale():
class Locale():
territories = {'FI': 'Finland'}
return Locale()
class TestCase(object):
dns = 'sqlite:///:memory:'
create_tables = True
def setup_method(self, method):
self.engine = create_engine(self.dns)
# self.engine.echo = True
self.connection = self.engine.connect()
self.Base = declarative_base()
self.create_models()
sa.orm.configure_mappers()
if self.create_tables:
self.Base.metadata.create_all(self.connection)
Session = sessionmaker(bind=self.connection)
self.session = Session()
i18n.get_locale = get_locale
def teardown_method(self, method):
aggregates.manager.reset()
self.session.close_all()
if self.create_tables:
self.Base.metadata.drop_all(self.connection)
remove_composite_listeners()
self.connection.close()
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def articles_count(self):
return len(self.articles)
@articles_count.expression
def articles_count(cls):
return (
sa.select([sa.func.count(self.Article.id)])
.where(self.Article.category_id == self.Category.id)
.correlate(self.Article.__table__)
.label('article_count')
)
@property
def name_alias(self):
return self.name
@synonym_for('name')
@property
def name_synonym(self):
return self.name
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255), index=True)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
self.User = User
self.Category = Category
self.Article = Article
def assert_contains(clause, query):
# Test that query executes
query.all()
assert clause in str(query)