diff --git a/tricircle/tests/unit/utils.py b/tricircle/tests/unit/utils.py index 4987e7b5..901470ed 100644 --- a/tricircle/tests/unit/utils.py +++ b/tricircle/tests/unit/utils.py @@ -144,10 +144,11 @@ class DotList(list): class FakeQuery(object): pk_map = {'ports': 'id'} - def __init__(self, records, table): + def __init__(self, records, table, field=None): self.records = records self.table = table self.index = 0 + self.field = field def _handle_pagination_by_id(self, record_id): for i, record in enumerate(self.records): @@ -156,7 +157,7 @@ class FakeQuery(object): return FakeQuery(self.records[i + 1:], self.table) else: return FakeQuery([], self.table) - return FakeQuery([], self.table) + return FakeQuery([], self.table, self.field) def _handle_filter(self, keys, values): filtered_list = [] @@ -168,7 +169,7 @@ class FakeQuery(object): break if selected: filtered_list.append(record) - return FakeQuery(filtered_list, self.table) + return FakeQuery(filtered_list, self.table, self.field) def filter(self, *criteria): _filter = [] @@ -194,7 +195,7 @@ class FakeQuery(object): values.append(True) if not _filter: if not keys: - return FakeQuery(self.records, self.table) + return FakeQuery(self.records, self.table, self.field) else: return self._handle_filter(keys, values) if hasattr(_filter[0].right, 'value'): @@ -219,7 +220,7 @@ class FakeQuery(object): break if selected: filtered_list.append(record) - return FakeQuery(filtered_list, self.table) + return FakeQuery(filtered_list, self.table, self.field) def get(self, pk): pk_field = self.pk_map[self.table] @@ -231,26 +232,33 @@ class FakeQuery(object): pass def outerjoin(self, *props, **kwargs): - return FakeQuery(self.records, self.table) + return FakeQuery(self.records, self.table, self.field) def join(self, *props, **kwargs): - return FakeQuery(self.records, self.table) + return FakeQuery(self.records, self.table, self.field) def order_by(self, func): self.records.sort(key=lambda x: x['id']) - return FakeQuery(self.records, self.table) + return FakeQuery(self.records, self.table, self.field) def enable_eagerloads(self, value): - return FakeQuery(self.records, self.table) + return FakeQuery(self.records, self.table, self.field) def limit(self, limit): - return FakeQuery(self.records[:limit], self.table) + return FakeQuery(self.records[:limit], self.table, self.field) def next(self): if self.index >= len(self.records): raise StopIteration self.index += 1 - return self.records[self.index - 1] + record = self.records[self.index - 1] + # populate integer indices + i = 0 + for key, value in list(record.items()): + if key == self.field: + record[i] = value + i += 1 + return record __next__ = next @@ -359,12 +367,14 @@ class FakeSession(object): return FakeSession.WithWrapper() def query(self, model): + field = None if isinstance(model, attributes.InstrumentedAttribute): + field = model.key model = model.class_ if model.__tablename__ not in self.resource_store.store_map: - return FakeQuery([], model.__tablename__) + return FakeQuery([], model.__tablename__, field) return FakeQuery(self.resource_store.store_map[model.__tablename__], - model.__tablename__) + model.__tablename__, field) def _extend_standard_attr(self, model_dict): if 'standard_attr' in model_dict: