diff --git a/setup.py b/setup.py index afa2336..9cb382e 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ extras_require = { 'docutils>=0.10', 'flexmock>=0.9.7', 'psycopg2>=2.4.6', + 'pymysql' ], 'anyjson': ['anyjson>=0.3.3'], 'babel': ['Babel>=1.3'], diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index dcc636d..7911b77 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -13,7 +13,10 @@ from .functions import ( mock_engine, sort_query, table_name, - with_backrefs + with_backrefs, + database_exists, + create_database, + drop_database ) from .listeners import coercion_listener from .merge import merge, Merger @@ -96,4 +99,7 @@ __all__ = ( TimezoneType, TSVectorType, UUIDType, + database_exists, + create_database, + drop_database ) diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 0ac53be..9cce2a0 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -6,6 +6,7 @@ from .defer_except import defer_except from .mock import create_mock_engine, mock_engine from .render import render_expression, render_statement from .sort_query import sort_query, QuerySorterException +from .database import database_exists, create_database, drop_database __all__ = ( @@ -18,7 +19,10 @@ __all__ = ( render_statement, with_backrefs, CompositePath, - QuerySorterException + QuerySorterException, + database_exists, + create_database, + drop_database ) diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py new file mode 100644 index 0000000..be2ebe6 --- /dev/null +++ b/sqlalchemy_utils/functions/database.py @@ -0,0 +1,86 @@ +from sqlalchemy.engine.url import make_url +import sqlalchemy as sa +from sqlalchemy.exc import ProgrammingError +import os + + +def database_exists(url): + """Check if a database exists. + """ + + url = make_url(url) + database = url.database + url.database = None + + engine = sa.create_engine(url) + + if engine.dialect.name == 'postgres': + text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database + return bool(engine.execute(text).scalar()) + + elif engine.dialect.name == 'mysql': + text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA " + "WHERE SCHEMA_NAME = '%s'" % database) + return bool(engine.execute(text).scalar()) + + elif engine.dialect.name == 'sqlite': + return database == ':memory:' or os.path.exists(database) + + else: + text = 'SELECT 1' + try: + url.database = database + engine = sa.create_engine(url) + engine.execute(text) + return True + + except ProgrammingError: + return False + + +def create_database(url, encoding='utf8'): + """Issue the appropriate CREATE DATABASE statement. + """ + + url = make_url(url) + + database = url.database + if not url.drivername.startswith('sqlite'): + url.database = None + + engine = sa.create_engine(url) + + if engine.dialect.name == 'postgres': + text = "CREATE DATABASE %s ENCODING = '%s'" % (database, encoding) + engine.execute(text) + + elif engine.dialect.name == 'mysql': + text = "CREATE DATABASE %s CHARACTER SET = '%s'" % (database, encoding) + engine.execute(text) + + elif engine.dialect.name == 'sqlite' and database != ':memory:': + open(database, 'w').close() + + else: + text = "CREATE DATABASE %s" % database + engine.execute(text) + + +def drop_database(url): + """Issue the appropriate DROP DATABASE statement. + """ + + url = make_url(url) + + database = url.database + if not url.drivername.startswith('sqlite'): + url.database = None + + engine = sa.create_engine(url) + + if engine.dialect.name == 'sqlite' and url.database != ':memory:': + os.remove(url.database) + + else: + text = "DROP DATABASE %s" % database + engine.execute(text) diff --git a/tests/functions/__init__.py b/tests/functions/__init__.py index e69de29..ad6e251 100644 --- a/tests/functions/__init__.py +++ b/tests/functions/__init__.py @@ -0,0 +1,7 @@ +import sqlalchemy as sa +from tests import TestCase +from sqlalchemy_utils.functions import ( + render_statement, + render_expression, + mock_engine +) diff --git a/tests/functions/test_database.py b/tests/functions/test_database.py new file mode 100644 index 0000000..34559b1 --- /dev/null +++ b/tests/functions/test_database.py @@ -0,0 +1,38 @@ +import sqlalchemy as sa +import os +from tests import TestCase +from sqlalchemy_utils import ( + create_database, + drop_database, + database_exists, +) + + +class DatabaseTest(TestCase): + + def test_create_and_drop(self): + assert not database_exists(self.url) + create_database(self.url) + assert database_exists(self.url) + drop_database(self.url) + assert not database_exists(self.url) + + +class TestDatabaseSQLite(DatabaseTest): + + url = 'sqlite:///sqlalchemy_utils.db' + + def setup(self): + if os.path.exists('sqlalchemy_utils.db'): + os.remove('sqlalchemy_utils.db') + + def test_exists_memory(self): + assert database_exists('sqlite:///:memory:') + + +class TestDatabaseMySQL(DatabaseTest): + url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy_util' + + +class TestDatabasePostgres(DatabaseTest): + url = 'postgres://postgres@localhost/db_test_sqlalchemy_util'