commit 8d736544fa43e52d4a3127fb9f69ec11ce16aeb4 Author: bninja Date: Tue Aug 19 22:37:39 2014 -0700 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c96ae0a --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +*.py[co] + +# packages +*.egg +*.egg-info +dist +build +eggs +parts +bin +var +sdist +develop-eggs +.installed.cfg + +# installer logs +pip-log.txt + +# unit test / coverage reports +.coverage +.tox + +# sphnix +docs/_build + +# pydev +.pydevproject +.project diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..ca0676b --- /dev/null +++ b/.travis.yml @@ -0,0 +1,18 @@ +language: python +python: + - 2.7 +install: + - pip install -e .[tests] + - pip install coveralls +script: + - py.test test.py --cov=pika_pool --cov-report term-missing +after_success: + - coveralls +deploy: + provider: pypi + user: somepie + password: + secure: f01cUCmCqzNyQ2cQcY5sNQPKXlBgESqp7LMohTfZxyugGxdyGV2YgHLH/BJc31JH4Fnq7WBa7Kjo94c2BofY2jTkQMqdMBXXKLbdg5Rwn4imlvFzv1VYJB8Cwl7aTCIIHn5UldKTnJPzt47rn9DI17iPbqCIdOZ5nKt/fO6ltC0= + on: + all_branches: true + tags: true diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..9561fb1 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include README.rst diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..a3435fc --- /dev/null +++ b/README.rst @@ -0,0 +1,61 @@ +========= +pika-pool +========= + +.. image:: https://travis-ci.org/bninja/pika-pool.png + :target: https://travis-ci.org/bninja/pika-pool + +.. image:: https://coveralls.io/repos/bninja/pika-pool/badge.png + :target: https://coveralls.io/r/bninja/pika-pool + +Pika connection pooling inspired by: + +- `flask-pika `_ +- `sqlalchemy.pool.Pool `_ + +Typically you'll go with local shovel(s), krazy kombu, etc. but this might work too. + +Get it: + +.. code:: bash + + $ pip install pika-pool + +and use it: + +.. code:: python + + import json + + import pika + import pika_pool + + params = pika.URLParameters( + 'amqp://guest:guest@localhost:5672/?' + 'socket_timeout=10&' + 'connection_attempts=2' + ) + + pool = pika_pool.QueuedPool( + create=lambda: pika.BlockingConnection(parameters=params), + max_size=10, + max_overflow=10, + timeout=10, + recycle=3600, + stale=45, + ) + + with pool.acquire() as cxn: + cxn.channel.basic_publish( + body=json.dumps({ + 'type': 'banana', + 'description': 'they are yellow' + }), + exchange='', + routing_key='fruits', + properties=pika.BasicProperties( + content_type='application/json', + content_encoding='utf-8', + delivery_mode=2, + ) + ) diff --git a/pika_pool.py b/pika_pool.py new file mode 100644 index 0000000..52275d9 --- /dev/null +++ b/pika_pool.py @@ -0,0 +1,353 @@ +""" +Pika connection pool inspired by: + + - https://github.com/WeatherDecisionTechnologies/flask-pika + +and this interface: + + - http://docs.sqlalchemy.org/en/latest/core/pooling.html#sqlalchemy.pool.Pool + +Get it like this: + +.. code:: python + + $ pip install pika-pool + +Use it like e.g. this: + +.. code:: python + + import json + + import pika + import pika_pool + + params = pika.URLParameters( + 'amqp://guest:guest@localhost:5672/%2F?socket_timeout=5&connection_attempts=2' + ) + + pool = pika_pool.QueuedPool( + create=lambda: pika.BlockingConnection(parameters=params) + recycle=45, + max_size=10, + max_overflow=10, + timeout=10, + ) + + with pool.acquire() as cxn: + cxn.channel.basic_publish( + body=json.dumps({'type': 'banana', 'color': 'yellow'}), + exchange='exchange', + routing_key='banana', + properties={ + 'content_type': 'application/json', + 'content_encoding': 'utf-8', + 'delivery_mode': 2 + } + ) + +""" +from __future__ import unicode_literals + +import logging +import Queue as queue +import select +import socket +import threading +import time + +import pika.exceptions + + +__version__ = '0.1.0' + +__all__ = [ + 'Error' + 'Timeout' + 'Overflow' + 'Connection', + 'Pool', + 'NullPool', + 'QueuedPool', +] + + +logger = logging.getLogger(__name__) + + +class Error(Exception): + + pass + + +class Overflow(Error): + """ + Raised when a `Pool.acquire` cannot allocate anymore connections. + """ + + pass + + +class Timeout(Error): + """ + Raised when an attempt to `Pool.acquire` a connection has timedout. + """ + + pass + + +class Connection(object): + """ + Connection acquired from a `Pool` instance. Get them like this: + + .. code:: python + + with pool.acquire() as cxn: + print cxn.channel + + """ + + #: Exceptions that imply connection has been invalidated. + connectivity_errors = ( + pika.exceptions.AMQPConnectionError, + pika.exceptions.ConnectionClosed, + pika.exceptions.ChannelClosed, + select.error, # XXX: https://github.com/pika/pika/issues/412 + ) + + @classmethod + def is_connection_invalidated(cls, exc): + """ + Says whether the given exception indicates the connection has been invalidated. + + :param exc: Exception object. + + :return: True if connection has been invalidted, otherwise False. + """ + return any( + isinstance(exc, error)for error in cls.connectivity_errors + ) + + def __init__(self, pool, fairy): + self.pool = pool + self.fairy = fairy + + @property + def channel(self): + if self.fairy.channel is None: + self.fairy.channel = self.fairy.cxn.channel() + return self.fairy.channel + + def close(self): + self.pool.close(self.fairy) + self.fairy = None + + def release(self): + self.pool.release(self.fairy) + self.fairy = None + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None or not self.is_connection_invalidated(value): + self.release() + else: + self.close() + + +class Pool(object): + """ + Pool interface similar to: + + http://docs.sqlalchemy.org/en/latest/core/pooling.html#sqlalchemy.pool.Pool + + and used like: + + .. code:: python + + with pool.acquire(timeout=60) as cxn: + cxn.channel.basic_publish( + ... + ) + + """ + + #: Acquired connection type. + Connection = Connection + + def __init__(self, create): + """ + :param create: Callable creating a new connection. + """ + self.create = create + + def acquire(self, timeout=None): + """ + Retrieve a connection from the pool or create a new one. + """ + raise NotImplementedError + + def release(self, fairy): + """ + Return a connection to the pool. + """ + raise NotImplementedError + + def close(self, fairy): + """ + Forcibly close a connection, suppressing any connection errors. + """ + fairy.close() + + class Fairy(object): + """ + Connection wrapper for tracking its associated state. + """ + + def __init__(self, cxn): + self.cxn = cxn + self.channel = None + + def close(self): + if self.channel: + try: + self.channel.close() + self.channel = None + except Connection.connectivity_errors as ex: + if not Connection.is_connection_invalidated(ex): + raise + try: + self.cxn.close() + except Connection.connectivity_errors as ex: + if not Connection.is_connection_invalidated(ex): + raise + + def _create(self): + """ + All fairy creates go through here. + """ + return self.Fairy(self.create()) + + +class NullPool(Pool): + """ + Dummy pool. It opens/closes connections on each acquire/release. + """ + + def acquire(self, timeout=None): + return self.Connection(self, self._create()) + + def release(self, fairy): + self.close(fairy) + + +class QueuedPool(Pool): + """ + Queue backed pool. + """ + + def __init__(self, + create, + max_size=10, + max_overflow=10, + timeout=30, + recycle=None, + stale=None, + ): + """ + :param max_size: + Maximum number of connections to keep queued. + + :param max_overflow: + Maximum number of connections to create above `max_size`. + + :param timeout: + Default number of seconds to wait for a connections to available. + + :param recycle: + Lifetime of a connection (since creation) in seconds or None for no + recycling. Expired connections are closed on acquire. + + :param stale: + Threshold at which inactive (since release) connections are + considered stale in seconds or None for no staleness. Stale + connections are closed on acquire. + """ + self.max_size = max_size + self.max_overflow = max_overflow + self.timeout = timeout + self.recycle = recycle + self.stale = stale + self._queue = queue.Queue(maxsize=self.max_size) + self._avail_lock = threading.Lock() + self._avail = self.max_size + self.max_overflow + super(QueuedPool, self).__init__(create) + + def acquire(self, timeout=None): + try: + fairy = self._queue.get(False) + except queue.Empty: + try: + fairy = self._create() + except Overflow: + timeout = timeout or self.timeout + try: + fairy = self._queue.get(timeout=self.timeout) + except queue.Empty: + try: + fairy = self._create() + except Overflow: + raise Timeout() + if self.is_expired(fairy): + logger.info('connection %r expired', fairy) + self.close(fairy) + return self.acquire(timeout=timeout) + if self.is_stale(fairy): + logger.info('connection %r stale', fairy) + self.close(fairy) + return self.acquire(timeout=timeout) + return self.Connection(self, fairy) + + def release(self, fairy): + fairy.released_at = time.time() + try: + self._queue.put_nowait(fairy) + except queue.Full: + self.close(fairy) + + def close(self, fairy): + # inc + with self._avail_lock: + self._avail += 1 + return super(QueuedPool, self).close(fairy) + + def _create(self): + # dec + with self._avail_lock: + if self._avail <= 0: + raise Overflow() + self._avail -= 1 + try: + return super(QueuedPool, self)._create() + except: + # inc + with self._avail_lock: + self._avail += 1 + raise + + class Fairy(Pool.Fairy): + + def __init__(self, cxn): + super(QueuedPool.Fairy, self).__init__(cxn) + self.released_at = self.created_at = time.time() + + def is_stale(self, fairy): + if not self.stale: + return False + return (time.time() - fairy.released_at) > self.stale + + def is_expired(self, fairy): + if not self.recycle: + return False + return (time.time() - fairy.created_at) > self.recycle diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..bc2037c --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +import re +import setuptools + + +setuptools.setup( + name='pika-pool', + version=( + re + .compile(r".*__version__ = '(.*?)'", re.S) + .match(open('pika_pool.py').read()) + .group(1) + ), + url='https://github.com/bninja/pika-pool', + license='BSD', + author='egon', + author_email='egon@gb.com', + description='Pools for pikas.', + long_description=open('README.rst').read(), + py_modules=['pika_pool'], + include_package_data=True, + platforms='any', + install_requires=[ + 'pika >=0.9,<0.10', + ], + extras_require={ + 'tests': [ + 'pytest >=2.5.2,<3', + 'pytest-cov >=1.7,<2', + ], + }, + classifiers=[ + 'Environment :: Web Environment', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', + 'Topic :: Software Development :: Libraries :: Python Modules' + ] +) diff --git a/test.py b/test.py new file mode 100644 index 0000000..eac8fa6 --- /dev/null +++ b/test.py @@ -0,0 +1,190 @@ +from __future__ import unicode_literals + +import json +import select +import threading +import time +import uuid + +import pika +import pytest + +import pika_pool + + +@pytest.fixture(scope='session') +def params(): + return pika.URLParameters('amqp://guest:guest@localhost:5672/') + + +@pytest.fixture(scope='session', autouse=True) +def schema(request, params): + cxn = pika.BlockingConnection(params) + channel = cxn.channel() + channel.queue_declare(queue='pika_pool_test') + + +consumed = { +} + + +@pytest.fixture(scope='session', autouse=True) +def consume(request, params): + + def _callback(ch, method, properties, body): + msg = Message.from_json(body) + consumed[msg.id] = msg + + def _forever(): + channel.start_consuming() + + cxn = pika.BlockingConnection(params) + channel = cxn.channel() + channel.basic_consume(_callback, queue='pika_pool_test', no_ack=True) + + threading.Thread(target=_forever).start() + + request.addfinalizer(lambda: channel.stop_consuming()) + + +@pytest.fixture +def null_pool(params): + return pika_pool.NullPool( + create=lambda: pika.BlockingConnection(params), + ) + + +class Message(dict): + + @classmethod + def generate(cls, **kwargs): + id = kwargs.pop('id', uuid.uuid4().hex) + return cls(id=id, **kwargs) + + @property + def id(self): + return self['id'] + + def to_json(self): + return json.dumps(self) + + @classmethod + def from_json(cls, raw): + return cls(json.loads(raw)) + + +class TestNullPool(object): + + def test_pub(self, null_pool): + msg = Message.generate() + with null_pool.acquire() as cxn: + cxn.channel.basic_publish( + exchange='', + routing_key='pika_pool_test', + body=msg.to_json() + ) + time.sleep(0.1) + assert msg.id in consumed + + +@pytest.fixture +def queued_pool(params): + return pika_pool.QueuedPool( + create=lambda: pika.BlockingConnection(params), + recycle=10, + stale=10, + max_size=10, + max_overflow=10, + timeout=10, + ) + + +def test_use_it(): + params = pika.URLParameters( + 'amqp://guest:guest@localhost:5672/?' + 'socket_timeout=10&' + 'connection_attempts=2' + ) + + pool = pika_pool.QueuedPool( + create=lambda: pika.BlockingConnection(parameters=params), + max_size=10, + max_overflow=10, + timeout=10, + recycle=3600, + stale=45, + ) + + with pool.acquire() as cxn: + cxn.channel.basic_publish( + body=json.dumps({ + 'type': 'banana', + 'description': 'they are yellow' + }), + exchange='', + routing_key='fruits', + properties=pika.BasicProperties( + content_type='application/json', + content_encoding='utf-8', + delivery_mode=2, + ) + ) + + +class TestQueuedPool(object): + + def test_invalidate_connection(slef, queued_pool): + msg = Message.generate() + with pytest.raises(select.error): + with queued_pool.acquire() as cxn: + fairy = cxn.fairy + raise select.error(9, 'Bad file descriptor') + assert fairy.cxn.is_closed + + def test_pub(self, queued_pool): + msg = Message.generate() + with queued_pool.acquire() as cxn: + cxn.channel.basic_publish( + exchange='', + routing_key='pika_pool_test', + body=msg.to_json() + ) + time.sleep(0.1) + assert msg.id in consumed + + def test_expire(self, queued_pool): + with queued_pool.acquire() as cxn: + expired = id(cxn.fairy.cxn) + expires_at = cxn.fairy.created_at + queued_pool.recycle + with queued_pool.acquire() as cxn: + assert expired == id(cxn.fairy.cxn) + cxn.fairy.created_at -= queued_pool.recycle + with queued_pool.acquire() as cxn: + assert expired != id(cxn.fairy.cxn) + + def test_stale(self, queued_pool): + with queued_pool.acquire() as cxn: + stale = id(cxn.fairy.cxn) + fairy = cxn.fairy + with queued_pool.acquire() as cxn: + assert stale == id(cxn.fairy.cxn) + fairy.released_at -= queued_pool.stale + with queued_pool.acquire() as cxn: + assert stale != id(cxn.fairy.cxn) + + def test_overflow(self, queued_pool): + queued = [queued_pool.acquire() for _ in range(queued_pool.max_size)] + with queued_pool.acquire() as cxn: + fairy = cxn.fairy + for cxn in queued: + cxn.release() + assert fairy.cxn.is_closed + + def test_timeout(self, request, queued_pool): + queued = [queued_pool.acquire() for _ in range(queued_pool.max_size)] + request.addfinalizer(lambda: [cxn.release() for cxn in queued]) + overflow = [queued_pool.acquire() for _ in range(queued_pool.max_overflow)] + request.addfinalizer(lambda: [cxn.release() for cxn in overflow]) + queued_pool.timeout = 1 + with pytest.raises(pika_pool.Timeout): + queued_pool.acquire()