From a1955c3a00a8398dc27beacb92c4efcb1119c8d3 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 7 Mar 2013 11:09:49 +0200 Subject: [PATCH] Added SmartList, 0.3 release --- CHANGES.rst | 7 +++++++ docs/index.rst | 4 ++++ setup.py | 2 +- sqlalchemy_utils/__init__.py | 33 ++++++++++++++++++++++++++++++++- tests.py | 28 ++++++++++++++++++++++++++-- 5 files changed, 70 insertions(+), 4 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ac2dc79..86d52a6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,13 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. + +0.3.0 (2013-03-01) +^^^^^^^^^^^^^^^^^^ + +- Added new collection class SmartList + + 0.2.0 (2013-03-01) ^^^^^^^^^^^^^^^^^^ diff --git a/docs/index.rst b/docs/index.rst index a477467..65f1d3c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,12 +6,16 @@ SQLAlchemy-Utils ================ +SQLAlchemy-Utils provides various utility classes and functions for SQLAlchemy. + API Documentation ----------------- .. module:: sqlalchemy_utils +.. autoclass:: SmartList + :members: .. autofunction:: sort_query .. autofunction:: escape_like diff --git a/setup.py b/setup.py index 28ad7c7..ae9f16e 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ class PyTest(Command): setup( name='SQLAlchemy-Utils', - version='0.2', + version='0.3', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 432cea5..5d55d59 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,10 +1,41 @@ from sqlalchemy.orm import defer +from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.sql.expression import desc, asc +class SmartList(InstrumentedList): + def has(self, attr): + """ + Returns True if any member of this collection has given attribute + defined. + + Example syntax: + + >>> Category.articles.has('name') + + :param attr: collection member attribute name + """ + adapter = self._sa_adapter + owner_class = adapter.owner_state.class_ + relation = getattr(owner_class, adapter._key).property + relation_class = relation.mapper.class_ + + if not hasattr(relation_class, attr): + raise AttributeError( + 'Class %s does not have attribute named %s' % + (relation_class.__name__, attr) + ) + + for record in self: + if getattr(record, attr): + return True + + return False + + def sort_query(query, sort): """ Applies an sql ORDER BY for given query. This function can be easily used @@ -16,7 +47,7 @@ def sort_query(query, sort): >>> from sqlalchemy import create_engine >>> from sqlalchemy.orm import sessionmaker >>> from sqlalchemy.ext.declarative import declarative_base - >>> from sqlalchemy_utils import escape_like, sort_query + >>> from sqlalchemy_utils import sort_query >>> >>> >>> engine = create_engine( diff --git a/tests.py b/tests.py index 20716ff..b21c109 100644 --- a/tests.py +++ b/tests.py @@ -1,10 +1,11 @@ +from pytest import raises import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_utils import escape_like, sort_query +from sqlalchemy_utils import escape_like, sort_query, SmartList engine = create_engine( @@ -29,10 +30,33 @@ class Article(Base): category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) category = sa.orm.relationship( - Category, primaryjoin=category_id == Category.id + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles', + collection_class=SmartList + ) ) +class TestSmartList(object): + def test_has_raises_error_for_unknown_attribute(self): + category = Category() + with raises(AttributeError): + category.articles.has('unknown_column') + + def test_has_returns_true_if_member_has_attr_defined(self): + category = Category() + category.articles.append(Article()) + category.articles.append(Article(name=u'some name')) + assert category.articles.has('name') + + def test_has_returns_false_if_no_member_has_attr_defined(self): + category = Category() + category.articles.append(Article()) + assert not category.articles.has('name') + + class TestEscapeLike(object): def test_escapes_wildcards(self): assert escape_like('_*%') == '*_***%'