diff --git a/test-requirements.txt b/test-requirements.txt index a7b8d480..83873b42 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,6 +2,7 @@ pep8>=1.4.5 pyflakes>=0.7.2,<0.7.4 flake8>=2.1.0 discover +mock>=1.0 # only needed on < python 3.3 sphinx>=1.1.2,<1.2 python-subunit testrepository>=0.0.17 diff --git a/tooz/drivers/pgsql.py b/tooz/drivers/pgsql.py index d142137d..84940664 100644 --- a/tooz/drivers/pgsql.py +++ b/tooz/drivers/pgsql.py @@ -15,6 +15,8 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +import contextlib import hashlib import psycopg2 @@ -24,6 +26,64 @@ import tooz from tooz import coordination from tooz.drivers import _retry from tooz import locking +from tooz import utils + + +# See: psycopg/diagnostics_type.c for what kind of fields these +# objects may have (things like 'schema_name', 'internal_query' +# and so-on which are useful for figuring out what went wrong...) +_DIAGNOSTICS_ATTRS = tuple([ + 'column_name', + 'constraint_name', + 'context', + 'datatype_name', + 'internal_position', + 'internal_query', + 'message_detail', + 'message_hint', + 'message_primary', + 'schema_name', + 'severity', + 'source_file', + 'source_function', + 'source_line', + 'sqlstate', + 'statement_position', + 'table_name', +]) + + +def _format_exception(e): + lines = [ + "%s: %s" % (type(e).__name__, utils.exception_message(e).strip()), + ] + if hasattr(e, 'pgcode') and e.pgcode is not None: + lines.append("Error code: %s" % e.pgcode) + # The reason this hasattr check is done is that the 'diag' may not always + # be present, depending on how new of a psycopg is installed... so better + # to be safe than sorry... + if hasattr(e, 'diag') and e.diag is not None: + diagnostic_lines = [] + for attr_name in _DIAGNOSTICS_ATTRS: + if not hasattr(e.diag, attr_name): + continue + attr_value = getattr(e.diag, attr_name) + if attr_value is None: + continue + diagnostic_lines.append(" %s = %s" (attr_name, attr_value)) + if diagnostic_lines: + lines.append('Diagnostics:') + lines.extend(diagnostic_lines) + return "\n".join(lines) + + +@contextlib.contextmanager +def _translating_cursor(conn): + try: + with conn.cursor() as cur: + yield cur + except psycopg2.Error as e: + raise coordination.ToozError(_format_exception(e)) class PostgresLock(locking.Lock): @@ -41,16 +101,16 @@ class PostgresLock(locking.Lock): def acquire(self, blocking=True): if blocking is True: - with self._conn.cursor() as cur: + with _translating_cursor(self._conn) as cur: cur.execute("SELECT pg_advisory_lock(%s, %s);", self.key) - return True + return True elif blocking is False: - with self._conn.cursor() as cur: + with _translating_cursor(self._conn) as cur: cur.execute("SELECT pg_try_advisory_lock(%s, %s);", self.key) return cur.fetchone()[0] else: def _acquire(): - with self._conn.cursor() as cur: + with _translating_cursor(self._conn) as cur: cur.execute("SELECT pg_try_advisory_lock(%s, %s);", self.key) if cur.fetchone()[0] is True: @@ -61,7 +121,7 @@ class PostgresLock(locking.Lock): return _retry.Retrying(**kwargs).call(_acquire) def release(self): - with self._conn.cursor() as cur: + with _translating_cursor(self._conn) as cur: cur.execute("SELECT pg_advisory_unlock(%s, %s);", self.key) return cur.fetchone()[0] @@ -78,11 +138,14 @@ class PostgresDriver(coordination.CoordinationDriver): self._password = parsed_url.password def _start(self): - self._conn = psycopg2.connect(host=self._host, - port=self._port, - user=self._username, - password=self._password, - database=self._dbname) + try: + self._conn = psycopg2.connect(host=self._host, + port=self._port, + user=self._username, + password=self._password, + database=self._dbname) + except psycopg2.Error as e: + raise coordination.ToozConnectionError(_format_exception(e)) def _stop(self): self._conn.close() diff --git a/tooz/tests/test_postgresql.py b/tooz/tests/test_postgresql.py new file mode 100644 index 00000000..35c8d0e8 --- /dev/null +++ b/tooz/tests/test_postgresql.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import uuid + +try: + # Added in python 3.3+ + from unittest import mock +except ImportError: + import mock + +import testtools +from testtools import testcase + +from tooz import coordination +from tooz import utils + +# Handle the case gracefully where the driver is not installed. +try: + import psycopg2 + PGSQL_AVAILABLE = True +except ImportError: + PGSQL_AVAILABLE = False + + +@testtools.skipUnless(PGSQL_AVAILABLE, 'psycopg2 is not available') +class TestPostgreSQLFailures(testcase.TestCase): + + # Not actually used (but required none the less), since we mock out + # the connect() method... + FAKE_URL = "postgresql://localhost:1" + + def _create_coordinator(self): + + def _safe_stop(coord): + try: + coord.stop() + except coordination.ToozError as e: + # TODO(harlowja): make this better, so that we don't have to + # do string checking... + message = utils.exception_message(e) + if (message != 'Can not stop a driver which has not' + ' been started'): + raise + + coord = coordination.get_coordinator(self.FAKE_URL, + str(uuid.uuid4()).encode('ascii')) + self.addCleanup(_safe_stop, coord) + return coord + + @mock.patch("tooz.drivers.pgsql.psycopg2.connect") + def test_connect_failure(self, psycopg2_connector): + psycopg2_connector.side_effect = psycopg2.Error("Broken") + c = self._create_coordinator() + self.assertRaises(coordination.ToozConnectionError, c.start) + + @mock.patch("tooz.drivers.pgsql.psycopg2.connect") + def test_connect_failure_operational(self, psycopg2_connector): + psycopg2_connector.side_effect = psycopg2.OperationalError("Broken") + c = self._create_coordinator() + self.assertRaises(coordination.ToozConnectionError, c.start) + + @mock.patch("tooz.drivers.pgsql.psycopg2.connect") + def test_failure_acquire_lock(self, psycopg2_connector): + execute_mock = mock.MagicMock() + execute_mock.execute.side_effect = psycopg2.OperationalError("Broken") + + cursor_mock = mock.MagicMock() + cursor_mock.__enter__ = mock.MagicMock(return_value=execute_mock) + cursor_mock.__exit__ = mock.MagicMock(return_value=False) + + conn_mock = mock.MagicMock() + conn_mock.cursor.return_value = cursor_mock + psycopg2_connector.return_value = conn_mock + + c = self._create_coordinator() + c.start() + test_lock = c.get_lock(b'test-lock') + self.assertRaises(coordination.ToozError, test_lock.acquire) + + @mock.patch("tooz.drivers.pgsql.psycopg2.connect") + def test_failure_release_lock(self, psycopg2_connector): + execute_mock = mock.MagicMock() + execute_mock.execute.side_effect = [ + True, + psycopg2.OperationalError("Broken"), + ] + + cursor_mock = mock.MagicMock() + cursor_mock.__enter__ = mock.MagicMock(return_value=execute_mock) + cursor_mock.__exit__ = mock.MagicMock(return_value=False) + + conn_mock = mock.MagicMock() + conn_mock.cursor.return_value = cursor_mock + psycopg2_connector.return_value = conn_mock + + c = self._create_coordinator() + c.start() + test_lock = c.get_lock(b'test-lock') + self.assertTrue(test_lock.acquire()) + self.assertRaises(coordination.ToozError, test_lock.release) diff --git a/tox.ini b/tox.ini index 1c503148..74fa529f 100644 --- a/tox.ini +++ b/tox.ini @@ -106,3 +106,4 @@ show-source = True [hacking] import_exceptions = six.moves + unittest.mock