diff --git a/nova/db/api.py b/nova/db/api.py index 63a586dea..3d0efd138 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -87,6 +87,32 @@ class NoMoreTargets(exception.NovaException): ################### +def constraint(**conditions): + """Return a constraint object suitable for use with some updates.""" + return IMPL.constraint(**conditions) + + +def equal_any(*values): + """Return an equality condition object suitable for use in a constraint. + + Equal_any conditions require that a model object's attribute equal any + one of the given values. + """ + return IMPL.equal_any(*values) + + +def not_equal(*values): + """Return an inequality condition object suitable for use in a constraint. + + Not_equal conditions require that a model object's attribute differs from + all of the given values. + """ + return IMPL.not_equal(*values) + + +################### + + def service_destroy(context, instance_id): """Destroy the service or raise if it does not exist.""" return IMPL.service_destroy(context, instance_id) @@ -527,9 +553,9 @@ def instance_data_get_for_project(context, project_id, session=None): session=session) -def instance_destroy(context, instance_id): +def instance_destroy(context, instance_id, constraint=None): """Destroy the instance or raise if it does not exist.""" - return IMPL.instance_destroy(context, instance_id) + return IMPL.instance_destroy(context, instance_id, constraint) def instance_get_by_uuid(context, uuid): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 30c556b4d..0881af62f 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -43,6 +43,7 @@ from sqlalchemy.orm import joinedload_all from sqlalchemy.sql.expression import asc from sqlalchemy.sql.expression import desc from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import or_ from sqlalchemy.sql import func FLAGS = flags.FLAGS @@ -263,6 +264,52 @@ def exact_filter(query, model, filters, legal_keys): ################### +def constraint(**conditions): + return Constraint(conditions) + + +def equal_any(*values): + return EqualityCondition(values) + + +def not_equal(*values): + return InequalityCondition(values) + + +class Constraint(object): + + def __init__(self, conditions): + self.conditions = conditions + + def apply(self, model, query): + clauses = [] + for key, condition in self.conditions.iteritems(): + for clause in condition.clauses(getattr(model, key)): + query = query.filter(clause) + return query + + +class EqualityCondition(object): + + def __init__(self, values): + self.values = values + + def clauses(self, field): + return or_([field == value for value in self.values]) + + +class InequalityCondition(object): + + def __init__(self, values): + self.values = values + + def clauses(self, field): + return [field != value for value in self.values] + + +################### + + @require_admin_context def service_destroy(context, service_id): session = get_session() @@ -1311,7 +1358,7 @@ def instance_data_get_for_project(context, project_id, session=None): @require_context -def instance_destroy(context, instance_id): +def instance_destroy(context, instance_id, constraint=None): session = get_session() with session.begin(): if utils.is_uuid_like(instance_id): @@ -1321,11 +1368,14 @@ def instance_destroy(context, instance_id): else: instance_ref = instance_get(context, instance_id, session=session) - session.query(models.Instance).\ - filter_by(id=instance_id).\ - update({'deleted': True, - 'deleted_at': utils.utcnow(), - 'updated_at': literal_column('updated_at')}) + query = session.query(models.Instance).filter_by(id=instance_id) + if constraint is not None: + query = constraint.apply(models.Instance, query) + count = query.update({'deleted': True, + 'deleted_at': utils.utcnow(), + 'updated_at': literal_column('updated_at')}) + if count == 0: + raise exception.ConstraintNotMet() session.query(models.SecurityGroupInstanceAssociation).\ filter_by(instance_id=instance_id).\ update({'deleted': True,