From 03129bfebd80f2f1289f8d160812bdb28d558212 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 7 Oct 2014 10:28:43 +0300 Subject: [PATCH] Add explain and explain_analyze expressions --- CHANGES.rst | 6 +++++ sqlalchemy_utils/expressions.py | 21 +++++++++++++++++ tests/test_expressions.py | 40 ++++++++++++++++++++++++++++----- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 9e0f903..a9eafc0 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.26.17 (2014-10-xx) +^^^^^^^^^^^^^^^^^^^^ + +- Added explain and explain_analyze expressions + + 0.26.16 (2014-09-09) ^^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/expressions.py b/sqlalchemy_utils/expressions.py index 150c6fa..b0da0fc 100644 --- a/sqlalchemy_utils/expressions.py +++ b/sqlalchemy_utils/expressions.py @@ -1,9 +1,30 @@ import sqlalchemy as sa from sqlalchemy.sql import expression +from sqlalchemy.sql.expression import Executable, ClauseElement, _literal_as_text from sqlalchemy.ext.compiler import compiles from sqlalchemy_utils.types import TSVectorType +class explain(Executable, ClauseElement): + def __init__(self, stmt, analyze=False): + self.statement = _literal_as_text(stmt) + self.analyze = analyze + + +class explain_analyze(explain): + def __init__(self, stmt): + super(explain_analyze, self).__init__(stmt, analyze=True) + + +@compiles(explain, 'postgresql') +def pg_explain(element, compiler, **kw): + text = "EXPLAIN " + if element.analyze: + text += "ANALYZE " + text += compiler.process(element.statement) + return text + + class tsvector_match(expression.FunctionElement): type = sa.types.Unicode() name = 'tsvector_match' diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 7718b53..09bb571 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,7 +1,10 @@ from pytest import raises import sqlalchemy as sa +from sqlalchemy.dialects import postgresql from sqlalchemy_utils.types import TSVectorType from sqlalchemy_utils.expressions import ( + explain, + explain_analyze, tsvector_match, tsvector_concat, to_tsquery, @@ -10,7 +13,7 @@ from sqlalchemy_utils.expressions import ( from tests import TestCase -class TSVectorTestCase(TestCase): +class ExpressionTestCase(TestCase): dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' def create_models(self): @@ -25,7 +28,34 @@ class TSVectorTestCase(TestCase): self.Article = Article -class TestMatchTSVector(TSVectorTestCase): +class TestExplain(ExpressionTestCase): + def test_render_explain(self): + assert str( + explain(self.session.query(self.Article)).compile( + dialect=postgresql.dialect() + ) + ).startswith('EXPLAIN SELECT') + + def test_render_explain_with_analyze(self): + assert str( + explain(self.session.query(self.Article), analyze=True) + .compile( + dialect=postgresql.dialect() + ) + ).startswith('EXPLAIN ANALYZE SELECT') + + +class TestExplainAnalyze(ExpressionTestCase): + def test_render_explain_analyze(self): + assert str( + explain_analyze(self.session.query(self.Article)) + .compile( + dialect=postgresql.dialect() + ) + ).startswith('EXPLAIN ANALYZE SELECT') + + +class TestMatchTSVector(ExpressionTestCase): def test_raises_exception_if_less_than_2_parameters_given(self): with raises(Exception): str( @@ -41,7 +71,7 @@ class TestMatchTSVector(TSVectorTestCase): )) == '(article.search_vector) @@ to_tsquery(:to_tsquery_1)' -class TestToTSQuery(TSVectorTestCase): +class TestToTSQuery(ExpressionTestCase): def test_requires_atleast_one_parameter(self): with raises(Exception): str(to_tsquery()) @@ -50,7 +80,7 @@ class TestToTSQuery(TSVectorTestCase): assert str(to_tsquery('something')) == 'to_tsquery(:to_tsquery_1)' -class TestPlainToTSQuery(TSVectorTestCase): +class TestPlainToTSQuery(ExpressionTestCase): def test_requires_atleast_one_parameter(self): with raises(Exception): str(plainto_tsquery()) @@ -61,7 +91,7 @@ class TestPlainToTSQuery(TSVectorTestCase): ) -class TestConcatTSVector(TSVectorTestCase): +class TestConcatTSVector(ExpressionTestCase): def test_concatenate_search_vectors(self): assert str(tsvector_match( tsvector_concat(