From e181292bb5364083e2b3bf28f04e09d45510a123 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 5 Jun 2013 13:38:42 +0300 Subject: [PATCH] More efficient contains method for proxy dict --- CHANGES.rst | 6 ++++++ setup.py | 2 +- sqlalchemy_utils/proxy_dict.py | 22 ++++++++++++---------- tests/__init__.py | 20 +++++++++++++------- tests/test_proxy_dict.py | 19 +++++++++++++++++++ 5 files changed, 51 insertions(+), 18 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ffe9883..8a648cd 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.12.5 (2013-06-05) +^^^^^^^^^^^^^^^^^^^ + +- ProxyDict now contains None values in cache - more efficient contains method. + + 0.12.4 (2013-06-01) ^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index c14e256..169ac49 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ class PyTest(Command): setup( name='SQLAlchemy-Utils', - version='0.12.4', + version='0.12.5', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/proxy_dict.py b/sqlalchemy_utils/proxy_dict.py index df1bade..e6ef228 100644 --- a/sqlalchemy_utils/proxy_dict.py +++ b/sqlalchemy_utils/proxy_dict.py @@ -18,10 +18,9 @@ class ProxyDict(object): return [x[0] for x in self.collection.values(descriptor)] def __contains__(self, key): - try: - return key in self.cache or self.fetch(key) is not None - except KeyError: - return False + if key in self.cache: + return self.cache[key] is not None + return self.fetch(key) is not None def has_key(self, key): return self.__contains__(key) @@ -29,7 +28,9 @@ class ProxyDict(object): def fetch(self, key): session = sa.orm.object_session(self.parent) if session and sa.orm.util.has_identity(self.parent): - return self.collection.filter_by(**{self.key_name: key}).first() + obj = self.collection.filter_by(**{self.key_name: key}).first() + self.cache[key] = obj + return obj def create_new_instance(self, key): value = self.child_class(**{self.key_name: key}) @@ -39,11 +40,12 @@ class ProxyDict(object): def __getitem__(self, key): if key in self.cache: - return self.cache[key] - - value = self.fetch(key) - if value: - return value + if self.cache[key] is not None: + return self.cache[key] + else: + value = self.fetch(key) + if value: + return value return self.create_new_instance(key) diff --git a/tests/__init__.py b/tests/__init__.py index 875d0c1..eb93fe8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,31 +5,37 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import ( - escape_like, - sort_query, InstrumentedList, - PhoneNumber, PhoneNumberType, - merge ) +@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') +def count_sql_calls(conn, cursor, statement, parameters, context, executemany): + try: + conn.query_count += 1 + except AttributeError: + conn.query_count = 0 + + class TestCase(object): def setup_method(self, method): self.engine = create_engine( 'postgres://postgres@localhost/sqlalchemy_utils_test' ) + self.connection = self.engine.connect() self.Base = declarative_base() self.create_models() - self.Base.metadata.create_all(self.engine) + self.Base.metadata.create_all(self.connection) - Session = sessionmaker(bind=self.engine) + Session = sessionmaker(bind=self.connection) self.session = Session() def teardown_method(self, method): self.session.close_all() - self.Base.metadata.drop_all(self.engine) + self.Base.metadata.drop_all(self.connection) + self.connection.close() self.engine.dispose() def create_models(self): diff --git a/tests/test_proxy_dict.py b/tests/test_proxy_dict.py index 13ebcd9..58bec79 100644 --- a/tests/test_proxy_dict.py +++ b/tests/test_proxy_dict.py @@ -79,6 +79,25 @@ class TestProxyDict(TestCase): ) article.translations['en'] + def test_contains_efficiency(self): + article = self.Article() + self.session.add(article) + self.session.commit() + article.id + query_count = self.connection.query_count + 'en' in article.translations + 'en' in article.translations + 'en' in article.translations + assert self.connection.query_count == query_count + 1 + + def test_getitem_with_none_value_in_cache(self): + article = self.Article() + self.session.add(article) + self.session.commit() + article.id + 'en' in article.translations + assert article.translations['en'] + def test_contains(self): article = self.Article() assert 'en' not in article.translations