diff --git a/tests/test_eav.py b/tests/test_eav.py index 19a0492..7e55121 100644 --- a/tests/test_eav.py +++ b/tests/test_eav.py @@ -1,5 +1,14 @@ from pytest import raises import sqlalchemy as sa +import sqlalchemy.ext.associationproxy +from sqlalchemy.ext.associationproxy import ( + AssociationProxy, _AssociationDict +) +from sqlalchemy.orm.collections import ( + attribute_mapped_collection, + collection, + MappedCollection, +) from sqlalchemy_utils import MetaType, MetaValue from tests import TestCase @@ -55,54 +64,71 @@ class TestMetaModel(TestCase): ) -from sqlalchemy.orm.collections import ( - attribute_mapped_collection, - collection, - MappedCollection, -) +class MetaTypedCollection(MappedCollection): + def __init__(self): + self.keyfunc = lambda value: value.attr.name + def __getitem__(self, key): + if not self.key_exists(key): + raise KeyError(key) -class MyMappedCollection(MappedCollection): - def __init__(self, *args, **kwargs): - MappedCollection.__init__( - self, keyfunc=lambda node: node.attr.name - ) + return self.get(key) - def __contains__(self, key): + def key_exists(self, key): adapter = self._sa_adapter obj = adapter.owner_state.object return obj.category and key in obj.category.attributes + @collection.appender @collection.internally_instrumented - def __getitem__(self, key): - if not self.__contains__(key): - raise KeyError(key) + def set(self, *args, **kwargs): + if len(args) > 1: + if not self.key_exists(args[0]): + raise KeyError(args[0]) + arg = args[1] + else: + arg = args[0] - try: - return super(MyMappedCollection, self).__getitem__(key) - except KeyError: + super(MetaTypedCollection, self).set(arg, **kwargs) + + @collection.remover + @collection.internally_instrumented + def remove(self, key): + del self[key] + + +def assoc_dict_factory(lazy_collection, creator, getter, setter, parent): + if isinstance(parent, MetaAssociationProxy): + return MetaAssociationDict( + lazy_collection, creator, getter, setter, parent + ) + else: + return _AssociationDict( + lazy_collection, creator, getter, setter, parent + ) + + +sqlalchemy.ext.associationproxy._AssociationDict = assoc_dict_factory + + +class MetaAssociationDict(_AssociationDict): + def _create(self, key, value): + parent_obj = self.lazy_collection.ref() + class_ = parent_obj.__mapper__.relationships[ + self.lazy_collection.target + ].mapper.class_ + if not parent_obj.category: + raise KeyError(key) + return class_(attr=parent_obj.category.attributes[key], value=value) + + def _get(self, object): + if object is None: return None + return self.getter(object) - @collection.internally_instrumented - def __setitem__(self, key, value, _sa_initiator=None): - if not self.__contains__(key): - raise KeyError(key) - adapter = self._sa_adapter - obj = adapter.owner_state.object - mapper = adapter.owner_state.mapper - class_ = mapper.relationships[adapter._key].mapper.class_ - - if not isinstance(value, class_): - value = class_(attr=obj.category.attributes[key], value=value) - - super(MyMappedCollection, self).__setitem__(key, value, _sa_initiator) - - @collection.internally_instrumented - def __delitem__(self, key, _sa_initiator=None): - # do something with key - print self, key - super(MyMappedCollection, self).__delitem__(key, _sa_initiator) +class MetaAssociationProxy(AssociationProxy): + pass class TestProductCatalog(TestCase): @@ -122,11 +148,13 @@ class TestProductCatalog(TestCase): ) category = sa.orm.relationship(Category) - def __getattr__(self, attr): - return self.attributes[unicode(attr)].value + attributes = MetaAssociationProxy( + 'attribute_objects', + 'value', + ) class Attribute(self.Base): - __tablename__ = 'product_attribute' + __tablename__ = 'attribute' id = sa.Column(sa.Integer, primary_key=True) data_type = sa.Column( MetaType({ @@ -157,8 +185,8 @@ class TestProductCatalog(TestCase): product = sa.orm.relationship( Product, backref=sa.orm.backref( - 'attributes', - collection_class=MyMappedCollection + 'attribute_objects', + collection_class=MetaTypedCollection ) ) @@ -177,6 +205,12 @@ class TestProductCatalog(TestCase): self.Attribute = Attribute self.AttributeValue = AttributeValue + def test_attr_value_setting(self): + attr = self.Attribute(data_type=sa.UnicodeText) + value = self.AttributeValue(attr=attr) + value.value = u'some answer' + assert u'some answer' == value.value_unicode + def test_unknown_attribute_key(self): product = self.Product() @@ -211,8 +245,36 @@ class TestProductCatalog(TestCase): self.session.add(product) self.session.commit() + product.attribute_objects[u'color'] = self.AttributeValue( + attr=category.attributes['color'], value=u'red' + ) + product.attribute_objects[u'maxspeed'] = self.AttributeValue( + attr=category.attributes['maxspeed'], value=300 + ) + assert product.attribute_objects[u'color'].value_unicode == u'red' + assert product.attribute_objects[u'maxspeed'].value_int == 300 + self.session.commit() + + assert product.attribute_objects[u'color'].value == u'red' + assert product.attribute_objects[u'maxspeed'].value == 300 + + def test_association_proxies(self): + category = self.Category(name=u'cars') + category.attributes = { + u'color': self.Attribute(name=u'color', data_type=sa.UnicodeText), + u'maxspeed': self.Attribute(name=u'maxspeed', data_type=sa.Integer) + } + product = self.Product( + name=u'Porsche 911', + category=category + ) + self.session.add(product) + self.session.commit() + product.attributes[u'color'] = u'red' product.attributes[u'maxspeed'] = 300 + assert product.attributes[u'color'] == u'red' + assert product.attributes[u'maxspeed'] == 300 self.session.commit() assert product.attributes[u'color'] == u'red' @@ -234,11 +296,8 @@ class TestProductCatalog(TestCase): product.attributes[u'maxspeed'] = 300 self.session.commit() - assert product.color == u'red' - assert product.maxspeed == 300 - ( self.session.query(self.Product) - .filter(self.Product.color.in_([u'red', u'blue'])) - .order_by(self.Product.color) + .filter(self.Product.attributes['color'].in_([u'red', u'blue'])) + .order_by(self.Product.attributes['color']) )