From 26db1397d55952dd243217f395103e5d0ffe0ba1 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 26 Mar 2013 13:03:43 +0200 Subject: [PATCH] Added NumberRange types, refactored file structure --- sqlalchemy_utils/__init__.py | 399 +++----------------------------- sqlalchemy_utils/functions.py | 159 +++++++++++++ sqlalchemy_utils/merge.py | 123 ++++++++++ sqlalchemy_utils/types.py | 179 ++++++++++++++ tests.py | 397 ------------------------------- tests/__init__.py | 86 +++++++ tests/test_instrumented_list.py | 14 ++ tests/test_merge.py | 196 ++++++++++++++++ tests/test_phonenumber_type.py | 65 ++++++ tests/test_utility_functions.py | 47 ++++ 10 files changed, 896 insertions(+), 769 deletions(-) create mode 100644 sqlalchemy_utils/functions.py create mode 100644 sqlalchemy_utils/merge.py create mode 100644 sqlalchemy_utils/types.py delete mode 100644 tests.py create mode 100644 tests/__init__.py create mode 100644 tests/test_instrumented_list.py create mode 100644 tests/test_merge.py create mode 100644 tests/test_phonenumber_type.py create mode 100644 tests/test_utility_functions.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 8ee83a6..d57bb63 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,372 +1,27 @@ -import phonenumbers -from functools import wraps -import sqlalchemy as sa -from sqlalchemy.engine import reflection -from sqlalchemy.orm import defer, object_session, mapperlib -from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList -from sqlalchemy.orm.mapper import Mapper -from sqlalchemy.orm.query import _ColumnEntity -from sqlalchemy.orm.properties import ColumnProperty -from sqlalchemy.sql.expression import desc, asc -from sqlalchemy import types - - -class PhoneNumber(phonenumbers.phonenumber.PhoneNumber): - ''' - Extends a PhoneNumber class from `Python phonenumbers library`_. Adds - different phone number formats to attributes, so they can be easily used - in templates. Phone number validation method is also implemented. - - Takes the raw phone number and country code as params and parses them - into a PhoneNumber object. - - .. _Python phonenumbers library: - https://github.com/daviddrysdale/python-phonenumbers - - :param raw_number: - String representation of the phone number. - :param country_code: - Country code of the phone number. - ''' - def __init__(self, raw_number, country_code=None): - self._phone_number = phonenumbers.parse(raw_number, country_code) - super(PhoneNumber, self).__init__( - country_code=self._phone_number.country_code, - national_number=self._phone_number.national_number, - extension=self._phone_number.extension, - italian_leading_zero=self._phone_number.italian_leading_zero, - raw_input=self._phone_number.raw_input, - country_code_source=self._phone_number.country_code_source, - preferred_domestic_carrier_code= - self._phone_number.preferred_domestic_carrier_code - ) - self.national = phonenumbers.format_number( - self._phone_number, - phonenumbers.PhoneNumberFormat.NATIONAL - ) - self.international = phonenumbers.format_number( - self._phone_number, - phonenumbers.PhoneNumberFormat.INTERNATIONAL - ) - self.e164 = phonenumbers.format_number( - self._phone_number, - phonenumbers.PhoneNumberFormat.E164 - ) - - def is_valid_number(self): - return phonenumbers.is_valid_number(self._phone_number) - - -class PhoneNumberType(types.TypeDecorator): - """ - Changes PhoneNumber objects to a string representation on the way in and - changes them back to PhoneNumber objects on the way out. If E164 is used - as storing format, no country code is needed for parsing the database - value to PhoneNumber object. - """ - STORE_FORMAT = 'e164' - impl = types.Unicode(20) - - def __init__(self, country_code='US', max_length=20, *args, **kwargs): - super(PhoneNumberType, self).__init__(*args, **kwargs) - self.country_code = country_code - self.impl = types.Unicode(max_length) - - def process_bind_param(self, value, dialect): - return getattr(value, self.STORE_FORMAT) - - def process_result_value(self, value, dialect): - return PhoneNumber(value, self.country_code) - - -class InstrumentedList(_InstrumentedList): - """Enhanced version of SQLAlchemy InstrumentedList. Provides some - additional functionality.""" - - def any(self, attr): - return any(getattr(item, attr) for item in self) - - def all(self, attr): - return all(getattr(item, attr) for item in self) - - -def instrumented_list(f): - @wraps(f) - def wrapper(*args, **kwargs): - return InstrumentedList([item for item in f(*args, **kwargs)]) - return wrapper - - -def sort_query(query, sort): - """ - Applies an sql ORDER BY for given query. This function can be easily used - with user-defined sorting. - - The examples use the following model definition: - - >>> import sqlalchemy as sa - >>> from sqlalchemy import create_engine - >>> from sqlalchemy.orm import sessionmaker - >>> from sqlalchemy.ext.declarative import declarative_base - >>> from sqlalchemy_utils import sort_query - >>> - >>> - >>> engine = create_engine( - ... 'sqlite:///' - ... ) - >>> Base = declarative_base() - >>> Session = sessionmaker(bind=engine) - >>> session = Session() - >>> - >>> class Category(Base): - ... __tablename__ = 'category' - ... id = sa.Column(sa.Integer, primary_key=True) - ... name = sa.Column(sa.Unicode(255)) - >>> - >>> class Article(Base): - ... __tablename__ = 'article' - ... id = sa.Column(sa.Integer, primary_key=True) - ... name = sa.Column(sa.Unicode(255)) - ... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) - ... - ... category = sa.orm.relationship( - ... Category, primaryjoin=category_id == Category.id - ... ) - - - - 1. Applying simple ascending sort - - >>> query = session.query(Article) - >>> query = sort_query(query, 'name') - - 2. Appying descending sort - - >>> query = sort_query(query, '-name') - - 3. Applying sort to custom calculated label - - >>> query = session.query( - ... Category, db.func.count(Article.id).label('articles') - ... ) - >>> query = sort_query(query, 'articles') - - 4. Applying sort to joined table column - - >>> query = session.query(Article).join(Article.category) - >>> query = sort_query(query, 'category-name') - - - :param query: query to be modified - :param sort: string that defines the label or column to sort the query by - :param errors: whether or not to raise exceptions if unknown sort column - is passed - """ - entities = [entity.entity_zero.class_ for entity in query._entities] - for mapper in query._join_entities: - if isinstance(mapper, Mapper): - entities.append(mapper.class_) - else: - entities.append(mapper) - - # get all label names for queries such as: - # db.session.query(Category, db.func.count(Article.id).label('articles')) - labels = [] - for entity in query._entities: - if isinstance(entity, _ColumnEntity) and entity._label_name: - labels.append(entity._label_name) - - if not sort: - return query - - if sort[0] == '-': - func = desc - sort = sort[1:] - else: - func = asc - - component = None - parts = sort.split('-') - if len(parts) > 1: - component = parts[0] - sort = parts[1] - if sort in labels: - return query.order_by(func(sort)) - - for entity in entities: - table = entity.__table__ - if component and table.name != component: - continue - if sort in table.columns: - try: - attr = getattr(entity, sort) - query = query.order_by(func(attr)) - except AttributeError: - pass - break - return query - - -def defer_except(query, columns): - """ - Deferred loads all columns in given query, except the ones given. - - This function is very useful when working with models with myriad of - columns and you want to deferred load many columns. - - >>> from sqlalchemy_utils import defer_except - >>> query = session.query(Article) - >>> query = defer_except(Article, [Article.id, Article.name]) - - :param columns: columns not to deferred load - """ - model = query._entities[0].entity_zero.class_ - fields = set(model._sa_class_manager.values()) - for field in fields: - property_ = field.property - if isinstance(property_, ColumnProperty): - column = property_.columns[0] - if column.name not in columns: - query = query.options(defer(property_.key)) - return query - - -def escape_like(string, escape_char='*'): - """ - Escapes the string paremeter used in SQL LIKE expressions - - >>> from sqlalchemy_utils import escape_like - >>> query = session.query(User).filter( - ... User.name.ilike(escape_like('John')) - ... ) - - - :param string: a string to escape - :param escape_char: escape character - """ - return ( - string - .replace(escape_char, escape_char * 2) - .replace('%', escape_char + '%') - .replace('_', escape_char + '_') - ) - - -def dependent_foreign_keys(model_class): - """ - Returns dependent foreign keys as dicts for given model class. - - ** Experimental function ** - """ - session = object_session(model_class) - - engine = session.bind - inspector = reflection.Inspector.from_engine(engine) - table_names = inspector.get_table_names() - - dependent_foreign_keys = {} - - for table_name in table_names: - fks = inspector.get_foreign_keys(table_name) - if fks: - dependent_foreign_keys[table_name] = [] - for fk in fks: - if fk['referred_table'] == model_class.__tablename__: - dependent_foreign_keys[table_name].append(fk) - return dependent_foreign_keys - - -class Merger(object): - def memory_merge(self, session, table_name, old_values, new_values): - # try to fetch mappers for given table and update in memory objects as - # well as database table - found = False - for mapper in mapperlib._mapper_registry: - class_ = mapper.class_ - if table_name == class_.__table__.name: - try: - ( - session.query(mapper.class_) - .filter_by(**old_values) - .update( - new_values, - 'fetch' - ) - ) - except sa.exc.IntegrityError: - pass - found = True - return found - - def raw_merge(self, session, table, old_values, new_values): - conditions = [] - for key, value in old_values.items(): - conditions.append(getattr(table.c, key) == value) - sql = ( - table - .update() - .where(sa.and_( - *conditions - )) - .values( - new_values - ) - ) - try: - session.execute(sql) - except sa.exc.IntegrityError: - pass - - def merge_update(self, table_name, from_, to, foreign_key): - session = object_session(from_) - constrained_columns = foreign_key['constrained_columns'] - referred_columns = foreign_key['referred_columns'] - metadata = from_.metadata - table = metadata.tables[table_name] - - new_values = {} - for index, column in enumerate(constrained_columns): - new_values[column] = getattr( - to, referred_columns[index] - ) - - old_values = {} - for index, column in enumerate(constrained_columns): - old_values[column] = getattr( - from_, referred_columns[index] - ) - - if not self.memory_merge(session, table_name, old_values, new_values): - self.raw_merge(session, table, old_values, new_values) - - def __call__(self, from_, to): - """ - Merges entity into another entity. After merging deletes the from_ - argument entity. - """ - if from_.__tablename__ != to.__tablename__: - raise Exception() - - session = object_session(from_) - foreign_keys = dependent_foreign_keys(from_) - - for table_name in foreign_keys: - for foreign_key in foreign_keys[table_name]: - self.merge_update(table_name, from_, to, foreign_key) - - session.delete(from_) - - -def merge(from_, to, merger=Merger): - """ - Merges entity into another entity. After merging deletes the from_ argument - entity. - - After merging the from_ entity is deleted from database. - - :param from_: an entity to merge into another entity - :param to: an entity to merge another entity into - :param merger: Merger class, by default this is sqlalchemy_utils.Merger - class - """ - return Merger()(from_, to) +from .functions import sort_query, defer_except, escape_like +from .merge import merge, Merger +from .types import ( + instrumented_list, + InstrumentedList, + PhoneNumber, + PhoneNumberType, + NumberRange, + NumberRangeRawType, + NumberRangeType +) + + +__all__ = ( + sort_query, + defer_except, + escape_like, + instrumented_list, + merge, + InstrumentedList, + Merger, + NumberRange, + NumberRangeRawType, + NumberRangeType, + PhoneNumber, + PhoneNumberType, +) diff --git a/sqlalchemy_utils/functions.py b/sqlalchemy_utils/functions.py new file mode 100644 index 0000000..9495130 --- /dev/null +++ b/sqlalchemy_utils/functions.py @@ -0,0 +1,159 @@ +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.query import _ColumnEntity +from sqlalchemy.orm.properties import ColumnProperty +from sqlalchemy.sql.expression import desc, asc + + +def sort_query(query, sort): + """ + Applies an sql ORDER BY for given query. This function can be easily used + with user-defined sorting. + + The examples use the following model definition: + + >>> import sqlalchemy as sa + >>> from sqlalchemy import create_engine + >>> from sqlalchemy.orm import sessionmaker + >>> from sqlalchemy.ext.declarative import declarative_base + >>> from sqlalchemy_utils import sort_query + >>> + >>> + >>> engine = create_engine( + ... 'sqlite:///' + ... ) + >>> Base = declarative_base() + >>> Session = sessionmaker(bind=engine) + >>> session = Session() + >>> + >>> class Category(Base): + ... __tablename__ = 'category' + ... id = sa.Column(sa.Integer, primary_key=True) + ... name = sa.Column(sa.Unicode(255)) + >>> + >>> class Article(Base): + ... __tablename__ = 'article' + ... id = sa.Column(sa.Integer, primary_key=True) + ... name = sa.Column(sa.Unicode(255)) + ... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + ... + ... category = sa.orm.relationship( + ... Category, primaryjoin=category_id == Category.id + ... ) + + + + 1. Applying simple ascending sort + + >>> query = session.query(Article) + >>> query = sort_query(query, 'name') + + 2. Appying descending sort + + >>> query = sort_query(query, '-name') + + 3. Applying sort to custom calculated label + + >>> query = session.query( + ... Category, db.func.count(Article.id).label('articles') + ... ) + >>> query = sort_query(query, 'articles') + + 4. Applying sort to joined table column + + >>> query = session.query(Article).join(Article.category) + >>> query = sort_query(query, 'category-name') + + + :param query: query to be modified + :param sort: string that defines the label or column to sort the query by + :param errors: whether or not to raise exceptions if unknown sort column + is passed + """ + entities = [entity.entity_zero.class_ for entity in query._entities] + for mapper in query._join_entities: + if isinstance(mapper, Mapper): + entities.append(mapper.class_) + else: + entities.append(mapper) + + # get all label names for queries such as: + # db.session.query(Category, db.func.count(Article.id).label('articles')) + labels = [] + for entity in query._entities: + if isinstance(entity, _ColumnEntity) and entity._label_name: + labels.append(entity._label_name) + + if not sort: + return query + + if sort[0] == '-': + func = desc + sort = sort[1:] + else: + func = asc + + component = None + parts = sort.split('-') + if len(parts) > 1: + component = parts[0] + sort = parts[1] + if sort in labels: + return query.order_by(func(sort)) + + for entity in entities: + table = entity.__table__ + if component and table.name != component: + continue + if sort in table.columns: + try: + attr = getattr(entity, sort) + query = query.order_by(func(attr)) + except AttributeError: + pass + break + return query + + +def defer_except(query, columns): + """ + Deferred loads all columns in given query, except the ones given. + + This function is very useful when working with models with myriad of + columns and you want to deferred load many columns. + + >>> from sqlalchemy_utils import defer_except + >>> query = session.query(Article) + >>> query = defer_except(Article, [Article.id, Article.name]) + + :param columns: columns not to deferred load + """ + model = query._entities[0].entity_zero.class_ + fields = set(model._sa_class_manager.values()) + for field in fields: + property_ = field.property + if isinstance(property_, ColumnProperty): + column = property_.columns[0] + if column.name not in columns: + query = query.options(defer(property_.key)) + return query + + +def escape_like(string, escape_char='*'): + """ + Escapes the string paremeter used in SQL LIKE expressions + + >>> from sqlalchemy_utils import escape_like + >>> query = session.query(User).filter( + ... User.name.ilike(escape_like('John')) + ... ) + + + :param string: a string to escape + :param escape_char: escape character + """ + return ( + string + .replace(escape_char, escape_char * 2) + .replace('%', escape_char + '%') + .replace('_', escape_char + '_') + ) diff --git a/sqlalchemy_utils/merge.py b/sqlalchemy_utils/merge.py new file mode 100644 index 0000000..eaf2854 --- /dev/null +++ b/sqlalchemy_utils/merge.py @@ -0,0 +1,123 @@ +import sqlalchemy as sa +from sqlalchemy.engine import reflection +from sqlalchemy.orm import object_session, mapperlib + + +def dependent_foreign_keys(model_class): + """ + Returns dependent foreign keys as dicts for given model class. + + ** Experimental function ** + """ + session = object_session(model_class) + + engine = session.bind + inspector = reflection.Inspector.from_engine(engine) + table_names = inspector.get_table_names() + + dependent_foreign_keys = {} + + for table_name in table_names: + fks = inspector.get_foreign_keys(table_name) + if fks: + dependent_foreign_keys[table_name] = [] + for fk in fks: + if fk['referred_table'] == model_class.__tablename__: + dependent_foreign_keys[table_name].append(fk) + return dependent_foreign_keys + + +class Merger(object): + def memory_merge(self, session, table_name, old_values, new_values): + # try to fetch mappers for given table and update in memory objects as + # well as database table + found = False + for mapper in mapperlib._mapper_registry: + class_ = mapper.class_ + if table_name == class_.__table__.name: + try: + ( + session.query(mapper.class_) + .filter_by(**old_values) + .update( + new_values, + 'fetch' + ) + ) + except sa.exc.IntegrityError: + pass + found = True + return found + + def raw_merge(self, session, table, old_values, new_values): + conditions = [] + for key, value in old_values.items(): + conditions.append(getattr(table.c, key) == value) + sql = ( + table + .update() + .where(sa.and_( + *conditions + )) + .values( + new_values + ) + ) + try: + session.execute(sql) + except sa.exc.IntegrityError: + pass + + def merge_update(self, table_name, from_, to, foreign_key): + session = object_session(from_) + constrained_columns = foreign_key['constrained_columns'] + referred_columns = foreign_key['referred_columns'] + metadata = from_.metadata + table = metadata.tables[table_name] + + new_values = {} + for index, column in enumerate(constrained_columns): + new_values[column] = getattr( + to, referred_columns[index] + ) + + old_values = {} + for index, column in enumerate(constrained_columns): + old_values[column] = getattr( + from_, referred_columns[index] + ) + + if not self.memory_merge(session, table_name, old_values, new_values): + self.raw_merge(session, table, old_values, new_values) + + def __call__(self, from_, to): + """ + Merges entity into another entity. After merging deletes the from_ + argument entity. + """ + if from_.__tablename__ != to.__tablename__: + raise Exception() + + session = object_session(from_) + foreign_keys = dependent_foreign_keys(from_) + + for table_name in foreign_keys: + for foreign_key in foreign_keys[table_name]: + self.merge_update(table_name, from_, to, foreign_key) + + session.delete(from_) + + +def merge(from_, to, merger=Merger): + """ + Merges entity into another entity. After merging deletes the from_ argument + entity. + + After merging the from_ entity is deleted from database. + + :param from_: an entity to merge into another entity + :param to: an entity to merge another entity into + :param merger: Merger class, by default this is sqlalchemy_utils.Merger + class + """ + return Merger()(from_, to) diff --git a/sqlalchemy_utils/types.py b/sqlalchemy_utils/types.py new file mode 100644 index 0000000..0261df2 --- /dev/null +++ b/sqlalchemy_utils/types.py @@ -0,0 +1,179 @@ +import phonenumbers +from functools import wraps +from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList +from sqlalchemy import types + + +class PhoneNumber(phonenumbers.phonenumber.PhoneNumber): + ''' + Extends a PhoneNumber class from `Python phonenumbers library`_. Adds + different phone number formats to attributes, so they can be easily used + in templates. Phone number validation method is also implemented. + + Takes the raw phone number and country code as params and parses them + into a PhoneNumber object. + + .. _Python phonenumbers library: + https://github.com/daviddrysdale/python-phonenumbers + + :param raw_number: + String representation of the phone number. + :param country_code: + Country code of the phone number. + ''' + def __init__(self, raw_number, country_code=None): + self._phone_number = phonenumbers.parse(raw_number, country_code) + super(PhoneNumber, self).__init__( + country_code=self._phone_number.country_code, + national_number=self._phone_number.national_number, + extension=self._phone_number.extension, + italian_leading_zero=self._phone_number.italian_leading_zero, + raw_input=self._phone_number.raw_input, + country_code_source=self._phone_number.country_code_source, + preferred_domestic_carrier_code= + self._phone_number.preferred_domestic_carrier_code + ) + self.national = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.NATIONAL + ) + self.international = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.INTERNATIONAL + ) + self.e164 = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.E164 + ) + + def is_valid_number(self): + return phonenumbers.is_valid_number(self._phone_number) + + +class PhoneNumberType(types.TypeDecorator): + """ + Changes PhoneNumber objects to a string representation on the way in and + changes them back to PhoneNumber objects on the way out. If E164 is used + as storing format, no country code is needed for parsing the database + value to PhoneNumber object. + """ + STORE_FORMAT = 'e164' + impl = types.Unicode(20) + + def __init__(self, country_code='US', max_length=20, *args, **kwargs): + super(PhoneNumberType, self).__init__(*args, **kwargs) + self.country_code = country_code + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + return getattr(value, self.STORE_FORMAT) + + def process_result_value(self, value, dialect): + return PhoneNumber(value, self.country_code) + + +class NumberRangeRawType(types.UserDefinedType): + """ + Raw number range type, only supports PostgreSQL for now. + """ + def get_col_spec(self): + return 'int4range' + + +class NumberRangeType(types.TypeDecorator): + impl = NumberRangeRawType + + def process_bind_param(self, value, dialect): + return value + + def process_result_value(self, value, dialect): + return NumberRange.from_normalized_str(value) + + +class NumberRange(object): + def __init__(self, min_value, max_value): + self.min_value = min_value + self.max_value = max_value + + @classmethod + def from_normalized_str(cls, value): + if value is not None: + values = value[1:-1].split(',') + min_value, max_value = map( + lambda a: int(a.strip()), values + ) + + if value[0] == '(': + min_value += 1 + + if value[1] == ')': + max_value -= 1 + + return cls(min_value, max_value) + + @classmethod + def from_str(cls, value): + if value is not None: + values = value.split('-') + min_value, max_value = map( + lambda a: int(a.strip()), values + ) + return cls(min_value, max_value) + + def __repr__(self): + return 'NumberRange(%r, %r)' % (self.min_value, self.max_value) + + def __str__(self): + return '[%s, %s]' % (self.min_value, self.max_value) + + def __add__(self, other): + try: + return NumberRange( + self.min_value + other.min_value, + self.max_value + other.max_value + ) + except AttributeError: + return NotImplemented + + def __iadd__(self, other): + try: + self.min_value += other.min_value + self.max_value += other.max_value + return self + except AttributeError: + return NotImplemented + + def __sub__(self, other): + try: + return NumberRange( + self.min_value - other.min_value, + self.max_value - other.max_value + ) + except AttributeError: + return NotImplemented + + def __isub__(self, other): + try: + self.min_value -= other.min_value + self.max_value -= other.max_value + return self + except AttributeError: + return NotImplemented + + +class InstrumentedList(_InstrumentedList): + """Enhanced version of SQLAlchemy InstrumentedList. Provides some + additional functionality.""" + + def any(self, attr): + return any(getattr(item, attr) for item in self) + + def all(self, attr): + return all(getattr(item, attr) for item in self) + + +def instrumented_list(f): + @wraps(f) + def wrapper(*args, **kwargs): + return InstrumentedList([item for item in f(*args, **kwargs)]) + return wrapper diff --git a/tests.py b/tests.py deleted file mode 100644 index dce5c7a..0000000 --- a/tests.py +++ /dev/null @@ -1,397 +0,0 @@ -import sqlalchemy as sa - -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base - -from sqlalchemy_utils import ( - escape_like, - sort_query, - InstrumentedList, - PhoneNumber, - PhoneNumberType, - merge -) - - -class TestCase(object): - - def setup_method(self, method): - self.engine = create_engine('sqlite:///:memory:') - self.Base = declarative_base() - - self.create_models() - self.Base.metadata.create_all(self.engine) - - Session = sessionmaker(bind=self.engine) - self.session = Session() - - def teardown_method(self, method): - self.session.close_all() - self.Base.metadata.drop_all(self.engine) - self.engine.dispose() - - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - phone_number = sa.Column(PhoneNumberType()) - - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) - - category = sa.orm.relationship( - Category, - primaryjoin=category_id == Category.id, - backref=sa.orm.backref( - 'articles', - collection_class=InstrumentedList - ) - ) - - self.User = User - self.Category = Category - self.Article = Article - - -class TestInstrumentedList(TestCase): - def test_any_returns_true_if_member_has_attr_defined(self): - category = self.Category() - category.articles.append(self.Article()) - category.articles.append(self.Article(name=u'some name')) - assert category.articles.any('name') - - def test_any_returns_false_if_no_member_has_attr_defined(self): - category = self.Category() - category.articles.append(self.Article()) - assert not category.articles.any('name') - - -class TestEscapeLike(TestCase): - def test_escapes_wildcards(self): - assert escape_like('_*%') == '*_***%' - - -class TestSortQuery(TestCase): - def test_without_sort_param_returns_the_query_object_untouched(self): - query = self.session.query(self.Article) - sorted_query = sort_query(query, '') - assert query == sorted_query - - def test_sort_by_column_ascending(self): - query = sort_query(self.session.query(self.Article), 'name') - assert 'ORDER BY article.name ASC' in str(query) - - def test_sort_by_column_descending(self): - query = sort_query(self.session.query(self.Article), '-name') - assert 'ORDER BY article.name DESC' in str(query) - - def test_skips_unknown_columns(self): - query = self.session.query(self.Article) - sorted_query = sort_query(query, '-unknown') - assert query == sorted_query - - def test_sort_by_calculated_value_ascending(self): - query = self.session.query( - self.Category, sa.func.count(self.Article.id).label('articles') - ) - query = sort_query(query, 'articles') - assert 'ORDER BY articles ASC' in str(query) - - def test_sort_by_calculated_value_descending(self): - query = self.session.query( - self.Category, sa.func.count(self.Article.id).label('articles') - ) - query = sort_query(query, '-articles') - assert 'ORDER BY articles DESC' in str(query) - - def test_sort_by_joined_table_column(self): - query = self.session.query(self.Article).join(self.Article.category) - sorted_query = sort_query(query, 'category-name') - assert 'category.name ASC' in str(sorted_query) - - -class TestPhoneNumber(object): - def setup_method(self, method): - self.valid_phone_numbers = [ - '040 1234567', - '+358 401234567', - '09 2501234', - '+358 92501234', - '0800 939393', - '09 4243 0456', - '0600 900 500' - ] - self.invalid_phone_numbers = [ - 'abc', - '+040 1234567', - '0111234567', - '358' - ] - - def test_valid_phone_numbers(self): - for raw_number in self.valid_phone_numbers: - phone_number = PhoneNumber(raw_number, 'FI') - assert phone_number.is_valid_number() - - def test_invalid_phone_numbers(self): - for raw_number in self.invalid_phone_numbers: - try: - phone_number = PhoneNumber(raw_number, 'FI') - assert not phone_number.is_valid_number() - except: - pass - - def test_phone_number_attributes(self): - phone_number = PhoneNumber('+358401234567') - assert phone_number.e164 == u'+358401234567' - assert phone_number.international == u'+358 40 1234567' - assert phone_number.national == u'040 1234567' - - -class TestPhoneNumberType(TestCase): - def setup_method(self, method): - super(TestPhoneNumberType, self).setup_method(method) - self.phone_number = PhoneNumber( - '040 1234567', - 'FI' - ) - self.user = self.User() - self.user.name = u'Someone' - self.user.phone_number = self.phone_number - self.session.add(self.user) - self.session.commit() - - def test_query_returns_phone_number_object(self): - queried_user = self.session.query(self.User).first() - assert queried_user.phone_number == self.phone_number - - def test_phone_number_is_stored_as_string(self): - result = self.session.execute( - 'SELECT phone_number FROM user WHERE id=:param', - {'param': self.user.id} - ) - assert result.first()[0] == u'+358401234567' - - -class DatabaseTestCase(object): - def create_models(self): - pass - - def setup_method(self, method): - self.engine = create_engine( - 'sqlite:///' - ) - #self.engine.echo = True - self.Base = declarative_base() - - self.create_models() - self.Base.metadata.create_all(self.engine) - Session = sessionmaker(bind=self.engine) - self.session = Session() - - def teardown_method(self, method): - self.engine.dispose() - self.Base.metadata.drop_all(self.engine) - self.session.expunge_all() - - -class TestMerge(DatabaseTestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - def __repr__(self): - return 'User(%r)' % self.name - - class BlogPost(self.Base): - __tablename__ = 'blog_post' - id = sa.Column(sa.Integer, primary_key=True) - title = sa.Column(sa.Unicode(255)) - content = sa.Column(sa.UnicodeText) - author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) - - author = sa.orm.relationship(User) - - self.User = User - self.BlogPost = BlogPost - - def test_updates_foreign_keys(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - post = self.BlogPost(title=u'Some title', author=john) - post2 = self.BlogPost(title=u'Other title', author=jack) - self.session.add(john) - self.session.add(jack) - self.session.add(post) - self.session.add(post2) - self.session.commit() - merge(john, jack) - assert post.author == jack - assert post2.author == jack - - def test_deletes_from_entity(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - self.session.add(john) - self.session.add(jack) - self.session.commit() - merge(john, jack) - assert john in self.session.deleted - - -class TestMergeManyToManyAssociations(DatabaseTestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - def __repr__(self): - return 'User(%r)' % self.name - - team_member = sa.Table( - 'team_member', self.Base.metadata, - sa.Column( - 'user_id', sa.Integer, - sa.ForeignKey('user.id', ondelete='CASCADE'), - primary_key=True - ), - sa.Column( - 'team_id', sa.Integer, - sa.ForeignKey('team.id', ondelete='CASCADE'), - primary_key=True - ) - ) - - class Team(self.Base): - __tablename__ = 'team' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - members = sa.orm.relationship( - User, - secondary=team_member, - backref='teams' - ) - - self.User = User - self.Team = Team - - def test_when_association_only_exists_in_from_entity(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - team = self.Team(name=u'Team') - team.members.append(john) - self.session.add(john) - self.session.add(jack) - self.session.commit() - merge(john, jack) - assert john not in team.members - assert jack in team.members - - def test_when_association_exists_in_both(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - team = self.Team(name=u'Team') - team.members.append(john) - team.members.append(jack) - self.session.add(john) - self.session.add(jack) - self.session.commit() - merge(john, jack) - assert john not in team.members - assert jack in team.members - count = self.session.execute( - 'SELECT COUNT(1) FROM team_member' - ).fetchone()[0] - assert count == 1 - - -class TestMergeManyToManyAssociationObjects(DatabaseTestCase): - def create_models(self): - class Team(self.Base): - __tablename__ = 'team' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class TeamMember(self.Base): - __tablename__ = 'team_member' - user_id = sa.Column( - sa.Integer, - sa.ForeignKey(User.id, ondelete='CASCADE'), - primary_key=True - ) - team_id = sa.Column( - sa.Integer, - sa.ForeignKey(Team.id, ondelete='CASCADE'), - primary_key=True - ) - role = sa.Column(sa.Unicode(255)) - team = sa.orm.relationship( - Team, - backref=sa.orm.backref( - 'members', - cascade='all, delete-orphan' - ), - primaryjoin=team_id == Team.id, - ) - user = sa.orm.relationship( - User, - backref=sa.orm.backref( - 'memberships', - cascade='all, delete-orphan' - ), - primaryjoin=user_id == User.id, - ) - - self.User = User - self.TeamMember = TeamMember - self.Team = Team - - def test_when_association_only_exists_in_from_entity(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - team = self.Team(name=u'Team') - team.members.append(self.TeamMember(user=john)) - self.session.add(john) - self.session.add(jack) - self.session.add(team) - self.session.commit() - merge(john, jack) - self.session.commit() - users = [member.user for member in team.members] - assert john not in users - assert jack in users - - def test_when_association_exists_in_both(self): - john = self.User(name=u'John') - jack = self.User(name=u'Jack') - team = self.Team(name=u'Team') - team.members.append(self.TeamMember(user=john)) - team.members.append(self.TeamMember(user=jack)) - self.session.add(john) - self.session.add(jack) - self.session.add(team) - self.session.commit() - merge(john, jack) - users = [member.user for member in team.members] - assert john not in users - assert jack in users - assert self.session.query(self.TeamMember).count() == 1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..875d0c1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,86 @@ +import sqlalchemy as sa + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import ( + escape_like, + sort_query, + InstrumentedList, + PhoneNumber, + PhoneNumberType, + merge +) + + +class TestCase(object): + def setup_method(self, method): + self.engine = create_engine( + 'postgres://postgres@localhost/sqlalchemy_utils_test' + ) + self.Base = declarative_base() + + self.create_models() + self.Base.metadata.create_all(self.engine) + + Session = sessionmaker(bind=self.engine) + self.session = Session() + + def teardown_method(self, method): + self.session.close_all() + self.Base.metadata.drop_all(self.engine) + self.engine.dispose() + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + phone_number = sa.Column(PhoneNumberType()) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles', + collection_class=InstrumentedList + ) + ) + + self.User = User + self.Category = Category + self.Article = Article + + +class DatabaseTestCase(object): + def create_models(self): + pass + + def setup_method(self, method): + self.engine = create_engine( + 'sqlite:///' + ) + #self.engine.echo = True + self.Base = declarative_base() + + self.create_models() + self.Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + self.session = Session() + + def teardown_method(self, method): + self.engine.dispose() + self.Base.metadata.drop_all(self.engine) + self.session.expunge_all() diff --git a/tests/test_instrumented_list.py b/tests/test_instrumented_list.py new file mode 100644 index 0000000..525d46b --- /dev/null +++ b/tests/test_instrumented_list.py @@ -0,0 +1,14 @@ +from tests import TestCase + + +class TestInstrumentedList(TestCase): + def test_any_returns_true_if_member_has_attr_defined(self): + category = self.Category() + category.articles.append(self.Article()) + category.articles.append(self.Article(name=u'some name')) + assert category.articles.any('name') + + def test_any_returns_false_if_no_member_has_attr_defined(self): + category = self.Category() + category.articles.append(self.Article()) + assert not category.articles.any('name') diff --git a/tests/test_merge.py b/tests/test_merge.py new file mode 100644 index 0000000..c776785 --- /dev/null +++ b/tests/test_merge.py @@ -0,0 +1,196 @@ +import sqlalchemy as sa +from sqlalchemy_utils import merge + +from tests import DatabaseTestCase + + +class TestMerge(DatabaseTestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(%r)' % self.name + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.Unicode(255)) + content = sa.Column(sa.UnicodeText) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User) + + self.User = User + self.BlogPost = BlogPost + + def test_updates_foreign_keys(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + post = self.BlogPost(title=u'Some title', author=john) + post2 = self.BlogPost(title=u'Other title', author=jack) + self.session.add(john) + self.session.add(jack) + self.session.add(post) + self.session.add(post2) + self.session.commit() + merge(john, jack) + assert post.author == jack + assert post2.author == jack + + def test_deletes_from_entity(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge(john, jack) + assert john in self.session.deleted + + +class TestMergeManyToManyAssociations(DatabaseTestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(%r)' % self.name + + team_member = sa.Table( + 'team_member', self.Base.metadata, + sa.Column( + 'user_id', sa.Integer, + sa.ForeignKey('user.id', ondelete='CASCADE'), + primary_key=True + ), + sa.Column( + 'team_id', sa.Integer, + sa.ForeignKey('team.id', ondelete='CASCADE'), + primary_key=True + ) + ) + + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + members = sa.orm.relationship( + User, + secondary=team_member, + backref='teams' + ) + + self.User = User + self.Team = Team + + def test_when_association_only_exists_in_from_entity(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(john) + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge(john, jack) + assert john not in team.members + assert jack in team.members + + def test_when_association_exists_in_both(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(john) + team.members.append(jack) + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge(john, jack) + assert john not in team.members + assert jack in team.members + count = self.session.execute( + 'SELECT COUNT(1) FROM team_member' + ).fetchone()[0] + assert count == 1 + + +class TestMergeManyToManyAssociationObjects(DatabaseTestCase): + def create_models(self): + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class TeamMember(self.Base): + __tablename__ = 'team_member' + user_id = sa.Column( + sa.Integer, + sa.ForeignKey(User.id, ondelete='CASCADE'), + primary_key=True + ) + team_id = sa.Column( + sa.Integer, + sa.ForeignKey(Team.id, ondelete='CASCADE'), + primary_key=True + ) + role = sa.Column(sa.Unicode(255)) + team = sa.orm.relationship( + Team, + backref=sa.orm.backref( + 'members', + cascade='all, delete-orphan' + ), + primaryjoin=team_id == Team.id, + ) + user = sa.orm.relationship( + User, + backref=sa.orm.backref( + 'memberships', + cascade='all, delete-orphan' + ), + primaryjoin=user_id == User.id, + ) + + self.User = User + self.TeamMember = TeamMember + self.Team = Team + + def test_when_association_only_exists_in_from_entity(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(self.TeamMember(user=john)) + self.session.add(john) + self.session.add(jack) + self.session.add(team) + self.session.commit() + merge(john, jack) + self.session.commit() + users = [member.user for member in team.members] + assert john not in users + assert jack in users + + def test_when_association_exists_in_both(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(self.TeamMember(user=john)) + team.members.append(self.TeamMember(user=jack)) + self.session.add(john) + self.session.add(jack) + self.session.add(team) + self.session.commit() + merge(john, jack) + users = [member.user for member in team.members] + assert john not in users + assert jack in users + assert self.session.query(self.TeamMember).count() == 1 diff --git a/tests/test_phonenumber_type.py b/tests/test_phonenumber_type.py new file mode 100644 index 0000000..3ecfbde --- /dev/null +++ b/tests/test_phonenumber_type.py @@ -0,0 +1,65 @@ +from tests import TestCase +from sqlalchemy_utils import PhoneNumber + + +class TestPhoneNumber(object): + def setup_method(self, method): + self.valid_phone_numbers = [ + '040 1234567', + '+358 401234567', + '09 2501234', + '+358 92501234', + '0800 939393', + '09 4243 0456', + '0600 900 500' + ] + self.invalid_phone_numbers = [ + 'abc', + '+040 1234567', + '0111234567', + '358' + ] + + def test_valid_phone_numbers(self): + for raw_number in self.valid_phone_numbers: + phone_number = PhoneNumber(raw_number, 'FI') + assert phone_number.is_valid_number() + + def test_invalid_phone_numbers(self): + for raw_number in self.invalid_phone_numbers: + try: + phone_number = PhoneNumber(raw_number, 'FI') + assert not phone_number.is_valid_number() + except: + pass + + def test_phone_number_attributes(self): + phone_number = PhoneNumber('+358401234567') + assert phone_number.e164 == u'+358401234567' + assert phone_number.international == u'+358 40 1234567' + assert phone_number.national == u'040 1234567' + + +class TestPhoneNumberType(TestCase): + def setup_method(self, method): + super(TestPhoneNumberType, self).setup_method(method) + self.phone_number = PhoneNumber( + '040 1234567', + 'FI' + ) + self.user = self.User() + self.user.name = u'Someone' + self.user.phone_number = self.phone_number + self.session.add(self.user) + self.session.commit() + + def test_query_returns_phone_number_object(self): + queried_user = self.session.query(self.User).first() + assert queried_user.phone_number == self.phone_number + + def test_phone_number_is_stored_as_string(self): + result = self.session.execute( + 'SELECT phone_number FROM "user" WHERE id=:param', + {'param': self.user.id} + ) + assert result.first()[0] == u'+358401234567' diff --git a/tests/test_utility_functions.py b/tests/test_utility_functions.py new file mode 100644 index 0000000..26d802a --- /dev/null +++ b/tests/test_utility_functions.py @@ -0,0 +1,47 @@ +import sqlalchemy as sa +from sqlalchemy_utils import escape_like, sort_query +from tests import TestCase + + +class TestEscapeLike(TestCase): + def test_escapes_wildcards(self): + assert escape_like('_*%') == '*_***%' + + +class TestSortQuery(TestCase): + def test_without_sort_param_returns_the_query_object_untouched(self): + query = self.session.query(self.Article) + sorted_query = sort_query(query, '') + assert query == sorted_query + + def test_sort_by_column_ascending(self): + query = sort_query(self.session.query(self.Article), 'name') + assert 'ORDER BY article.name ASC' in str(query) + + def test_sort_by_column_descending(self): + query = sort_query(self.session.query(self.Article), '-name') + assert 'ORDER BY article.name DESC' in str(query) + + def test_skips_unknown_columns(self): + query = self.session.query(self.Article) + sorted_query = sort_query(query, '-unknown') + assert query == sorted_query + + def test_sort_by_calculated_value_ascending(self): + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') + ) + query = sort_query(query, 'articles') + assert 'ORDER BY articles ASC' in str(query) + + def test_sort_by_calculated_value_descending(self): + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') + ) + query = sort_query(query, '-articles') + assert 'ORDER BY articles DESC' in str(query) + + def test_sort_by_joined_table_column(self): + query = self.session.query(self.Article).join(self.Article.category) + sorted_query = sort_query(query, 'category-name') + assert 'category.name ASC' in str(sorted_query)