Add select_correlated_expression

This commit is contained in:
Konsta Vesterinen
2015-07-16 09:57:43 +03:00
parent 6f1b61762f
commit 9ab88a7ab5
2 changed files with 564 additions and 0 deletions

View File

@@ -1,2 +1,118 @@
import sqlalchemy as sa
from sqlalchemy.sql.util import ClauseAdapter
from .chained_join import chained_join # noqa from .chained_join import chained_join # noqa
from .select_aggregate import select_aggregate # noqa from .select_aggregate import select_aggregate # noqa
def path_to_relationships(path, cls):
relationships = []
for path_name in path.split('.'):
rel = getattr(cls, path_name)
relationships.append(rel)
cls = rel.mapper.class_
return relationships
def adapt_expr(expr, *selectables):
for selectable in selectables:
expr = ClauseAdapter(selectable).traverse(expr)
return expr
def inverse_join(selectable, left_alias, right_alias, relationship):
if relationship.property.secondary is not None:
secondary_alias = sa.alias(relationship.property.secondary)
return selectable.join(
secondary_alias,
adapt_expr(
relationship.property.secondaryjoin,
sa.inspect(left_alias).selectable,
secondary_alias
)
).join(
right_alias,
adapt_expr(
relationship.property.primaryjoin,
sa.inspect(right_alias).selectable,
secondary_alias
)
)
else:
join = sa.orm.join(right_alias, left_alias, relationship)
onclause = join.onclause
return selectable.join(right_alias, onclause)
def relationship_to_correlation(relationship, alias):
if relationship.property.secondary is not None:
return adapt_expr(
relationship.property.primaryjoin,
alias,
)
else:
return sa.orm.join(
relationship.parent,
alias,
relationship
).onclause
def chained_inverse_join(relationships, leaf_model):
selectable = sa.inspect(leaf_model).selectable
aliases = [leaf_model]
for index, relationship in enumerate(relationships[1:]):
aliases.append(sa.orm.aliased(relationship.mapper.class_))
selectable = inverse_join(
selectable,
aliases[index],
aliases[index + 1],
relationships[index]
)
if relationships[-1].property.secondary is not None:
secondary_alias = sa.alias(relationships[-1].property.secondary)
selectable = selectable.join(
secondary_alias,
adapt_expr(
relationships[-1].property.secondaryjoin,
secondary_alias,
sa.inspect(aliases[-1]).selectable
)
)
aliases.append(secondary_alias)
return selectable, aliases
def select_correlated_expression(
root_model,
expr,
path,
leaf_model,
from_obj=None,
order_by=None
):
relationships = list(reversed(path_to_relationships(path, root_model)))
query = sa.select([expr])
selectable = sa.inspect(leaf_model).selectable
if order_by:
query = query.order_by(
*[adapt_expr(o, selectable) for o in order_by]
)
join_expr, aliases = chained_inverse_join(relationships, leaf_model)
condition = relationship_to_correlation(
relationships[-1],
aliases[-1]
)
if from_obj is not None:
condition = adapt_expr(condition, from_obj)
query = query.select_from(join_expr.selectable)
return query.correlate(
from_obj if from_obj is not None else root_model
).where(condition)

View File

@@ -0,0 +1,448 @@
import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils.relationships import select_correlated_expression
@pytest.fixture(scope='class')
def base():
return declarative_base()
@pytest.fixture(scope='class')
def group_user_cls(base):
return sa.Table(
'group_user',
base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
)
@pytest.fixture(scope='class')
def group_cls(base):
class Group(base):
__tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
return Group
@pytest.fixture(scope='class')
def friendship_cls(base):
return sa.Table(
'friendships',
base.metadata,
sa.Column(
'friend_a_id',
sa.Integer,
sa.ForeignKey('user.id'),
primary_key=True
),
sa.Column(
'friend_b_id',
sa.Integer,
sa.ForeignKey('user.id'),
primary_key=True
)
)
@pytest.fixture(scope='class')
def user_cls(base, group_user_cls, friendship_cls):
class User(base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
groups = sa.orm.relationship(
'Group',
secondary=group_user_cls,
backref='users'
)
# this relationship is used for persistence
friends = sa.orm.relationship(
'User',
secondary=friendship_cls,
primaryjoin=id == friendship_cls.c.friend_a_id,
secondaryjoin=id == friendship_cls.c.friend_b_id,
)
friendship_union = sa.select([
friendship_cls.c.friend_a_id,
friendship_cls.c.friend_b_id
]).union(
sa.select([
friendship_cls.c.friend_b_id,
friendship_cls.c.friend_a_id]
)
).alias()
User.all_friends = sa.orm.relationship(
'User',
secondary=friendship_union,
primaryjoin=User.id == friendship_union.c.friend_a_id,
secondaryjoin=User.id == friendship_union.c.friend_b_id,
viewonly=True,
order_by=User.id
)
return User
@pytest.fixture(scope='class')
def category_cls(base, group_user_cls, friendship_cls):
class Category(base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
created_at = sa.Column(sa.DateTime)
parent_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
parent = sa.orm.relationship(
'Category',
backref='subcategories',
remote_side=[id],
order_by=id
)
return Category
@pytest.fixture(scope='class')
def article_cls(base, category_cls, user_cls):
class Article(base):
__tablename__ = 'article'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column(sa.String)
name_synonym = sa.orm.synonym('name')
@hybrid_property
def name_upper(self):
return self.name.upper() if self.name else None
@name_upper.expression
def name_upper(cls):
return sa.func.upper(cls.name)
content = sa.Column(sa.String)
category_id = sa.Column(sa.Integer, sa.ForeignKey(category_cls.id))
category = sa.orm.relationship(category_cls, backref='articles')
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
author = sa.orm.relationship(
user_cls,
primaryjoin=author_id == user_cls.id,
backref='authored_articles'
)
owner_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
owner = sa.orm.relationship(
user_cls,
primaryjoin=owner_id == user_cls.id,
backref='owned_articles'
)
return Article
@pytest.fixture(scope='class')
def comment_cls(base, article_cls, user_cls):
class Comment(base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.String)
article_id = sa.Column(sa.Integer, sa.ForeignKey(article_cls.id))
article = sa.orm.relationship(article_cls, backref='comments')
author_id = sa.Column(sa.Integer, sa.ForeignKey(user_cls.id))
author = sa.orm.relationship(user_cls, backref='comments')
article_cls.comment_count = sa.orm.column_property(
sa.select([sa.func.count(Comment.id)])
.where(Comment.article_id == article_cls.id)
.correlate_except(article_cls)
)
return Comment
@pytest.fixture(scope='class')
def composite_pk_cls(base):
class CompositePKModel(base):
__tablename__ = 'composite_pk_model'
a = sa.Column(sa.Integer, primary_key=True)
b = sa.Column(sa.Integer, primary_key=True)
return CompositePKModel
@pytest.fixture(scope='class')
def dns():
return 'postgres://postgres@localhost/sqlalchemy_utils_test'
@pytest.yield_fixture(scope='class')
def engine(dns):
engine = create_engine(dns)
engine.echo = True
yield engine
engine.dispose()
@pytest.yield_fixture(scope='class')
def connection(engine):
conn = engine.connect()
yield conn
conn.close()
@pytest.fixture(scope='class')
def model_mapping(article_cls, category_cls, comment_cls, group_cls, user_cls):
return {
'articles': article_cls,
'categories': category_cls,
'comments': comment_cls,
'groups': group_cls,
'users': user_cls
}
@pytest.yield_fixture(scope='class')
def table_creator(base, connection, model_mapping):
sa.orm.configure_mappers()
base.metadata.create_all(connection)
yield
base.metadata.drop_all(connection)
@pytest.yield_fixture(scope='class')
def session(connection):
Session = sessionmaker(bind=connection)
session = Session()
yield session
session.close_all()
@pytest.fixture(scope='class')
def dataset(
session,
user_cls,
group_cls,
article_cls,
category_cls,
comment_cls
):
group = group_cls(name='Group 1')
group2 = group_cls(name='Group 2')
user = user_cls(id=1, name='User 1', groups=[group, group2])
user2 = user_cls(id=2, name='User 2')
user3 = user_cls(id=3, name='User 3', groups=[group])
user4 = user_cls(id=4, name='User 4', groups=[group2])
user5 = user_cls(id=5, name='User 5')
user.friends = [user2]
user2.friends = [user3, user4]
user3.friends = [user5]
article = article_cls(
name='Some article',
author=user,
owner=user2,
category=category_cls(
id=1,
name='Some category',
subcategories=[
category_cls(
id=2,
name='Subcategory 1',
subcategories=[
category_cls(
id=3,
name='Subsubcategory 1',
subcategories=[
category_cls(
id=5,
name='Subsubsubcategory 1',
),
category_cls(
id=6,
name='Subsubsubcategory 2',
)
]
)
]
),
category_cls(id=4, name='Subcategory 2'),
]
),
comments=[
comment_cls(
content='Some comment',
author=user
)
]
)
session.add(user3)
session.add(user4)
session.add(article)
session.commit()
@pytest.mark.usefixtures('table_creator', 'dataset')
class TestSelectCorrelatedExpression(object):
@pytest.mark.parametrize(
('model_key', 'related_model_key', 'path', 'result'),
(
(
'categories',
'categories',
'subcategories',
[
(1, 2),
(2, 1),
(3, 2),
(4, 0),
(5, 0),
(6, 0)
]
),
(
'articles',
'comments',
'comments',
[
(1, 1),
]
),
(
'users',
'groups',
'groups',
[
(1, 2),
(2, 0),
(3, 1),
(4, 1),
(5, 0)
]
),
(
'users',
'users',
'all_friends',
[
(1, 1),
(2, 3),
(3, 2),
(4, 1),
(5, 1)
]
),
(
'users',
'users',
'all_friends.all_friends',
[
(1, 3),
(2, 2),
(3, 3),
(4, 3),
(5, 2)
]
),
(
'users',
'users',
'groups.users',
[
(1, 3),
(2, 0),
(3, 2),
(4, 2),
(5, 0)
]
),
(
'groups',
'articles',
'users.authored_articles',
[
(1, 1),
(2, 1),
]
),
(
'categories',
'categories',
'subcategories.subcategories',
[
(1, 1),
(2, 2),
(3, 0),
(4, 0),
(5, 0),
(6, 0)
]
),
(
'categories',
'categories',
'subcategories.subcategories.subcategories',
[
(1, 2),
(2, 0),
(3, 0),
(4, 0),
(5, 0),
(6, 0)
]
),
)
)
def test_returns_correct_results(
self,
session,
model_mapping,
model_key,
related_model_key,
path,
result
):
model = model_mapping[model_key]
alias = sa.orm.aliased(model_mapping[related_model_key])
aggregate = select_correlated_expression(
model,
sa.func.count(sa.distinct(alias.id)),
path,
alias
)
query = session.query(
model.id,
aggregate.label('count')
).order_by(model.id)
assert query.all() == result
def test_with_non_aggregate_function(
self,
session,
user_cls,
article_cls
):
aggregate = select_correlated_expression(
article_cls,
sa.func.json_build_object('name', user_cls.name),
'comments.author',
user_cls
)
query = session.query(
article_cls.id,
aggregate.label('author_json')
).order_by(article_cls.id)
result = query.all()
assert result == [
(1, {'name': 'User 1'})
]