From e53676bb32b70ff01ca27c310e558b651590be3d Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 10 Sep 2010 15:26:13 -0700 Subject: [PATCH] Refactored to security group api to support projects --- nova/auth/manager.py | 2 - nova/db/api.py | 34 ++++++++------- nova/db/sqlalchemy/api.py | 76 ++++++++++++++++++++------------- nova/db/sqlalchemy/models.py | 22 +++++----- nova/endpoint/cloud.py | 81 ++++++++++++++++++++++-------------- nova/tests/api_unittest.py | 1 + nova/tests/virt_unittest.py | 4 +- nova/virt/libvirt_conn.py | 2 +- 8 files changed, 132 insertions(+), 90 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 281e2d8f05ee..34aa73bf61b6 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -649,8 +649,6 @@ class AuthManager(object): def delete_user(self, user): """Deletes a user""" with self.driver() as drv: - for security_group in db.security_group_get_by_user(context = {}, user_id=User.safe_id(user)): - db.security_group_destroy({}, security_group.id) drv.delete_user(User.safe_id(user)) def generate_key_pair(self, user, key_name): diff --git a/nova/db/api.py b/nova/db/api.py index 2bcf0bd2b35e..cdbd1548683a 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -442,17 +442,28 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + #################### -def security_group_create(context, values): - """Create a new security group""" - return IMPL.security_group_create(context, values) +def security_group_get_all(context): + """Get all security groups""" + return IMPL.security_group_get_all(context) -def security_group_get_by_id(context, security_group_id): +def security_group_get(context, security_group_id): """Get security group by its internal id""" - return IMPL.security_group_get_by_id(context, security_group_id) + return IMPL.security_group_get(context, security_group_id) + + +def security_group_get_by_name(context, project_id, group_name): + """Returns a security group with the specified name from a project""" + return IMPL.securitygroup_get_by_name(context, project_id, group_name) + + +def security_group_get_by_project(context, project_id): + """Get all security groups belonging to a project""" + return IMPL.securitygroup_get_by_project(context, project_id) def security_group_get_by_instance(context, instance_id): @@ -460,15 +471,10 @@ def security_group_get_by_instance(context, instance_id): return IMPL.security_group_get_by_instance(context, instance_id) -def security_group_get_by_user(context, user_id): - """Get security groups owned by the given user""" - return IMPL.security_group_get_by_user(context, user_id) - - -def security_group_get_by_user_and_name(context, user_id, name): - """Get user's named security group""" - return IMPL.security_group_get_by_user_and_name(context, user_id, name) - +def security_group_create(context, values): + """Create a new security group""" + return IMPL.security_group_create(context, values) + def security_group_destroy(context, security_group_id): """Deletes a security group""" diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 1c95efd83109..61d733940b9e 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -616,20 +616,45 @@ def volume_update(_context, volume_id, values): ################### -def security_group_create(_context, values): - security_group_ref = models.SecurityGroup() - for (key, value) in values.iteritems(): - security_group_ref[key] = value - security_group_ref.save() - return security_group_ref +def security_group_get_all(_context): + session = get_session() + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).filter_by(deleted=False + ).all() -def security_group_get_by_id(_context, security_group_id): +def security_group_get(_context, security_group_id): session = get_session() with session.begin(): return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).get(security_group_id) + + +def securitygroup_get_by_name(context, project_id, group_name): + session = get_session() + group_ref = session.query(models.SecurityGroup ).options(eagerload('rules') - ).get(security_group_id) + ).filter_by(project_id=project_id + ).filter_by(name=group_name + ).filter_by(deleted=False + ).first() + if not group_ref: + raise exception.NotFound( + 'No security group named %s for project: %s' \ + % (group_name, project_id)) + + return group_ref + + +def securitygroup_get_by_project(_context, project_id): + session = get_session() + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).all() def security_group_get_by_instance(_context, instance_id): @@ -638,34 +663,27 @@ def security_group_get_by_instance(_context, instance_id): return session.query(models.Instance ).get(instance_id ).security_groups \ - .all() + .filter_by(deleted=False + ).all() -def security_group_get_by_user(_context, user_id): - session = get_session() - with session.begin(): - return session.query(models.SecurityGroup - ).filter_by(user_id=user_id - ).filter_by(deleted=False - ).options(eagerload('rules') - ).all() +def security_group_create(_context, values): + security_group_ref = models.SecurityGroup() + for (key, value) in values.iteritems(): + security_group_ref[key] = value + security_group_ref.save() + return security_group_ref -def security_group_get_by_user_and_name(_context, user_id, name): - session = get_session() - with session.begin(): - return session.query(models.SecurityGroup - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=False - ).options(eagerload('rules') - ).one() def security_group_destroy(_context, security_group_id): session = get_session() with session.begin(): - security_group = session.query(models.SecurityGroup - ).get(security_group_id) - security_group.delete(session=session) + # TODO(vish): do we have to use sql here? + session.execute('update security_group set deleted=1 where id=:id', + {'id': security_group_id}) + session.execute('update security_group_rule set deleted=1 ' + 'where group_id=:id', + {'id': security_group_id}) ################### diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index f27520aa81ae..3c4b9ddd7c50 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -306,26 +306,23 @@ class SecurityGroup(BASE, NovaBase): class SecurityGroupIngressRule(BASE, NovaBase): """Represents a rule in a security group""" - __tablename__ = 'security_group_rules' + __tablename__ = 'security_group_rule' id = Column(Integer, primary_key=True) - parent_group_id = Column(Integer, ForeignKey('security_group.id')) - parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id, - primaryjoin=parent_group_id==SecurityGroup.id) + group_id = Column(Integer, ForeignKey('security_group.id')) + group = relationship("SecurityGroup", backref="rules", + foreign_keys=group_id, + primaryjoin=group_id==SecurityGroup.id) protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) to_port = Column(Integer) + cidr = Column(String(255)) # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. - group_id = Column(Integer, ForeignKey('security_group.id')) + source_group_id = Column(Integer, ForeignKey('security_group.id')) - @property - def user(self): - return auth.manager.AuthManager().get_user(self.user_id) - - cidr = Column(String(255)) class Network(BASE, NovaBase): """Represents a network""" @@ -430,8 +427,9 @@ class FloatingIp(BASE, NovaBase): def register_models(): """Register Models and create metadata""" from sqlalchemy import create_engine - models = (Service, Instance, Volume, ExportDevice, - FixedIp, FloatingIp, Network, NetworkIndex) # , Image, Host) + models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp, + Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule) + # , Image, Host engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 930274aed19e..4cb09bedb4b5 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -216,7 +216,8 @@ class CloudController(object): @rbac.allow('all') def describe_security_groups(self, context, **kwargs): groups = [] - for group in db.security_group_get_by_user(context, context.user.id): + for group in db.security_group_get_by_project(context, + context.project.id): group_dict = {} group_dict['groupDescription'] = group.description group_dict['groupName'] = group.name @@ -229,10 +230,11 @@ class CloudController(object): rule_dict['toPort'] = rule.to_port rule_dict['groups'] = [] rule_dict['ipRanges'] = [] + import pdb; pdb.set_trace() if rule.group_id: - foreign_group = db.security_group_get_by_id({}, rule.group_id) - rule_dict['groups'] += [ { 'groupName': foreign_group.name, - 'userId': foreign_group.user_id } ] + source_group = db.security_group_get(context, rule.group_id) + rule_dict['groups'] += [ { 'groupName': source_group.name, + 'userId': source_group.user_id } ] else: rule_dict['ipRanges'] += [ { 'cidrIp': rule.cidr } ] group_dict['ipPermissions'] += [ rule_dict ] @@ -258,23 +260,22 @@ class CloudController(object): user_id=None, source_security_group_name=None, source_security_group_owner_id=None): - security_group = db.security_group_get_by_user_and_name(context, - context.user.id, - group_name) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) criteria = {} if source_security_group_name: - if source_security_group_owner_id: - other_user_id = source_security_group_owner_id - else: - other_user_id = context.user.id - - foreign_security_group = \ - db.security_group_get_by_user_and_name(context, - other_user_id, - source_security_group_name) - criteria['group_id'] = foreign_security_group.id + source_project_id = self._get_source_project_id(context, + source_security_group_owner_id) + + source_security_group = \ + db.security_group_get_by_name(context, + source_project_id, + source_security_group_name) + + criteria['group_id'] = source_security_group.id elif cidr_ip: criteria['cidr'] = cidr_ip else: @@ -303,22 +304,20 @@ class CloudController(object): ip_protocol=None, cidr_ip=None, source_security_group_name=None, source_security_group_owner_id=None): - security_group = db.security_group_get_by_user_and_name(context, - context.user.id, - group_name) - values = { 'parent_group_id' : security_group.id } + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + values = { 'group_id' : security_group.id } if source_security_group_name: - if source_security_group_owner_id: - other_user_id = source_security_group_owner_id - else: - other_user_id = context.user.id + source_project_id = self._get_source_project_id(context, + source_security_group_owner_id) - foreign_security_group = \ - db.security_group_get_by_user_and_name(context, - other_user_id, - source_security_group_name) - values['group_id'] = foreign_security_group.id + source_security_group = \ + db.security_group_get_by_name(context, + source_project_id, + source_security_group_name) + values['source_group_id'] = source_security_group.id elif cidr_ip: values['cidr'] = cidr_ip else: @@ -336,18 +335,38 @@ class CloudController(object): security_group_rule = db.security_group_rule_create(context, values) return True + + def _get_source_project_id(self, context, source_security_group_owner_id): + if source_security_group_owner_id: + # Parse user:project for source group. + source_parts = source_security_group_owner_id.split(':') + + # If no project name specified, assume it's same as user name. + # Since we're looking up by project name, the user name is not + # used here. It's only read for EC2 API compatibility. + if len(source_parts) == 2: + source_project_id = parts[1] + else: + source_project_id = parts[0] + else: + source_project_id = context.project.id + + return source_project_id @rbac.allow('netadmin') def create_security_group(self, context, group_name, group_description): db.security_group_create(context, values = { 'user_id' : context.user.id, + 'project_id': context.project.id, 'name': group_name, 'description': group_description }) return True @rbac.allow('netadmin') def delete_security_group(self, context, group_name, **kwargs): - security_group = db.security_group_get_by_user_and_name(context, context.user.id, group_name) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) db.security_group_destroy(context, security_group.id) return True diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 7e914e6f5690..55b7cb4d8e40 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -304,6 +304,7 @@ class ApiEc2TestCase(test.BaseTestCase): # be good enough for that. for group in rv: if group.name == security_group_name: + import pdb; pdb.set_trace() self.assertEquals(len(group.rules), 1) self.assertEquals(int(group.rules[0].from_port), 80) self.assertEquals(int(group.rules[0].to_port), 81) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 1f573c4632b3..dceced3a9add 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -86,6 +86,8 @@ class NWFilterTestCase(test.TrialTestCase): context.user = FakeContext() context.user.id = 'fake' context.user.is_superuser = lambda:True + context.project = FakeContext() + context.project.id = 'fake' cloud_controller.create_security_group(context, 'testgroup', 'test group description') cloud_controller.authorize_security_group_ingress(context, 'testgroup', from_port='80', to_port='81', ip_protocol='tcp', @@ -93,7 +95,7 @@ class NWFilterTestCase(test.TrialTestCase): fw = libvirt_conn.NWFilterFirewall() - security_group = db.security_group_get_by_user_and_name({}, 'fake', 'testgroup') + security_group = db.security_group_get_by_name({}, 'fake', 'testgroup') xml = fw.security_group_to_nwfilter_xml(security_group.id) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 6f708bb8085d..09c94577c066 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -492,7 +492,7 @@ class NWFilterFirewall(object): ''' def security_group_to_nwfilter_xml(self, security_group_id): - security_group = db.security_group_get_by_id({}, security_group_id) + security_group = db.security_group_get({}, security_group_id) rule_xml = "" for rule in security_group.rules: rule_xml += ""