diff --git a/cereal/__init__.py b/cereal/__init__.py index 1648330..bd93c04 100644 --- a/cereal/__init__.py +++ b/cereal/__init__.py @@ -1,4 +1,5 @@ import pkg_resources +import itertools def resolve_dotted(dottedname, package=None): if dottedname.startswith('.') or dottedname.startswith(':'): @@ -310,6 +311,8 @@ class Integer(object): def deserialize(self, struct, value): return self._validate(struct, value) +Int = Integer + class GlobalObject(object): """ A type representing an importable Python object """ def __init__(self, package): @@ -331,12 +334,19 @@ class GlobalObject(object): 'The dotted name %r cannot be imported' % value) class Structure(object): - def __init__(self, name, typ, *structs, **kw): - self.name = name + _counter = itertools.count() + + def __new__(cls, *arg, **kw): + inst = object.__new__(cls) + inst._order = cls._counter.next() + return inst + + def __init__(self, typ, *structs, **kw): self.typ = typ self.validator = kw.get('validator', None) self.default = kw.get('default', None) self.required = kw.get('required', True) + self.name = kw.get('name', '') self.structs = list(structs) def serialize(self, value): @@ -351,3 +361,44 @@ class Structure(object): def add(self, struct): self.structs.append(struct) +class _SchemaMeta(type): + def __init__(cls, name, bases, clsattrs): + structs = [] + for name, value in clsattrs.items(): + if isinstance(value, Structure): + value.name = name + structs.append((value._order, value)) + cls.__schema_structures__ = structs + # Combine all attrs from this class and its subclasses. + extended = [] + for c in cls.__mro__: + extended.extend(getattr(c, '__schema_structures__', [])) + # Sort the attrs to maintain the order as defined, and assign to the + # class. + extended.sort() + cls.structs = [x[1] for x in extended] + +class Schema(object): + struct_type = Mapping + __metaclass__ = _SchemaMeta + + def __new__(cls, *args, **kw): + inst = object.__new__(Structure) + inst.name = None + inst._order = Structure._counter.next() + struct = cls.struct_type(*args, **kw) + inst.__init__(struct) + for s in cls.structs: + inst.add(s) + return inst + +MappingSchema = Schema + +class SequenceSchema(Schema): + struct_type = Sequence + +class TupleSchema(Schema): + struct_type = Tuple + + + diff --git a/cereal/tests.py b/cereal/tests.py index 7226391..2df5432 100644 --- a/cereal/tests.py +++ b/cereal/tests.py @@ -1,67 +1,6 @@ import unittest -class TestFunctional(unittest.TestCase): - def _makeSchema(self): - import cereal - - integer = cereal.Structure( - 'int', - cereal.Integer(), - validator=cereal.Range(0, 10) - ) - - ob = cereal.Structure( - 'ob', - cereal.GlobalObject(package=cereal), - ) - - tup = cereal.Structure( - 'tup', - cereal.Tuple(), - cereal.Structure( - 'tupint', - cereal.Integer(), - ), - cereal.Structure( - 'tupstring', - cereal.String(), - ), - ) - - seq = cereal.Structure( - 'seq', - cereal.Sequence(tup), - ) - - seq2 = cereal.Structure( - 'seq2', - cereal.Sequence( - cereal.Structure( - 'mapping', - cereal.Mapping(), - cereal.Structure( - 'key', - cereal.Integer(), - ), - cereal.Structure( - 'key2', - cereal.Integer(), - ), - ) - ), - ) - - schema = cereal.Structure( - None, - cereal.Mapping(), - integer, - ob, - tup, - seq, - seq2) - - return schema - +class TestFunctional(object): def test_deserialize_ok(self): import cereal.tests data = { @@ -108,3 +47,90 @@ class TestFunctional(unittest.TestCase): except cereal.Invalid, e: errors = e.asdict() self.assertEqual(errors, expected) + +class TestImperative(unittest.TestCase, TestFunctional): + + def _makeSchema(self): + import cereal + + integer = cereal.Structure( + cereal.Integer(), + name='int', + validator=cereal.Range(0, 10) + ) + + ob = cereal.Structure( + cereal.GlobalObject(package=cereal), + name='ob', + ) + + tup = cereal.Structure( + cereal.Tuple(), + cereal.Structure( + cereal.Integer(), + name='tupint', + ), + cereal.Structure( + cereal.String(), + name='tupstring', + ), + name='tup', + ) + + seq = cereal.Structure( + cereal.Sequence(tup), + name='seq', + ) + + seq2 = cereal.Structure( + cereal.Sequence( + cereal.Structure( + cereal.Mapping(), + cereal.Structure( + cereal.Integer(), + name='key', + ), + cereal.Structure( + cereal.Integer(), + name='key2', + ), + name='mapping', + ) + ), + name='seq2', + ) + + schema = cereal.Structure( + cereal.Mapping(), + integer, + ob, + tup, + seq, + seq2) + + return schema + +class TestDeclarative(unittest.TestCase, TestFunctional): + + def _makeSchema(self): + + import cereal + + class TupleSchema(cereal.TupleSchema): + tupint = cereal.Structure(cereal.Int()) + tupstring = cereal.Structure(cereal.String()) + + class MappingSchema(cereal.MappingSchema): + key = cereal.Structure(cereal.Int()) + key2 = cereal.Structure(cereal.Int()) + + class MainSchema(cereal.MappingSchema): + int = cereal.Structure(cereal.Int(), validator=cereal.Range(0, 10)) + ob = cereal.Structure(cereal.GlobalObject(package=cereal)) + seq = cereal.Structure(cereal.Sequence(TupleSchema())) + tup = TupleSchema() + seq2 = cereal.SequenceSchema(MappingSchema()) + + schema = MainSchema() + return schema +