diff --git a/designate/central/service.py b/designate/central/service.py index 24d6953a..f72d4a10 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -129,7 +129,7 @@ class Service(service.Service): if self.check_for_tlds: try: self.storage.find_tld(context, {'name': domain_labels[-1]}) - except exceptions.TLDNotFound: + except exceptions.TldNotFound: raise exceptions.InvalidDomainName('Invalid TLD') # Now check that the domain name is not the same as a TLD @@ -138,7 +138,7 @@ class Service(service.Service): self.storage.find_tld( context, {'name': stripped_domain_name}) - except exceptions.TLDNotFound: + except exceptions.TldNotFound: pass else: raise exceptions.InvalidDomainName( diff --git a/designate/exceptions.py b/designate/exceptions.py index 9fd8d746..e1029e58 100644 --- a/designate/exceptions.py +++ b/designate/exceptions.py @@ -187,7 +187,7 @@ class DuplicateDomain(Duplicate): error_type = 'duplicate_domain' -class DuplicateTLD(Duplicate): +class DuplicateTld(Duplicate): error_type = 'duplicate_tld' @@ -235,7 +235,7 @@ class DomainNotFound(NotFound): error_type = 'domain_not_found' -class TLDNotFound(NotFound): +class TldNotFound(NotFound): error_type = 'tld_not_found' diff --git a/designate/manage/tlds.py b/designate/manage/tlds.py index 0c5bb4f6..e1669ae6 100644 --- a/designate/manage/tlds.py +++ b/designate/manage/tlds.py @@ -49,8 +49,8 @@ class TLDCommands(base.Commands): --> can be one of the following: - DuplicateTLD - This occurs if the TLD is already present. - InvalidTLD - This occurs if the TLD does not conform to the TLD schema. + DuplicateTld - This occurs if the TLD is already present. + InvalidTld - This occurs if the TLD does not conform to the TLD schema. InvalidDescription - This occurs if the description does not conform to the description schema InvalidLine - This occurs if the line contains more than 2 fields. @@ -81,7 +81,7 @@ class TLDCommands(base.Commands): def _validate_and_create_tld(self, line, error_lines): # validate the tld name if not format.is_tldname(line['name']): - error_lines.append("InvalidTLD --> " + + error_lines.append("InvalidTld --> " + self._convert_tld_dict_to_str(line)) return 0 # validate the description if there is one @@ -94,8 +94,8 @@ class TLDCommands(base.Commands): try: self.central_api.create_tld(self.context, values=line) return 1 - except exceptions.DuplicateTLD: - error_lines.append("DuplicateTLD --> " + + except exceptions.DuplicateTld: + error_lines.append("DuplicateTld --> " + self._convert_tld_dict_to_str(line)) return 0 diff --git a/designate/objects/base.py b/designate/objects/base.py index 6b4b916f..2b97bdf7 100644 --- a/designate/objects/base.py +++ b/designate/objects/base.py @@ -367,3 +367,12 @@ class PersistentObjectMixin(object): This adds the fields that we use in common for all persisent objects. """ FIELDS = ['id', 'created_at', 'updated_at', 'version'] + + +class SoftDeleteObjectMixin(object): + """ + Mixin class for Soft-Deleted objects. + + This adds the fields that we use in common for all soft-deleted objects. + """ + FIELDS = ['deleted', 'deleted_at'] diff --git a/designate/objects/domain.py b/designate/objects/domain.py index 08f857ba..43ed776e 100644 --- a/designate/objects/domain.py +++ b/designate/objects/domain.py @@ -15,8 +15,8 @@ from designate.objects import base -class Domain(base.DictObjectMixin, base.PersistentObjectMixin, - base.DesignateObject): +class Domain(base.DictObjectMixin, base.SoftDeleteObjectMixin, + base.PersistentObjectMixin, base.DesignateObject): FIELDS = ['tenant_id', 'name', 'email', 'ttl', 'refresh', 'retry', 'expire', 'minimum', 'parent_domain_id', 'serial', 'description', 'status'] diff --git a/designate/objects/record.py b/designate/objects/record.py index 69636337..589f4206 100644 --- a/designate/objects/record.py +++ b/designate/objects/record.py @@ -17,6 +17,8 @@ from designate.objects import base class Record(base.DictObjectMixin, base.PersistentObjectMixin, base.DesignateObject): + # TODO(kiall): `hash` is an implementation detail of our SQLA driver, + # so we should remove it. FIELDS = ['data', 'priority', 'domain_id', 'managed', 'managed_resource_type', 'managed_resource_id', 'managed_plugin_name', 'managed_plugin_type', 'hash', diff --git a/designate/sqlalchemy/utils.py b/designate/sqlalchemy/utils.py new file mode 100644 index 00000000..953da791 --- /dev/null +++ b/designate/sqlalchemy/utils.py @@ -0,0 +1,96 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010-2011 OpenStack Foundation. +# Copyright 2012 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import logging + +import sqlalchemy +from oslo.db.sqlalchemy import utils + +from designate.i18n import _ +from designate.i18n import _LW + + +LOG = logging.getLogger(__name__) + + +# copy from olso/db/sqlalchemy/utils.py +def paginate_query(query, table, limit, sort_keys, marker=None, + sort_dir=None, sort_dirs=None): + if 'id' not in sort_keys: + # TODO(justinsb): If this ever gives a false-positive, check + # the actual primary key, rather than assuming its id + LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?')) + + assert(not (sort_dir and sort_dirs)) + + # Default the sort direction to ascending + if sort_dirs is None and sort_dir is None: + sort_dir = 'asc' + + # Ensure a per-column sort direction + if sort_dirs is None: + sort_dirs = [sort_dir for _sort_key in sort_keys] + + assert(len(sort_dirs) == len(sort_keys)) + + # Add sorting + for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): + try: + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + except KeyError: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) + try: + sort_key_attr = getattr(table.c, current_sort_key) + except AttributeError: + raise utils.InvalidSortKey() + query = query.order_by(sort_dir_func(sort_key_attr)) + + # Add pagination + if marker is not None: + marker_values = [] + for sort_key in sort_keys: + v = marker[sort_key] + marker_values.append(v) + + # Build up an array of sort criteria as in the docstring + criteria_list = [] + for i in range(len(sort_keys)): + crit_attrs = [] + for j in range(i): + table_attr = getattr(table.c, sort_keys[j]) + crit_attrs.append((table_attr == marker_values[j])) + + table_attr = getattr(table.c, sort_keys[i]) + if sort_dirs[i] == 'desc': + crit_attrs.append((table_attr < marker_values[i])) + else: + crit_attrs.append((table_attr > marker_values[i])) + + criteria = sqlalchemy.sql.and_(*crit_attrs) + criteria_list.append(criteria) + + f = sqlalchemy.sql.or_(*criteria_list) + query = query.where(f) + + if limit is not None: + query = query.limit(limit) + + return query diff --git a/designate/storage/base.py b/designate/storage/base.py index 698136eb..d9ee63e0 100644 --- a/designate/storage/base.py +++ b/designate/storage/base.py @@ -28,12 +28,12 @@ class Storage(DriverPlugin): __plugin_type__ = 'storage' @abc.abstractmethod - def create_quota(self, context, values): + def create_quota(self, context, quota): """ Create a Quota. :param context: RPC Context. - :param values: Values to create the new Quota from. + :param quota: Quota object with the values to be created. """ @abc.abstractmethod diff --git a/designate/storage/impl_sqlalchemy/__init__.py b/designate/storage/impl_sqlalchemy/__init__.py index 02b029bf..9b6c4c7b 100644 --- a/designate/storage/impl_sqlalchemy/__init__.py +++ b/designate/storage/impl_sqlalchemy/__init__.py @@ -15,21 +15,23 @@ # under the License. import time import threading +import hashlib from oslo.config import cfg from oslo.db.sqlalchemy import utils as oslo_utils from oslo.db import options -from sqlalchemy.orm import exc +from oslo.db import exception as oslo_db_exception from sqlalchemy import exc as sqlalchemy_exc -from sqlalchemy import distinct, func +from sqlalchemy import select, distinct, func from designate.openstack.common import log as logging +from designate.openstack.common import timeutils from designate import exceptions from designate import objects from designate.sqlalchemy import session +from designate.sqlalchemy import utils from designate.storage import base -from designate.storage.impl_sqlalchemy import models -from designate.sqlalchemy.models import SoftDeleteMixin +from designate.storage.impl_sqlalchemy import tables LOG = logging.getLogger(__name__) @@ -103,81 +105,115 @@ class SQLAlchemyStorage(base.Storage): def rollback(self): self.session.rollback() - def _apply_criterion(self, model, query, criterion): + def _apply_criterion(self, table, query, criterion): if criterion is not None: for name, value in criterion.items(): - column = getattr(model, name) + column = getattr(table.c, name) # Wildcard value: '*' if isinstance(value, basestring) and '*' in value: queryval = value.replace('*', '%') - query = query.filter(column.like(queryval)) + query = query.where(column.like(queryval)) else: - query = query.filter(column == value) + query = query.where(column == value) return query - def _apply_tenant_criteria(self, context, model, query): - if hasattr(model, 'tenant_id'): + def _apply_tenant_criteria(self, context, table, query): + if hasattr(table.c, 'tenant_id'): if context.all_tenants: LOG.debug('Including all tenants items in query results') else: - query = query.filter(model.tenant_id == context.tenant) + query = query.where(table.c.tenant_id == context.tenant) return query - def _apply_deleted_criteria(self, context, model, query): - if issubclass(model, SoftDeleteMixin): + def _apply_deleted_criteria(self, context, table, query): + if hasattr(table.c, 'deleted'): if context.show_deleted: LOG.debug('Including deleted items in query results') else: - query = query.filter(model.deleted == "0") + query = query.where(table.c.deleted == "0") return query - def _find(self, model, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - """ - Base "finder" method + def _create(self, table, obj, exc_dup, skip_values=None): + values = obj.obj_get_changes() - Used to abstract these details from all the _find_*() methods. - """ - # First up, create a query and apply the various filters - query = self.session.query(model) - query = self._apply_criterion(model, query, criterion) - query = self._apply_tenant_criteria(context, model, query) - query = self._apply_deleted_criteria(context, model, query) + if skip_values is not None: + for skip_value in skip_values: + values.pop(skip_value, None) + query = table.insert() + + try: + resultproxy = self.session.execute(query, [dict(values)]) + except oslo_db_exception.DBDuplicateEntry: + raise exc_dup() + + # Refetch the row, for generated columns etc + query = select([table]).where( + table.c.id == resultproxy.inserted_primary_key[0]) + resultproxy = self.session.execute(query) + + return _set_object_from_model(obj, resultproxy.fetchone()) + + def _find(self, context, table, cls, list_cls, exc_notfound, criterion, + one=False, marker=None, limit=None, sort_key=None, + sort_dir=None): + sort_key = sort_key or 'created_at' + sort_dir = sort_dir or 'asc' + + # Build the query + query = select([table]) + query = self._apply_criterion(table, query, criterion) + query = self._apply_tenant_criteria(context, table, query) + query = self._apply_deleted_criteria(context, table, query) + + # Execute the Query if one: - # If we're asked to return exactly one record, but multiple or - # none match, raise a NotFound - try: - return query.one() - except (exc.NoResultFound, exc.MultipleResultsFound): - raise exceptions.NotFound() + # NOTE(kiall): If we expect one value, and two rows match, we raise + # a NotFound. Limiting to 2 allows us to determine + # when we need to raise, while selecting the minimal + # number of rows. + resultproxy = self.session.execute(query.limit(2)) + results = resultproxy.fetchall() + + if len(results) != 1: + raise exc_notfound() + else: + return _set_object_from_model(cls(), results[0]) else: - # If marker is not none and basestring we query it. - # Otherwise, return all matching records if marker is not None: + # If marker is not none and basestring we query it. + # Otherwise, return all matching records + marker_query = select([table]).where(table.c.id == marker) + try: - marker = self._find(model, context, {'id': marker}, - one=True) - except exceptions.NotFound: - raise exceptions.MarkerNotFound( - 'Marker %s could not be found' % marker) - # Malformed UUIDs return StatementError - except sqlalchemy_exc.StatementError as statement_error: - raise exceptions.InvalidMarker(statement_error.message) - sort_key = sort_key or 'created_at' - sort_dir = sort_dir or 'asc' + marker_resultproxy = self.session.execute(marker_query) + marker = marker_resultproxy.fetchone() + if marker is None: + raise exceptions.MarkerNotFound( + 'Marker %s could not be found' % marker) + except oslo_db_exception.DBError as e: + # Malformed UUIDs return StatementError wrapped in a + # DBError + if isinstance(e.inner_exception, + sqlalchemy_exc.StatementError): + raise exceptions.InvalidMarker() + else: + raise try: - query = oslo_utils.paginate_query( - query, model, limit, + query = utils.paginate_query( + query, table, limit, [sort_key, 'id', 'created_at'], marker=marker, sort_dir=sort_dir) - return query.all() + resultproxy = self.session.execute(query) + results = resultproxy.fetchall() + + return _set_listobject_from_models(list_cls(), results) except oslo_utils.InvalidSortKey as sort_key_error: raise exceptions.InvalidSortKey(sort_key_error.message) # Any ValueErrors are propagated back to the user as is. @@ -187,6 +223,64 @@ class SQLAlchemyStorage(base.Storage): except ValueError as value_error: raise exceptions.ValueError(value_error.message) + def _update(self, context, table, obj, exc_dup, exc_notfound, + skip_values=None): + values = obj.obj_get_changes() + + if skip_values is not None: + for skip_value in skip_values: + values.pop(skip_value, None) + + query = table.update()\ + .where(table.c.id == obj.id)\ + .values(**values) + + query = self._apply_tenant_criteria(context, table, query) + query = self._apply_deleted_criteria(context, table, query) + + try: + resultproxy = self.session.execute(query) + except oslo_db_exception.DBDuplicateEntry: + raise exc_dup() + + if resultproxy.rowcount != 1: + raise exc_notfound() + + # Refetch the row, for generated columns etc + query = select([table]).where(table.c.id == obj.id) + resultproxy = self.session.execute(query) + + return _set_object_from_model(obj, resultproxy.fetchone()) + + def _delete(self, context, table, obj, exc_notfound): + if hasattr(table.c, 'deleted'): + # Perform a Soft Delete + # TODO(kiall): If the object has any changed fields, they will be + # persisted here when we don't want that. + obj.deleted = obj.id.replace('-', '') + obj.deleted_at = timeutils.utcnow() + + # NOTE(kiall): It should be impossible for a duplicate exception to + # be raised in this call, therefore, it is OK to pass + # in "None" as the exc_dup param. + return self._update(context, table, obj, None, exc_notfound) + + # Delete the quota. + query = table.delete().where(table.c.id == obj.id) + query = self._apply_tenant_criteria(context, table, query) + query = self._apply_deleted_criteria(context, table, query) + + resultproxy = self.session.execute(query) + + if resultproxy.rowcount != 1: + raise exc_notfound() + + # Refetch the row, for generated columns etc + query = select([table]).where(table.c.id == obj.id) + resultproxy = self.session.execute(query) + + return _set_object_from_model(obj, resultproxy.fetchone()) + # CRUD for our resources (quota, server, tsigkey, tenant, domain & record) # R - get_*, find_*s # @@ -198,242 +292,164 @@ class SQLAlchemyStorage(base.Storage): # # Quota Methods - def _find_quotas(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.Quota, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.QuotaNotFound() + def _find_quotas(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.quotas, objects.Quota, objects.QuotaList, + exceptions.QuotaNotFound, criterion, one, marker, limit, + sort_key, sort_dir) - def create_quota(self, context, values): - quota = models.Quota() + def create_quota(self, context, quota): + if not isinstance(quota, objects.Quota): + # TODO(kiall): Quotas should always use Objects + quota = objects.Quota(**quota) - quota.update(values) - - try: - quota.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateQuota() - - return _set_object_from_model(objects.Quota(), quota) + return self._create( + tables.quotas, quota, exceptions.DuplicateQuota) def get_quota(self, context, quota_id): + return self._find_quotas(context, {'id': quota_id}, one=True) + + def find_quotas(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_quotas(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) + + def find_quota(self, context, criterion): + return self._find_quotas(context, criterion, one=True) + + def update_quota(self, context, quota): + return self._update( + context, tables.quotas, quota, exceptions.DuplicateQuota, + exceptions.QuotaNotFound) + + def delete_quota(self, context, quota_id): + # Fetch the existing quota, we'll need to return it. quota = self._find_quotas(context, {'id': quota_id}, one=True) + return self._delete(context, tables.quotas, quota, + exceptions.QuotaNotFound) - return _set_object_from_model(objects.Quota(), quota) + # Server Methods + def _find_servers(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.servers, objects.Server, objects.ServerList, + exceptions.ServerNotFound, criterion, one, marker, limit, + sort_key, sort_dir) - def find_quotas(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - quotas = self._find_quotas(context, criterion, marker=marker, + def create_server(self, context, server): + return self._create( + tables.servers, server, exceptions.DuplicateServer) + + def get_server(self, context, server_id): + return self._find_servers(context, {'id': server_id}, one=True) + + def find_servers(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_servers(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) + + def find_server(self, context, criterion): + return self._find_servers(context, criterion, one=True) + + def update_server(self, context, server): + return self._update( + context, tables.servers, server, exceptions.DuplicateServer, + exceptions.ServerNotFound) + + def delete_server(self, context, server_id): + # Fetch the existing server, we'll need to return it. + server = self._find_servers(context, {'id': server_id}, one=True) + return self._delete(context, tables.servers, server, + exceptions.ServerNotFound) + + # TLD Methods + def _find_tlds(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.tlds, objects.Tld, objects.TldList, + exceptions.TldNotFound, criterion, one, marker, limit, + sort_key, sort_dir) + + def create_tld(self, context, tld): + return self._create( + tables.tlds, tld, exceptions.DuplicateTld) + + def get_tld(self, context, tld_id): + return self._find_tlds(context, {'id': tld_id}, one=True) + + def find_tlds(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_tlds(context, criterion, marker=marker, limit=limit, + sort_key=sort_key, sort_dir=sort_dir) + + def find_tld(self, context, criterion): + return self._find_tlds(context, criterion, one=True) + + def update_tld(self, context, tld): + return self._update( + context, tables.tlds, tld, exceptions.DuplicateTld, + exceptions.TldNotFound) + + def delete_tld(self, context, tld_id): + # Fetch the existing tld, we'll need to return it. + tld = self._find_tlds(context, {'id': tld_id}, one=True) + return self._delete(context, tables.tlds, tld, exceptions.TldNotFound) + + # TSIG Key Methods + def _find_tsigkeys(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.tsigkeys, objects.TsigKey, objects.TsigKeyList, + exceptions.TsigKeyNotFound, criterion, one, marker, limit, + sort_key, sort_dir) + + def create_tsigkey(self, context, tsigkey): + return self._create( + tables.tsigkeys, tsigkey, exceptions.DuplicateTsigKey) + + def get_tsigkey(self, context, tsigkey_id): + return self._find_tsigkeys(context, {'id': tsigkey_id}, one=True) + + def find_tsigkeys(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_tsigkeys(context, criterion, marker=marker, limit=limit, sort_key=sort_key, sort_dir=sort_dir) - return _set_listobject_from_models(objects.QuotaList(), quotas) - - def find_quota(self, context, criterion): - quota = self._find_quotas(context, criterion, one=True) - - return _set_object_from_model(objects.Quota(), quota) - - def update_quota(self, context, quota): - storage_quota = self._find_quotas(context, {'id': quota.id}, one=True) - - storage_quota.update(quota.obj_get_changes()) - - try: - storage_quota.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateQuota() - - return _set_object_from_model(quota, storage_quota) - - def delete_quota(self, context, quota_id): - quota = self._find_quotas(context, {'id': quota_id}, one=True) - - quota.delete(self.session) - - return _set_object_from_model(objects.Quota(), quota) - - # Server Methods - def _find_servers(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.Server, context, criterion, one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.ServerNotFound() - - def create_server(self, context, server): - storage_server = models.Server() - - storage_server.update(server) - - try: - storage_server.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateServer() - - return _set_object_from_model(server, storage_server) - - def find_servers(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - servers = self._find_servers(context, criterion, marker=marker, - limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - - return _set_listobject_from_models(objects.ServerList(), servers) - - def get_server(self, context, server_id): - server = self._find_servers(context, {'id': server_id}, one=True) - return _set_object_from_model(objects.Server(), server) - - def update_server(self, context, server): - storage_server = self._find_servers(context, {'id': server.id}, - one=True) - - storage_server.update(server.obj_get_changes()) - - try: - storage_server.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateServer() - - return _set_object_from_model(server, storage_server) - - def delete_server(self, context, server_id): - server = self._find_servers(context, {'id': server_id}, one=True) - - server.delete(self.session) - - return _set_object_from_model(objects.Server(), server) - - # TLD Methods - def _find_tlds(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.Tld, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.TLDNotFound() - - def create_tld(self, context, tld): - storage_tld = models.Tld() - storage_tld.update(tld) - - try: - storage_tld.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateTLD() - - return _set_object_from_model(tld, storage_tld) - - def find_tlds(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - tlds = self._find_tlds(context, criterion, marker=marker, limit=limit, - sort_key=sort_key, sort_dir=sort_dir) - - return _set_listobject_from_models(objects.TldList(), tlds) - - def find_tld(self, context, criterion): - tld = self._find_tlds(context, criterion, one=True) - return _set_object_from_model(objects.Tld(), tld) - - def get_tld(self, context, tld_id): - tld = self._find_tlds(context, {'id': tld_id}, one=True) - return _set_object_from_model(objects.Tld(), tld) - - def update_tld(self, context, tld): - storage_tld = self._find_tlds(context, {'id': tld.id}, one=True) - storage_tld.update(tld.obj_get_changes()) - - try: - storage_tld.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateTLD() - - return _set_object_from_model(tld, storage_tld) - - def delete_tld(self, context, tld_id): - tld = self._find_tlds(context, {'id': tld_id}, one=True) - tld.delete(self.session) - - return _set_object_from_model(objects.Tld(), tld) - - # TSIG Key Methods - def _find_tsigkeys(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.TsigKey, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.TsigKeyNotFound() - - def create_tsigkey(self, context, tsigkey): - storage_tsigkey = models.TsigKey() - - storage_tsigkey.update(tsigkey) - - try: - storage_tsigkey.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateTsigKey() - - return _set_object_from_model(tsigkey, storage_tsigkey) - - def find_tsigkeys(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - tsigkeys = self._find_tsigkeys(context, criterion, marker=marker, - limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - - return _set_listobject_from_models(objects.TsigKeyList(), tsigkeys) - - def get_tsigkey(self, context, tsigkey_id): - tsigkey = self._find_tsigkeys(context, {'id': tsigkey_id}, one=True) - - return _set_object_from_model(objects.TsigKey(), tsigkey) + def find_tsigkey(self, context, criterion): + return self._find_tsigkeys(context, criterion, one=True) def update_tsigkey(self, context, tsigkey): - storage_tsigkey = self._find_tsigkeys(context, {'id': tsigkey.id}, - one=True) - - storage_tsigkey.update(tsigkey.obj_get_changes()) - - try: - storage_tsigkey.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateTsigKey() - - return _set_object_from_model(tsigkey, storage_tsigkey) + return self._update( + context, tables.tsigkeys, tsigkey, exceptions.DuplicateTsigKey, + exceptions.TsigKeyNotFound) def delete_tsigkey(self, context, tsigkey_id): + # Fetch the existing tsigkey, we'll need to return it. tsigkey = self._find_tsigkeys(context, {'id': tsigkey_id}, one=True) - - tsigkey.delete(self.session) - - return _set_object_from_model(objects.TsigKey(), tsigkey) + return self._delete(context, tables.tsigkeys, tsigkey, + exceptions.TsigKeyNotFound) ## # Tenant Methods ## def find_tenants(self, context): # returns an array of tenant_id & count of their domains - query = self.session.query(models.Domain.tenant_id, - func.count(models.Domain.id)) - query = self._apply_tenant_criteria(context, models.Domain, query) - query = self._apply_deleted_criteria(context, models.Domain, query) - query = query.group_by(models.Domain.tenant_id) + query = select([tables.domains.c.tenant_id, + func.count(tables.domains.c.id)]) + query = self._apply_tenant_criteria(context, tables.domains, query) + query = self._apply_deleted_criteria(context, tables.domains, query) + query = query.group_by(tables.domains.c.tenant_id) - tenants = query.all() + resultproxy = self.session.execute(query) + results = resultproxy.fetchall() tenant_list = objects.TenantList( objects=[objects.Tenant(id=t[0], domain_count=t[1]) for t in - tenants]) + results]) tenant_list.obj_reset_changes() @@ -441,191 +457,162 @@ class SQLAlchemyStorage(base.Storage): def get_tenant(self, context, tenant_id): # get list list & count of all domains owned by given tenant_id - query = self.session.query(models.Domain.name) - query = self._apply_tenant_criteria(context, models.Domain, query) - query = self._apply_deleted_criteria(context, models.Domain, query) - query = query.filter(models.Domain.tenant_id == tenant_id) + query = select([tables.domains.c.name]) + query = self._apply_tenant_criteria(context, tables.domains, query) + query = self._apply_deleted_criteria(context, tables.domains, query) + query = query.where(tables.domains.c.tenant_id == tenant_id) - result = query.all() + resultproxy = self.session.execute(query) + results = resultproxy.fetchall() return objects.Tenant( id=tenant_id, - domain_count=len(result), - domains=[r[0] for r in result]) + domain_count=len(results), + domains=[r[0] for r in results]) def count_tenants(self, context): # tenants are the owner of domains, count the number of unique tenants # select count(distinct tenant_id) from domains - query = self.session.query(distinct(models.Domain.tenant_id)) - query = self._apply_tenant_criteria(context, models.Domain, query) - query = self._apply_deleted_criteria(context, models.Domain, query) + query = select([func.count(distinct(tables.domains.c.tenant_id))]) + query = self._apply_tenant_criteria(context, tables.domains, query) + query = self._apply_deleted_criteria(context, tables.domains, query) - return query.count() + resultproxy = self.session.execute(query) + result = resultproxy.fetchone() + + if result is None: + return 0 + + return result[0] ## # Domain Methods ## - def _find_domains(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.Domain, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.DomainNotFound() + def _find_domains(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.domains, objects.Domain, objects.DomainList, + exceptions.DomainNotFound, criterion, one, marker, limit, + sort_key, sort_dir) def create_domain(self, context, domain): - storage_domain = models.Domain() - - storage_domain.update(domain) - - try: - storage_domain.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateDomain() - - return _set_object_from_model(domain, storage_domain) + return self._create( + tables.domains, domain, exceptions.DuplicateDomain) def get_domain(self, context, domain_id): - domain = self._find_domains(context, {'id': domain_id}, one=True) + return self._find_domains(context, {'id': domain_id}, one=True) - return _set_object_from_model(objects.Domain(), domain) - - def find_domains(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - domains = self._find_domains(context, criterion, marker=marker, - limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - - return _set_listobject_from_models(objects.DomainList(), domains) + def find_domains(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_domains(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) def find_domain(self, context, criterion): - domain = self._find_domains(context, criterion, one=True) - return _set_object_from_model(objects.Domain(), domain) + return self._find_domains(context, criterion, one=True) def update_domain(self, context, domain): - storage_domain = self._find_domains(context, {'id': domain.id}, - one=True) - - storage_domain.update(domain.obj_get_changes()) - - try: - storage_domain.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateDomain() - - return _set_object_from_model(domain, storage_domain) + return self._update( + context, tables.domains, domain, exceptions.DuplicateDomain, + exceptions.DomainNotFound) def delete_domain(self, context, domain_id): + # Fetch the existing domain, we'll need to return it. domain = self._find_domains(context, {'id': domain_id}, one=True) - - domain.soft_delete(self.session) - - return _set_object_from_model(objects.Domain(), domain) + return self._delete(context, tables.domains, domain, + exceptions.DomainNotFound) def count_domains(self, context, criterion=None): - query = self.session.query(models.Domain) - query = self._apply_criterion(models.Domain, query, criterion) - query = self._apply_tenant_criteria(context, models.Domain, query) - query = self._apply_deleted_criteria(context, models.Domain, query) + query = select([func.count(tables.domains.c.id)]) + query = self._apply_criterion(tables.domains, query, criterion) + query = self._apply_tenant_criteria(context, tables.domains, query) + query = self._apply_deleted_criteria(context, tables.domains, query) - return query.count() + resultproxy = self.session.execute(query) + result = resultproxy.fetchone() + + if result is None: + return 0 + + return result[0] # RecordSet Methods - def _find_recordsets(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, - sort_dir=None): - try: - return self._find(models.RecordSet, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.RecordSetNotFound() + def _find_recordsets(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.recordsets, objects.RecordSet, + objects.RecordSetList, exceptions.RecordSetNotFound, criterion, + one, marker, limit, sort_key, sort_dir) def create_recordset(self, context, domain_id, recordset): # Fetch the domain as we need the tenant_id domain = self._find_domains(context, {'id': domain_id}, one=True) - storage_recordset = models.RecordSet() + recordset.tenant_id = domain.tenant_id + recordset.domain_id = domain_id - # We'll need to handle records separately - values = dict(recordset) - values.pop('records', None) - - storage_recordset.update(values) - storage_recordset.tenant_id = domain['tenant_id'] - storage_recordset.domain_id = domain_id - - try: - storage_recordset.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateRecordSet() + recordset = self._create( + tables.recordsets, recordset, exceptions.DuplicateRecordSet, + ['records']) if recordset.obj_attr_is_set('records'): for record in recordset.records: # NOTE: Since we're dealing with a mutable object, the return # value is not needed. The original item will be mutated # in place on the input "recordset.records" list. - self.create_record(context, domain_id, storage_recordset.id, - record) + self.create_record(context, domain_id, recordset.id, record) else: recordset.records = objects.RecordList() - return _set_object_from_model(recordset, storage_recordset, - records=recordset.records) + recordset.obj_reset_changes('records') + + return recordset def get_recordset(self, context, recordset_id): - recordset = self._find_recordsets(context, {'id': recordset_id}, - one=True) + recordset = self._find_recordsets( + context, {'id': recordset_id}, one=True) - records = _set_listobject_from_models( - objects.RecordList(), recordset.records) + recordset.records = self._find_records( + context, {'recordset_id': recordset.id}) - return _set_object_from_model(objects.RecordSet(), recordset, - records=records) + recordset.obj_reset_changes('records') - def find_recordsets(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - recordsets = self._find_recordsets( - context, criterion, marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) + return recordset - def map_(recordset): - return { - 'records': _set_listobject_from_models( - objects.RecordList(), recordset.records) - } + def find_recordsets(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + recordsets = self._find_recordsets(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) - return _set_listobject_from_models( - objects.RecordSetList(), recordsets, map_=map_) + for recordset in recordsets: + recordset.records = self._find_records( + context, {'recordset_id': recordset.id}) + + recordset.obj_reset_changes('records') + + return recordsets def find_recordset(self, context, criterion): recordset = self._find_recordsets(context, criterion, one=True) - records = _set_listobject_from_models( - objects.RecordList(), recordset.records) + recordset.records = self._find_records( + context, {'recordset_id': recordset.id}) - return _set_object_from_model(objects.RecordSet(), recordset, - records=records) + recordset.obj_reset_changes('records') + + return recordset def update_recordset(self, context, recordset): - storage_recordset = self._find_recordsets( - context, {'id': recordset.id}, one=True) - - # We'll need to handle records separately - values = dict(recordset.obj_get_changes()) - values.pop('records', None) - - storage_recordset.update(values) - - try: - storage_recordset.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateRecordSet() + recordset = self._update( + context, tables.recordsets, recordset, + exceptions.DuplicateRecordSet, exceptions.RecordSetNotFound, + ['records']) if recordset.obj_attr_is_set('records'): # Gather the Record ID's we have - have_records = set([r.id for r in storage_recordset.records]) + have_records = set([r.id for r in self._find_records( + context, {'recordset_id': recordset.id})]) # Prep some lists of changes keep_records = set([]) @@ -660,165 +647,136 @@ class SQLAlchemyStorage(base.Storage): self.create_record( context, recordset.domain_id, recordset.id, record) - # Honestly, I have no idea why this is necessary. Without this - # call, then fetching the RecordSet's records again in the same - # session will return the deleted records. - self.session.refresh(storage_recordset) - - return _set_object_from_model(recordset, storage_recordset, - records=recordset.records) + return recordset def delete_recordset(self, context, recordset_id): - recordset = self._find_recordsets(context, {'id': recordset_id}, - one=True) + # Fetch the existing recordset, we'll need to return it. + recordset = self._find_recordsets( + context, {'id': recordset_id}, one=True) - recordset.delete(self.session) - - return _set_object_from_model(objects.RecordSet(), recordset) + return self._delete(context, tables.recordsets, recordset, + exceptions.RecordSetNotFound) def count_recordsets(self, context, criterion=None): - query = self.session.query(models.RecordSet) - query = self._apply_criterion(models.RecordSet, query, criterion) + query = select([func.count(tables.recordsets.c.id)]) + query = self._apply_criterion(tables.recordsets, query, criterion) + query = self._apply_tenant_criteria(context, tables.recordsets, query) + query = self._apply_deleted_criteria(context, tables.recordsets, query) - return query.count() + resultproxy = self.session.execute(query) + result = resultproxy.fetchone() + + if result is None: + return 0 + + return result[0] # Record Methods - def _find_records(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.Record, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.RecordNotFound() + def _find_records(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.records, objects.Record, objects.RecordList, + exceptions.RecordNotFound, criterion, one, marker, limit, + sort_key, sort_dir) + + def _recalculate_record_hash(self, record): + """ + Calculates the hash of the record, used to ensure record uniqueness. + """ + md5 = hashlib.md5() + md5.update("%s:%s:%s" % (record.recordset_id, record.data, + record.priority)) + + return md5.hexdigest() def create_record(self, context, domain_id, recordset_id, record): # Fetch the domain as we need the tenant_id domain = self._find_domains(context, {'id': domain_id}, one=True) - # Create and populate the new Record model - storage_record = models.Record() - storage_record.update(record) - storage_record.tenant_id = domain['tenant_id'] - storage_record.domain_id = domain_id - storage_record.recordset_id = recordset_id + record.tenant_id = domain.tenant_id + record.domain_id = domain_id + record.recordset_id = recordset_id + record.hash = self._recalculate_record_hash(record) - try: - # Save the new Record model - storage_record.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateRecord() - - return _set_object_from_model(record, storage_record) - - def find_records(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - records = self._find_records( - context, criterion, marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - - return _set_listobject_from_models(objects.RecordList(), records) + return self._create( + tables.records, record, exceptions.DuplicateRecord) def get_record(self, context, record_id): - record = self._find_records(context, {'id': record_id}, one=True) + return self._find_records(context, {'id': record_id}, one=True) - return _set_object_from_model(objects.Record(), record) + def find_records(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_records(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) def find_record(self, context, criterion): - record = self._find_records(context, criterion, one=True) - - return _set_object_from_model(objects.Record(), record) + return self._find_records(context, criterion, one=True) def update_record(self, context, record): - storage_record = self._find_records(context, {'id': record.id}, - one=True) + if record.obj_what_changed(): + record.hash = self._recalculate_record_hash(record) - storage_record.update(record.obj_get_changes()) - - try: - storage_record.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateRecord() - - return _set_object_from_model(record, storage_record) + return self._update( + context, tables.records, record, exceptions.DuplicateRecord, + exceptions.RecordNotFound) def delete_record(self, context, record_id): + # Fetch the existing record, we'll need to return it. record = self._find_records(context, {'id': record_id}, one=True) - - record.delete(self.session) - - return _set_object_from_model(objects.Record(), record) + return self._delete(context, tables.records, record, + exceptions.RecordNotFound) def count_records(self, context, criterion=None): - query = self.session.query(models.Record) - query = self._apply_tenant_criteria(context, models.Record, query) - query = self._apply_criterion(models.Record, query, criterion) - return query.count() + query = select([func.count(tables.records.c.id)]) + query = self._apply_criterion(tables.records, query, criterion) + query = self._apply_tenant_criteria(context, tables.records, query) + query = self._apply_deleted_criteria(context, tables.records, query) + + resultproxy = self.session.execute(query) + result = resultproxy.fetchone() + + if result is None: + return 0 + + return result[0] - # # Blacklist Methods - # - def _find_blacklist(self, context, criterion, one=False, - marker=None, limit=None, sort_key=None, sort_dir=None): - try: - return self._find(models.Blacklists, context, criterion, one=one, - marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - except exceptions.NotFound: - raise exceptions.BlacklistNotFound() + def _find_blacklists(self, context, criterion, one=False, marker=None, + limit=None, sort_key=None, sort_dir=None): + return self._find( + context, tables.blacklists, objects.Blacklist, + objects.BlacklistList, exceptions.BlacklistNotFound, criterion, + one, marker, limit, sort_key, sort_dir) def create_blacklist(self, context, blacklist): - storage_blacklist = models.Blacklists() - - storage_blacklist.update(blacklist) - - try: - storage_blacklist.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateBlacklist() - - return _set_object_from_model(blacklist, storage_blacklist) - - def find_blacklists(self, context, criterion=None, - marker=None, limit=None, sort_key=None, sort_dir=None): - blacklists = self._find_blacklist( - context, criterion, marker=marker, limit=limit, sort_key=sort_key, - sort_dir=sort_dir) - - return _set_listobject_from_models(objects.BlacklistList(), blacklists) + return self._create( + tables.blacklists, blacklist, exceptions.DuplicateBlacklist) def get_blacklist(self, context, blacklist_id): - blacklist = self._find_blacklist(context, - {'id': blacklist_id}, one=True) + return self._find_blacklists(context, {'id': blacklist_id}, one=True) - return _set_object_from_model(objects.Blacklist(), blacklist) + def find_blacklists(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): + return self._find_blacklists(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) def find_blacklist(self, context, criterion): - blacklist = self._find_blacklist(context, criterion, one=True) - - return _set_object_from_model(objects.Blacklist(), blacklist) + return self._find_blacklists(context, criterion, one=True) def update_blacklist(self, context, blacklist): - storage_blacklist = self._find_blacklist(context, {'id': blacklist.id}, - one=True) - - storage_blacklist.update(blacklist.obj_get_changes()) - - try: - storage_blacklist.save(self.session) - except exceptions.Duplicate: - raise exceptions.DuplicateBlacklist() - - return _set_object_from_model(blacklist, storage_blacklist) + return self._update( + context, tables.blacklists, blacklist, + exceptions.DuplicateBlacklist, exceptions.BlacklistNotFound) def delete_blacklist(self, context, blacklist_id): + # Fetch the existing blacklist, we'll need to return it. + blacklist = self._find_blacklists( + context, {'id': blacklist_id}, one=True) - blacklist = self._find_blacklist(context, {'id': blacklist_id}, - one=True) - - blacklist.delete(self.session) - - return _set_object_from_model(objects.Blacklist(), blacklist) + return self._delete(context, tables.blacklists, blacklist, + exceptions.BlacklistNotFound) # diagnostics def ping(self, context): diff --git a/designate/storage/impl_sqlalchemy/models.py b/designate/storage/impl_sqlalchemy/models.py deleted file mode 100644 index eb68fa85..00000000 --- a/designate/storage/impl_sqlalchemy/models.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2012 Hewlett-Packard Development Company, L.P. -# Copyright 2012 Managed I.T. -# -# Author: Kiall Mac Innes -# Modified: Patrick Galbraith -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -import hashlib - -from oslo.config import cfg -from oslo.db.sqlalchemy import models as oslo_models -from sqlalchemy import (Column, String, Text, Integer, ForeignKey, - Enum, Boolean, Unicode, UniqueConstraint, event) -from sqlalchemy.orm import relationship, backref -from sqlalchemy.ext.declarative import declarative_base - -from designate.openstack.common import timeutils -from designate.sqlalchemy.types import UUID -from designate.sqlalchemy import models -from designate import utils - - -CONF = cfg.CONF - -RESOURCE_STATUSES = ['ACTIVE', 'PENDING', 'DELETED'] -RECORD_TYPES = ['A', 'AAAA', 'CNAME', 'MX', 'SRV', 'TXT', 'SPF', 'NS', 'PTR', - 'SSHFP'] -TSIG_ALGORITHMS = ['hmac-md5', 'hmac-sha1', 'hmac-sha224', 'hmac-sha256', - 'hmac-sha384', 'hmac-sha512'] - - -class Base(models.Base, oslo_models.TimestampMixin): - id = Column(UUID, default=utils.generate_uuid, primary_key=True) - version = Column(Integer, default=1, nullable=False) - - __mapper_args__ = { - 'version_id_col': version - } - - __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8'} - - -Base = declarative_base(cls=Base) - - -class Quota(Base): - __tablename__ = 'quotas' - __table_args__ = ( - UniqueConstraint('tenant_id', 'resource', name='unique_quota'), - {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8'} - ) - - tenant_id = Column(String(36), default=None, nullable=True) - resource = Column(String(32), nullable=False) - hard_limit = Column(Integer(), nullable=False) - - -class Server(Base): - __tablename__ = 'servers' - - name = Column(String(255), nullable=False, unique=True) - - -class Tld(Base): - __tablename__ = 'tlds' - - name = Column(String(255), nullable=False, unique=True) - description = Column(Unicode(160), nullable=True) - - -class Domain(models.SoftDeleteMixin, Base): - __tablename__ = 'domains' - __table_args__ = ( - UniqueConstraint('name', 'deleted', name='unique_domain_name'), - {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8'} - ) - - tenant_id = Column(String(36), default=None, nullable=True) - - name = Column(String(255), nullable=False) - email = Column(String(255), nullable=False) - description = Column(Unicode(160), nullable=True) - ttl = Column(Integer, default=CONF.default_ttl, nullable=False) - - serial = Column(Integer, default=timeutils.utcnow_ts, nullable=False) - refresh = Column(Integer, default=CONF.default_soa_refresh, nullable=False) - retry = Column(Integer, default=CONF.default_soa_retry, nullable=False) - expire = Column(Integer, default=CONF.default_soa_expire, nullable=False) - minimum = Column(Integer, default=CONF.default_soa_minimum, nullable=False) - status = Column(Enum(name='resource_statuses', *RESOURCE_STATUSES), - nullable=False, server_default='ACTIVE', - default='ACTIVE') - - recordsets = relationship('RecordSet', - backref=backref('domain', uselist=False), - cascade="all, delete-orphan", - passive_deletes=True) - - parent_domain_id = Column(UUID, ForeignKey('domains.id'), default=None, - nullable=True) - - -class RecordSet(Base): - __tablename__ = 'recordsets' - __table_args__ = ( - UniqueConstraint('domain_id', 'name', 'type', name='unique_recordset'), - {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8'} - ) - - tenant_id = Column(String(36), default=None, nullable=True) - domain_id = Column(UUID, ForeignKey('domains.id', ondelete='CASCADE'), - nullable=False) - - name = Column(String(255), nullable=False) - type = Column(Enum(name='record_types', *RECORD_TYPES), nullable=False) - ttl = Column(Integer, default=None, nullable=True) - description = Column(Unicode(160), nullable=True) - - records = relationship('Record', - backref=backref('recordset', uselist=False), - cascade="all, delete-orphan", - passive_deletes=True) - - -class Record(Base): - __tablename__ = 'records' - - tenant_id = Column(String(36), default=None, nullable=True) - domain_id = Column(UUID, ForeignKey('domains.id', ondelete='CASCADE'), - nullable=False) - - recordset_id = Column(UUID, - ForeignKey('recordsets.id', ondelete='CASCADE'), - nullable=False) - - data = Column(Text, nullable=False) - priority = Column(Integer, default=None, nullable=True) - description = Column(Unicode(160), nullable=True) - - hash = Column(String(32), nullable=False, unique=True) - - managed = Column(Boolean, default=False) - managed_extra = Column(Unicode(100), default=None, nullable=True) - managed_plugin_type = Column(Unicode(50), default=None, nullable=True) - managed_plugin_name = Column(Unicode(50), default=None, nullable=True) - managed_resource_type = Column(Unicode(50), default=None, nullable=True) - managed_resource_region = Column(Unicode(100), default=None, nullable=True) - managed_resource_id = Column(UUID, default=None, nullable=True) - managed_tenant_id = Column(Unicode(36), default=None, nullable=True) - status = Column(Enum(name='resource_statuses', *RESOURCE_STATUSES), - nullable=False, server_default='ACTIVE', - default='ACTIVE') - - def recalculate_hash(self): - """ - Calculates the hash of the record, used to ensure record uniqueness. - """ - md5 = hashlib.md5() - md5.update("%s:%s:%s" % (self.recordset_id, self.data, self.priority)) - - self.hash = md5.hexdigest() - - -@event.listens_for(Record, "before_insert") -def recalculate_record_hash_before_insert(mapper, connection, instance): - instance.recalculate_hash() - - -@event.listens_for(Record, "before_update") -def recalculate_record_hash_before_update(mapper, connection, instance): - instance.recalculate_hash() - - -class TsigKey(Base): - __tablename__ = 'tsigkeys' - - name = Column(String(255), nullable=False, unique=True) - algorithm = Column(Enum(name='tsig_algorithms', *TSIG_ALGORITHMS), - nullable=False) - secret = Column(String(255), nullable=False) - - -class Blacklists(Base): - __tablename__ = 'blacklists' - - pattern = Column(String(255), nullable=False, unique=True) - description = Column(Unicode(160), nullable=True) diff --git a/designate/storage/impl_sqlalchemy/tables.py b/designate/storage/impl_sqlalchemy/tables.py new file mode 100644 index 00000000..cdf6baab --- /dev/null +++ b/designate/storage/impl_sqlalchemy/tables.py @@ -0,0 +1,190 @@ +# Copyright 2012-2014 Hewlett-Packard Development Company, L.P. +# +# Author: Kiall Mac Innes +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from sqlalchemy import (Table, MetaData, Column, String, Text, Integer, CHAR, + DateTime, Enum, Boolean, Unicode, UniqueConstraint, + ForeignKeyConstraint) + +from oslo.config import cfg + +from designate import utils +from designate.openstack.common import timeutils +from designate.sqlalchemy.types import UUID + + +CONF = cfg.CONF + +RESOURCE_STATUSES = ['ACTIVE', 'PENDING', 'DELETED'] +RECORD_TYPES = ['A', 'AAAA', 'CNAME', 'MX', 'SRV', 'TXT', 'SPF', 'NS', 'PTR', + 'SSHFP'] +TSIG_ALGORITHMS = ['hmac-md5', 'hmac-sha1', 'hmac-sha224', 'hmac-sha256', + 'hmac-sha384', 'hmac-sha512'] + +metadata = MetaData() + +quotas = Table('quotas', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('tenant_id', String(36), default=None, nullable=True), + Column('resource', String(32), nullable=False), + Column('hard_limit', Integer(), nullable=False), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +servers = Table('servers', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('name', String(255), nullable=False, unique=True), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +tlds = Table('tlds', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('name', String(255), nullable=False, unique=True), + Column('description', Unicode(160), nullable=True), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +domains = Table('domains', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('deleted', CHAR(32), nullable=False, default='0', + server_default='0'), + Column('deleted_at', DateTime, nullable=True, default=None), + + Column('tenant_id', String(36), default=None, nullable=True), + Column('name', String(255), nullable=False), + Column('email', String(255), nullable=False), + Column('description', Unicode(160), nullable=True), + Column('ttl', Integer, default=CONF.default_ttl, nullable=False), + Column('serial', Integer, default=timeutils.utcnow_ts, nullable=False), + Column('refresh', Integer, default=CONF.default_soa_refresh, + nullable=False), + Column('retry', Integer, default=CONF.default_soa_retry, nullable=False), + Column('expire', Integer, default=CONF.default_soa_expire, nullable=False), + Column('minimum', Integer, default=CONF.default_soa_minimum, + nullable=False), + Column('status', Enum(name='resource_statuses', *RESOURCE_STATUSES), + nullable=False, server_default='ACTIVE', default='ACTIVE'), + Column('parent_domain_id', UUID, default=None, nullable=True), + + UniqueConstraint('name', 'deleted', name='unique_domain_name'), + ForeignKeyConstraint(['parent_domain_id'], + ['domains.id'], + ondelete='SET NULL'), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +recordsets = Table('recordsets', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('tenant_id', String(36), default=None, nullable=True), + Column('domain_id', UUID, nullable=False), + Column('name', String(255), nullable=False), + Column('type', Enum(name='record_types', *RECORD_TYPES), nullable=False), + Column('ttl', Integer, default=None, nullable=True), + Column('description', Unicode(160), nullable=True), + + UniqueConstraint('domain_id', 'name', 'type', name='unique_recordset'), + ForeignKeyConstraint(['domain_id'], ['domains.id'], ondelete='CASCADE'), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +records = Table('records', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('tenant_id', String(36), default=None, nullable=True), + Column('domain_id', UUID, nullable=False), + Column('recordset_id', UUID, nullable=False), + Column('data', Text, nullable=False), + Column('priority', Integer, default=None, nullable=True), + Column('description', Unicode(160), nullable=True), + Column('hash', String(32), nullable=False, unique=True), + Column('managed', Boolean, default=False), + Column('managed_extra', Unicode(100), default=None, nullable=True), + Column('managed_plugin_type', Unicode(50), default=None, nullable=True), + Column('managed_plugin_name', Unicode(50), default=None, nullable=True), + Column('managed_resource_type', Unicode(50), default=None, nullable=True), + Column('managed_resource_region', Unicode(100), default=None, + nullable=True), + Column('managed_resource_id', UUID, default=None, nullable=True), + Column('managed_tenant_id', Unicode(36), default=None, nullable=True), + Column('status', Enum(name='resource_statuses', *RESOURCE_STATUSES), + nullable=False, server_default='ACTIVE', default='ACTIVE'), + + ForeignKeyConstraint(['domain_id'], ['domains.id'], ondelete='CASCADE'), + ForeignKeyConstraint(['recordset_id'], ['recordsets.id'], + ondelete='CASCADE'), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +tsigkeys = Table('tsigkeys', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('name', String(255), nullable=False, unique=True), + Column('algorithm', Enum(name='tsig_algorithms', *TSIG_ALGORITHMS), + nullable=False), + Column('secret', String(255), nullable=False), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) + +blacklists = Table('blacklists', metadata, + Column('id', UUID, default=utils.generate_uuid, primary_key=True), + Column('version', Integer(), default=1, nullable=False), + Column('created_at', DateTime, default=lambda: timeutils.utcnow()), + Column('updated_at', DateTime, onupdate=lambda: timeutils.utcnow()), + + Column('pattern', String(255), nullable=False, unique=True), + Column('description', Unicode(160), nullable=True), + + mysql_engine='InnoDB', + mysql_charset='utf8', +) diff --git a/designate/tests/__init__.py b/designate/tests/__init__.py index 86b0d180..0a2c603b 100644 --- a/designate/tests/__init__.py +++ b/designate/tests/__init__.py @@ -42,11 +42,6 @@ from designate.network_api import fake as fake_network_api from designate import network_api from designate import objects - -# NOTE: If eventlet isn't patched and there's a exc tests block -import eventlet -eventlet.monkey_patch(os=False) - LOG = logging.getLogger(__name__) cfg.CONF.import_opt('storage_driver', 'designate.central', @@ -455,7 +450,7 @@ class TestCase(base.BaseTestCase): for index in range(len(self.default_tld_fixtures)): try: self.create_default_tld(fixture=index) - except exceptions.DuplicateTLD: + except exceptions.DuplicateTld: pass def create_tsigkey(self, **kwargs): diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index 0a374446..54411647 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -347,7 +347,7 @@ class CentralServiceTest(CentralTestCase): # Fetch the tld again, ensuring an exception is raised self.assertRaises( - exceptions.TLDNotFound, + exceptions.TldNotFound, self.central_service.get_tld, self.admin_context, tld['id']) diff --git a/designate/tests/test_storage/__init__.py b/designate/tests/test_storage/__init__.py index 8c36adf5..8e4d1cbe 100644 --- a/designate/tests/test_storage/__init__.py +++ b/designate/tests/test_storage/__init__.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. import uuid +import math import testtools @@ -49,15 +50,27 @@ class StorageTestCase(object): Given an array of created items we iterate through them making sure they match up to things returned by paged results. """ - found = method(self.admin_context, limit=5) - x = 0 - for i in xrange(0, len(data)): - self.assertEqual(data[i]['id'], found[x]['id']) - x += 1 - if x == len(found): - x = 0 - found = method( - self.admin_context, limit=5, marker=found[-1:][0]['id']) + results = None + item_number = 0 + + for current_page in range(0, int(math.ceil(float(len(data)) / 2))): + LOG.debug('Validating results on page %d', current_page) + + if results is not None: + results = method( + self.admin_context, limit=2, marker=results[-1]['id']) + else: + results = method(self.admin_context, limit=2) + + LOG.critical('Results: %d', len(results)) + + for result_number, result in enumerate(results): + LOG.debug('Validating result %d on page %d', result_number, + current_page) + self.assertEqual( + data[item_number]['id'], results[result_number]['id']) + + item_number += 1 def test_paging_marker_not_found(self): with testtools.ExpectedException(exceptions.MarkerNotFound): @@ -93,7 +106,7 @@ class StorageTestCase(object): values = self.get_quota_fixture() values['tenant_id'] = self.admin_context.tenant - result = self.storage.create_quota(self.admin_context, values=values) + result = self.storage.create_quota(self.admin_context, values) self.assertIsNotNone(result['id']) self.assertIsNotNone(result['created_at']) @@ -290,12 +303,12 @@ class StorageTestCase(object): self.assertEqual(1, len(actual)) self.assertEqual(server['name'], actual[0]['name']) - # Order of found items later will be reverse of the order they are - # created - created = [self.create_server( - name='ns%s.example.org.' % i) for i in xrange(10, 20)] - created.insert(0, server) + def test_find_servers_paging(self): + # Create 10 Servers + created = [self.create_server(name='ns%d.example.org.' % i) + for i in xrange(10)] + # Ensure we can page through the results. self._ensure_paging(created, self.storage.find_servers) def test_find_servers_criterion(self): @@ -415,12 +428,12 @@ class StorageTestCase(object): self.assertEqual(tsig['algorithm'], actual[0]['algorithm']) self.assertEqual(tsig['secret'], actual[0]['secret']) - # Order of found items later will be reverse of the order they are - # created - created = [self.create_tsigkey(name='tsig%s.' % i) - for i in xrange(10, 20)] - created.insert(0, tsig) + def test_find_tsigkeys_paging(self): + # Create 10 TSIG Keys + created = [self.create_tsigkey(name='tsig-%s' % i) + for i in xrange(10)] + # Ensure we can page through the results. self._ensure_paging(created, self.storage.find_tsigkeys) def test_find_tsigkeys_criterion(self): @@ -612,12 +625,12 @@ class StorageTestCase(object): self.assertEqual(domain['name'], actual[0]['name']) self.assertEqual(domain['email'], actual[0]['email']) - # Order of found items later will be reverse of the order they are - # created XXXX - created = [self.create_domain(name='x%s.org.' % i) - for i in xrange(10, 20)] - created.insert(0, domain) + def test_find_domains_paging(self): + # Create 10 Domains + created = [self.create_domain(name='example-%d.org.' % i) + for i in xrange(10)] + # Ensure we can page through the results. self._ensure_paging(created, self.storage.find_domains) def test_find_domains_criterion(self): @@ -868,12 +881,14 @@ class StorageTestCase(object): self.assertEqual(recordset_one['name'], actual[0]['name']) self.assertEqual(recordset_one['type'], actual[0]['type']) - # Order of found items later will be reverse of the order they are - # created - created = [self.create_recordset( - domain, name='test%s' % i + '.%s') for i in xrange(10, 20)] - created.insert(0, recordset_one) + def test_find_recordsets_paging(self): + domain = self.create_domain(name='example.org.') + # Create 10 RecordSets + created = [self.create_recordset(domain, name='r-%d.example.org.' % i) + for i in xrange(10)] + + # Ensure we can page through the results. self._ensure_paging(created, self.storage.find_recordsets) def test_find_recordsets_criterion(self): @@ -1235,13 +1250,15 @@ class StorageTestCase(object): self.assertEqual(record['data'], actual[0]['data']) self.assertIn('status', record) - # Order of found items later will be reverse of the order they are - # created - created = [self.create_record( - domain, recordset, data='192.0.0.%s' % i) - for i in xrange(10, 20)] - created.insert(0, record) + def test_find_records_paging(self): + domain = self.create_domain() + recordset = self.create_recordset(domain, type='A') + # Create 10 Records + created = [self.create_record(domain, recordset, data='192.0.2.%d' % i) + for i in xrange(10)] + + # Ensure we can page through the results. self._ensure_paging(created, self.storage.find_records) def test_find_records_criterion(self):