Added more experimental tests for EAV

This commit is contained in:
Konsta Vesterinen
2013-11-15 11:15:07 +02:00
parent 536eeb212c
commit 741a6b8728

View File

@@ -1,5 +1,14 @@
from pytest import raises
import sqlalchemy as sa
import sqlalchemy.ext.associationproxy
from sqlalchemy.ext.associationproxy import (
AssociationProxy, _AssociationDict
)
from sqlalchemy.orm.collections import (
attribute_mapped_collection,
collection,
MappedCollection,
)
from sqlalchemy_utils import MetaType, MetaValue
from tests import TestCase
@@ -55,54 +64,71 @@ class TestMetaModel(TestCase):
)
from sqlalchemy.orm.collections import (
attribute_mapped_collection,
collection,
MappedCollection,
)
class MetaTypedCollection(MappedCollection):
def __init__(self):
self.keyfunc = lambda value: value.attr.name
def __getitem__(self, key):
if not self.key_exists(key):
raise KeyError(key)
class MyMappedCollection(MappedCollection):
def __init__(self, *args, **kwargs):
MappedCollection.__init__(
self, keyfunc=lambda node: node.attr.name
)
return self.get(key)
def __contains__(self, key):
def key_exists(self, key):
adapter = self._sa_adapter
obj = adapter.owner_state.object
return obj.category and key in obj.category.attributes
@collection.appender
@collection.internally_instrumented
def __getitem__(self, key):
if not self.__contains__(key):
raise KeyError(key)
def set(self, *args, **kwargs):
if len(args) > 1:
if not self.key_exists(args[0]):
raise KeyError(args[0])
arg = args[1]
else:
arg = args[0]
try:
return super(MyMappedCollection, self).__getitem__(key)
except KeyError:
super(MetaTypedCollection, self).set(arg, **kwargs)
@collection.remover
@collection.internally_instrumented
def remove(self, key):
del self[key]
def assoc_dict_factory(lazy_collection, creator, getter, setter, parent):
if isinstance(parent, MetaAssociationProxy):
return MetaAssociationDict(
lazy_collection, creator, getter, setter, parent
)
else:
return _AssociationDict(
lazy_collection, creator, getter, setter, parent
)
sqlalchemy.ext.associationproxy._AssociationDict = assoc_dict_factory
class MetaAssociationDict(_AssociationDict):
def _create(self, key, value):
parent_obj = self.lazy_collection.ref()
class_ = parent_obj.__mapper__.relationships[
self.lazy_collection.target
].mapper.class_
if not parent_obj.category:
raise KeyError(key)
return class_(attr=parent_obj.category.attributes[key], value=value)
def _get(self, object):
if object is None:
return None
return self.getter(object)
@collection.internally_instrumented
def __setitem__(self, key, value, _sa_initiator=None):
if not self.__contains__(key):
raise KeyError(key)
adapter = self._sa_adapter
obj = adapter.owner_state.object
mapper = adapter.owner_state.mapper
class_ = mapper.relationships[adapter._key].mapper.class_
if not isinstance(value, class_):
value = class_(attr=obj.category.attributes[key], value=value)
super(MyMappedCollection, self).__setitem__(key, value, _sa_initiator)
@collection.internally_instrumented
def __delitem__(self, key, _sa_initiator=None):
# do something with key
print self, key
super(MyMappedCollection, self).__delitem__(key, _sa_initiator)
class MetaAssociationProxy(AssociationProxy):
pass
class TestProductCatalog(TestCase):
@@ -122,11 +148,13 @@ class TestProductCatalog(TestCase):
)
category = sa.orm.relationship(Category)
def __getattr__(self, attr):
return self.attributes[unicode(attr)].value
attributes = MetaAssociationProxy(
'attribute_objects',
'value',
)
class Attribute(self.Base):
__tablename__ = 'product_attribute'
__tablename__ = 'attribute'
id = sa.Column(sa.Integer, primary_key=True)
data_type = sa.Column(
MetaType({
@@ -157,8 +185,8 @@ class TestProductCatalog(TestCase):
product = sa.orm.relationship(
Product,
backref=sa.orm.backref(
'attributes',
collection_class=MyMappedCollection
'attribute_objects',
collection_class=MetaTypedCollection
)
)
@@ -177,6 +205,12 @@ class TestProductCatalog(TestCase):
self.Attribute = Attribute
self.AttributeValue = AttributeValue
def test_attr_value_setting(self):
attr = self.Attribute(data_type=sa.UnicodeText)
value = self.AttributeValue(attr=attr)
value.value = u'some answer'
assert u'some answer' == value.value_unicode
def test_unknown_attribute_key(self):
product = self.Product()
@@ -211,8 +245,36 @@ class TestProductCatalog(TestCase):
self.session.add(product)
self.session.commit()
product.attribute_objects[u'color'] = self.AttributeValue(
attr=category.attributes['color'], value=u'red'
)
product.attribute_objects[u'maxspeed'] = self.AttributeValue(
attr=category.attributes['maxspeed'], value=300
)
assert product.attribute_objects[u'color'].value_unicode == u'red'
assert product.attribute_objects[u'maxspeed'].value_int == 300
self.session.commit()
assert product.attribute_objects[u'color'].value == u'red'
assert product.attribute_objects[u'maxspeed'].value == 300
def test_association_proxies(self):
category = self.Category(name=u'cars')
category.attributes = {
u'color': self.Attribute(name=u'color', data_type=sa.UnicodeText),
u'maxspeed': self.Attribute(name=u'maxspeed', data_type=sa.Integer)
}
product = self.Product(
name=u'Porsche 911',
category=category
)
self.session.add(product)
self.session.commit()
product.attributes[u'color'] = u'red'
product.attributes[u'maxspeed'] = 300
assert product.attributes[u'color'] == u'red'
assert product.attributes[u'maxspeed'] == 300
self.session.commit()
assert product.attributes[u'color'] == u'red'
@@ -234,11 +296,8 @@ class TestProductCatalog(TestCase):
product.attributes[u'maxspeed'] = 300
self.session.commit()
assert product.color == u'red'
assert product.maxspeed == 300
(
self.session.query(self.Product)
.filter(self.Product.color.in_([u'red', u'blue']))
.order_by(self.Product.color)
.filter(self.Product.attributes['color'].in_([u'red', u'blue']))
.order_by(self.Product.attributes['color'])
)