refresh UDTs on keyspace update for v[12]

PYTHON-106
This commit is contained in:
Adam Holmberg
2016-08-17 12:50:51 -05:00
parent fa5cf2d503
commit a4cf9f6afc

View File

@@ -131,7 +131,12 @@ class Metadata(object):
meta = parse_method(self.keyspaces, **kwargs)
if meta:
update_method = getattr(self, '_update_' + tt_lower)
update_method(meta)
if tt_lower == 'keyspace' and connection.protocol_version < 3:
# we didn't have 'type' target in legacy protocol versions, so we need to query those too
user_types = parser.get_types_map(self.keyspaces, **kwargs)
self._update_keyspace(meta, user_types)
else:
update_method(meta)
else:
drop_method = getattr(self, '_drop_' + tt_lower)
drop_method(**kwargs)
@@ -157,13 +162,13 @@ class Metadata(object):
for ksname in removed_keyspaces:
self._keyspace_removed(ksname)
def _update_keyspace(self, keyspace_meta):
def _update_keyspace(self, keyspace_meta, new_user_types=None):
ks_name = keyspace_meta.name
old_keyspace_meta = self.keyspaces.get(ks_name, None)
self.keyspaces[ks_name] = keyspace_meta
if old_keyspace_meta:
keyspace_meta.tables = old_keyspace_meta.tables
keyspace_meta.user_types = old_keyspace_meta.user_types
keyspace_meta.user_types = new_user_types or old_keyspace_meta.user_types
keyspace_meta.indexes = old_keyspace_meta.indexes
keyspace_meta.functions = old_keyspace_meta.functions
keyspace_meta.aggregates = old_keyspace_meta.aggregates
@@ -1588,11 +1593,14 @@ class _SchemaParser(object):
raise result
def _query_build_row(self, query_string, build_func):
result = self._query_build_rows(query_string, build_func)
return result[0] if result else None
def _query_build_rows(self, query_string, build_func):
query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE)
response = self.connection.wait_for_response(query, self.timeout)
result = dict_factory(*response.results)
if result:
return build_func(result[0])
return [build_func(row) for row in result]
class SchemaParserV22(_SchemaParser):
@@ -1701,6 +1709,11 @@ class SchemaParserV22(_SchemaParser):
where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder)
return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type)
def get_types_map(self, keyspaces, keyspace):
where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder)
types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type)
return dict((t.name, t) for t in types)
def get_function(self, keyspaces, keyspace, function):
where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,),
(keyspace, function.name, function.argument_types), _encoder)